README: Make default dir names compatible with README
[gostyle.git] / gnet / gnet_train.c
blobd4a801f48bb4d489b71f1b8aeccc890e8ed2e18c
1 #include <stdio.h>
2 #include <string.h>
3 #include <getopt.h>
5 #include "floatfann.h"
7 /* Prototypes */
8 void print_help(void);
10 /* There we go! */
11 int main(int argc, char ** argv) {
12 /* Default network params */
13 unsigned int num_layers = 3;
14 unsigned int num_neurons_hidden = 30;
15 unsigned int max_epochs = 5000;
16 float desired_error = 0.0001;
17 float learning_momentum = 0.2;
19 char * net_output_name = "go_net.net";
20 char * train_data_name=0;
22 /* Parse input */
24 static struct option long_options[] = {
25 { "layers", required_argument, NULL, 'l' },
26 { "neurons-hidden", required_argument, NULL, 'n' },
27 { "max-epochs", required_argument, NULL, 'p' },
28 { "desired-error", required_argument, NULL, 'e' },
29 { "learning-momentum", required_argument, NULL, 'm' },
30 { "net-output-file", required_argument, NULL, 'o' },
31 { "help", no_argument, NULL, 'h' },
32 { NULL, no_argument, NULL, 0 }
34 int c;
35 while( (c = getopt_long(argc, argv, "hl:n:p:e:m:o:", long_options, &optind)) != -1 ){
36 switch (c){
37 case 'l': num_layers = atoi(optarg); break;
38 case 'n': num_neurons_hidden = atoi(optarg); break;
39 case 'p': max_epochs = atoi(optarg); break;
40 case 'e': desired_error = atof(optarg); break;
41 case 'm': learning_momentum = atof(optarg); break;
42 case 'o': net_output_name = optarg; break;
43 case 'h': print_help(); exit(1); break;
44 case '?': break;
45 default: exit(1);
48 while (optind < argc)
49 train_data_name= argv[optind++];
52 if( ! train_data_name ){ fprintf(stderr, "No training data file specified.\n"); exit(1); }
53 if( num_layers <= 0 ){ fprintf(stderr, "Number of layers must be positive.\n"); exit(1); }
54 if( num_neurons_hidden <= 0 ){ fprintf(stderr, "Number of neurons in the hidden layer must be positive.\n"); exit(1); }
55 if( max_epochs <= 0 ){ fprintf(stderr, "Max number of epochs must be positive.\n"); exit(1); }
56 if( desired_error <= 0 ){ fprintf(stderr, "Desired error must be positive.\n"); exit(1); }
57 if( learning_momentum <= 0 ){ fprintf(stderr, "Learning momentum be positive.\n"); exit(1); }
58 //#ifdef DEBUG
59 #if 1
60 printf("Layers: %u\n", num_layers);
61 printf("Neurons hidden: %u\n", num_neurons_hidden);
62 printf("Max epochs: %u\n", max_epochs);
63 printf("Desired error: %f\n", desired_error);
64 printf("Learning momentum: %f\n", learning_momentum);
65 printf("Net output file: %s\n", net_output_name);
66 printf("Train data file: %s\n", train_data_name);
67 printf("\n");
68 #endif
70 /* Create the net */
71 struct fann *ann=0;
72 struct fann_train_data *train_data=0;
74 printf("Creating network.\n");
75 train_data = fann_read_train_from_file(train_data_name);
76 /* Die if error */
77 if( ! train_data){
78 fprintf(stderr, "Error reading file '%s'.\n", train_data_name);
79 exit(1);
82 switch (num_layers){
83 case 2:
84 ann = fann_create_standard(num_layers, train_data->num_input, train_data->num_output);
85 break;
86 case 3:
87 ann = fann_create_standard(num_layers, train_data->num_input, num_neurons_hidden, train_data->num_output);
88 break;
89 case 4:
90 ann = fann_create_standard(num_layers, train_data->num_input, num_neurons_hidden, num_neurons_hidden, train_data->num_output);
91 break;
92 case 5:
93 ann = fann_create_standard(num_layers, train_data->num_input, num_neurons_hidden, num_neurons_hidden, num_neurons_hidden, train_data->num_output);
94 break;
95 default:
96 fprintf(stderr, "Wrong number of layers..");
97 exit(1);
99 fann_set_activation_function_hidden(ann, FANN_SIGMOID_SYMMETRIC);
100 fann_set_activation_function_output(ann, FANN_SIGMOID_SYMMETRIC);
102 /* Train the net */
103 printf("Training network.\n");
104 //fann_set_training_algorithm(ann, FANN_TRAIN_INCREMENTAL);
105 //fann_set_training_algorithm(ann, FANN_TRAIN_QUICKPROP);
107 fann_set_learning_momentum(ann, learning_momentum);
108 fann_train_on_data(ann, train_data, max_epochs, 50, desired_error);
110 // fann_set_activation_function_hidden(ann, FANN_THRESHOLD_SYMMETRIC);
111 // fann_set_activation_function_output(ann, FANN_THRESHOLD_SYMMETRIC);
115 /* Save the net */
116 printf("Saving network.\n");
117 fann_save(ann, net_output_name);
120 /* Clean the net */
121 fann_destroy_train(train_data);
122 fann_destroy(ann);
124 return 0;
127 void print_help(void){
128 printf("Usage: gnet_train [OPTIONS] TRAIN_DATA_FILENAME\n\
130 Trains a neural network from the TRAIN_DATA_FILENAME.\n\
132 TRAIN_DATA_FILENAME format:\n\
133 number_of_pairs length_of_input_vector length_of_output_vector\n\
134 input_vector\n\
135 output_vector\n\
136 another_input_vector\n\
137 another_output_vector\n\
138 ...\n\
139 EOF\n\
141 OPTIONS\n\
142 -l int_number\n\
143 --layers=int_number\n\
144 Number of network layers\n\
145 -n int_number\n\
146 --neurons-hidden=int_number\n\
147 Number of neurons in hidden layers.\n\
148 -p int_number\n\
149 --max-epochs=int_number\n\
150 Maximal number of epochs.\n\
151 -e float_number\n\
152 --desired-error=float_number\n\
153 Desired error when to stop training.\n\
154 -m float_number\n\
155 --learning-momentum=float_number\n\
156 Learning momentum\n\
157 -o filename\n\
158 --net-output-file=filename\n\
159 Where to save the net.\n\
161 EXAMPLE\n\
162 gnet_train -l 3 -n 666 -p 1000 -e 0.00666 -o net.net dataset.data\n");