#include <iostream.h>
#include <iomanip.h>
#include <stdlib.h>
#include <math.h>
#include <strings.h>
#include "util.h"
#include "data_map.h"
#include "learning_method.h"
#include "command_line.h"

static int *data_set_ordering = 0;

static DataMap train_map;

static int output_vector_size;

static double **train_set_output_matrix = 0;

static double ***train_set_confusion_matrix = 0;

static char *save_ordering_fname = 0;
static bool save_ordering = 0;
static char *save_data_set_fname = 0;
static bool save_data_set = 0;
static char *save_model_fname = 0;
static bool save_model = 0;

void explain_save_ordering_item() {
  cout << "      ordering FNAME\n";
  cout << "          saves the random pattern ordering used in this experiment in FILENAME\n";
}

void explain_save_data_set_item() {
  cout << "      dataset FNAME\n";
  cout << "          saves the data set used in this experiment in FILENAME\n";
}

void explain_save_model_item() {
  cout << "      model FNAME\n";
  cout << "          saves the model learned in this experiment in FILENAME\n";
  cout << "          this option is REQUIRED\n";
}

static SetFileNameItem save_ordering_item(&get_save_values_item,&save_ordering_fname,&save_ordering,"ordering",explain_save_ordering_item);
static SetFileNameItem save_data_set_item(&get_save_values_item,&save_data_set_fname,&save_data_set,"dataset",explain_save_data_set_item);
static SetFileNameItem save_model_item(&get_save_values_item,&save_model_fname,&save_model,"model",explain_save_model_item);

static char *use_ordering_fname = 0;
static bool use_ordering = 0;

void explain_use_ordering_item() {
  cout << "      ordering FNAME\n";
  cout << "          uses the random pattern ordering stored in FILENAME in this experiment\n";
}

static SetFileNameItem use_ordering_item(&get_use_values_item,&use_ordering_fname,&use_ordering,"ordering",explain_use_ordering_item);

static bool report_train_set_accuracy = 0;
static bool report_all_predictions = 0;

void explain_set_train_report_item() {
  cout << "      train\n";
  cout << "          causes the experiment to test and report the accuracy of\n";
  cout << "          the model on the TRAINing data\n";
}

void explain_set_notrain_report_item() {
  cout << "      notrain\n";
  cout << "          disables the -report train flag (DEFAULT)\n";
}

void explain_set_predictions_report_item() {
  cout << "      predictions\n";
  cout << "          causes the experiment to print the resulting predicted and actual\n";
  cout << "          output vectors\n";
}

void explain_set_nopredictions_report_item() {
  cout << "      nopredictions\n";
  cout << "          disables the -report predictions flag (DEFAULT)\n";
}

static SetFlagItem set_nopredictions_report_item(&get_report_values_item,&report_all_predictions,0,"nopredictions",explain_set_nopredictions_report_item);
static SetFlagItem set_predictions_report_item(&get_report_values_item,&report_all_predictions,1,"predictions",explain_set_predictions_report_item);
static SetFlagItem set_notrain_report_item(&get_report_values_item,&report_train_set_accuracy,0,"notrain",explain_set_notrain_report_item);
static SetFlagItem set_train_report_item(&get_report_values_item,&report_train_set_accuracy,1,"train",explain_set_train_report_item);

void initialize_train_variables() {
  output_vector_size = get_output_vector_size(current_data_set);

  allocate_2d_double_array(train_set_output_matrix,current_data_set->num_examples,output_vector_size);

  train_set_confusion_matrix = new double**[current_data_set->num_output_classes];
  for (int cnum = 0; cnum < current_data_set->num_output_classes; cnum++) {
    allocate_2d_double_array(train_set_confusion_matrix[cnum],current_data_set->class_descriptors[cnum].num_class_values,current_data_set->class_descriptors[cnum].num_class_values);
  }

  if (!data_set_ordering) {
    if (train_set_inorder)
      generate_non_random_ordering(current_data_set->num_examples,data_set_ordering);
    else
      generate_new_random_ordering(current_data_set->num_examples,data_set_ordering);
  }
  train_map.initialize(current_data_set,current_data_set->num_examples);
}

void generate_train_map() {
  double trainweight = 1.0 / current_data_set->num_examples;
  for (int j = 0; j < current_data_set->num_examples; j++) {
    train_map.examples[j] = &(current_data_set->examples[data_set_ordering[j]]);
    train_map.original_index[j] = data_set_ordering[j];
    for (int k = 0; k < current_data_set->num_output_classes; k++)
      train_map.weight[j][k] = trainweight;
  }
}

int main (int argc, char *argv[]) {
  cout.setf(ios::fixed, ios::floatfield);
  cout.setf(ios::showpoint);

  process_top_level_commands(argc,argv);
  
  if (!current_data_set) {
    cerr << "ERROR!  Must specify a data_set" << endl;
    exit(-1);
  }

  if (!current_learning_method) {
    cerr << "ERROR!  Must specify a learning method" << endl;
    exit(-1);
  }

  if (!save_model) {
    cerr << "ERROR!  Must specify a filename for the resulting model" << endl;
    exit(-1);
  }

  if (!random_seed_set) srandom(time(0));

  if (save_data_set) do_save_data_set(save_data_set_fname,current_data_set);

  if (use_ordering) do_read_ordering(use_ordering_fname,data_set_ordering,current_data_set->num_examples);

  initialize_train_variables();

  if (save_ordering) do_save_ordering(save_ordering_fname,data_set_ordering,current_data_set->num_examples);

  generate_train_map();

  current_learning_method->learn(&train_map);

  do_save_model(save_model_fname,current_learning_method);

  if (report_train_set_accuracy) {
    generate_non_random_ordering(train_map.num_examples,data_set_ordering);
    generate_train_map();
    current_learning_method->classify(&train_map,train_set_output_matrix,output_vector_size);
    accumulate_confusion_matrix(&train_map,train_set_output_matrix,train_set_confusion_matrix);
    cout << "\n\nConfusion matrix for training set:\n";
    show_confusion_matrix(current_data_set,train_set_confusion_matrix);
  }

  if (report_all_predictions)
    show_predicted_and_actual_vectors(current_data_set,train_set_output_matrix,output_vector_size);

  return 0;
}
