Merge branch 'master' of ssh://repo.or.cz/srv/git/gostyle
[gostyle.git] / gnet / gnet_train.c
blob6773e21e85fa0cbb8f58bf5911a9e8e5fa60dbba
1 #include <stdio.h>
2 #include <string.h>
3 #include <getopt.h>
5 #include "floatfann.h"
6 #define LINEAR 0
7 #define SIGMOID 1
9 /* Prototypes */
10 void print_help(void);
12 /* There we go! */
13 int main(int argc, char ** argv) {
14 /* Default network params */
15 unsigned int num_layers = 3;
16 unsigned int num_neurons_hidden = 30;
17 unsigned int max_epochs = 5000;
18 float desired_error = 0.0001;
19 float learning_momentum = 0.2;
21 char * activation_function_output = "sigmoid";
22 int actfc_output = SIGMOID;
24 char * net_output_name = "go_net.net";
25 char * train_data_name=0;
27 /* Parse input */
29 static struct option long_options[] = {
30 { "layers", required_argument, NULL, 'l' },
31 { "neurons-hidden", required_argument, NULL, 'n' },
32 { "max-epochs", required_argument, NULL, 'p' },
33 { "desired-error", required_argument, NULL, 'e' },
34 { "learning-momentum", required_argument, NULL, 'm' },
35 { "net-output-file", required_argument, NULL, 'o' },
36 { "help", no_argument, NULL, 'h' },
37 { "activation-function-output", no_argument, NULL, 'a' },
38 { NULL, no_argument, NULL, 0 }
40 int c;
41 while( (c = getopt_long(argc, argv, "hl:n:p:e:m:o:a:", long_options, &optind)) != -1 ){
42 switch (c){
43 case 'l': num_layers = atoi(optarg); break;
44 case 'n': num_neurons_hidden = atoi(optarg); break;
45 case 'p': max_epochs = atoi(optarg); break;
46 case 'e': desired_error = atof(optarg); break;
47 case 'm': learning_momentum = atof(optarg); break;
48 case 'o': net_output_name = optarg; break;
49 case 'h': print_help(); exit(1); break;
50 case 'a': activation_function_output = optarg; break;
51 case '?': break;
52 default: exit(1);
55 while (optind < argc)
56 train_data_name= argv[optind++];
59 if( ! train_data_name ){ fprintf(stderr, "No training data file specified.\n"); exit(1); }
60 if( num_layers <= 0 ){ fprintf(stderr, "Number of layers must be positive.\n"); exit(1); }
61 if( num_neurons_hidden <= 0 ){ fprintf(stderr, "Number of neurons in the hidden layer must be positive.\n"); exit(1); }
62 if( max_epochs <= 0 ){ fprintf(stderr, "Max number of epochs must be positive.\n"); exit(1); }
63 if( desired_error <= 0 ){ fprintf(stderr, "Desired error must be positive.\n"); exit(1); }
64 if( learning_momentum <= 0 ){ fprintf(stderr, "Learning momentum be positive.\n"); exit(1); }
65 if( !strcmp(activation_function_output, "linear") )
66 actfc_output = LINEAR;
67 else if( !strcmp(activation_function_output, "sigmoid") )
68 actfc_output = SIGMOID;
69 else{ fprintf(stderr, "Activation function must be either 'linear' or 'sigmoid' (default).\n"); exit(1);}
71 //#ifdef DEBUG
72 #if 1
73 printf("Layers: %u\n", num_layers);
74 printf("Neurons hidden: %u\n", num_neurons_hidden);
75 printf("Max epochs: %u\n", max_epochs);
76 printf("Desired error: %f\n", desired_error);
77 printf("Learning momentum: %f\n", learning_momentum);
78 printf("Net output file: %s\n", net_output_name);
79 printf("Train data file: %s\n", train_data_name);
80 printf("Output layer activation function: %s\n", activation_function_output);
82 printf("\n");
83 #endif
85 /* Create the net */
86 struct fann *ann=0;
87 struct fann_train_data *train_data=0;
89 //printf("Loading training data file.\n");
90 train_data = fann_read_train_from_file(train_data_name);
91 /* Die if error */
92 if( ! train_data){
93 fprintf(stderr, "Error reading file '%s'.\n", train_data_name);
94 exit(1);
96 if( num_layers == 1)
97 fprintf(stderr, "Warning, network has only one layer.\n");
98 if( num_layers >= 10)
99 fprintf(stderr, "Warning, network has more than 10 layers.\n");
101 unsigned int * layers = ( unsigned int * ) malloc( num_layers * sizeof(unsigned int) );
104 layers[0] = train_data->num_input;
105 unsigned i;
106 for( i = 1 ; i < num_layers - 1 ; i++)
107 layers[i] = num_neurons_hidden;
108 layers[num_layers - 1] = train_data->num_output;
111 printf( "Network architecture:\n ->");
112 for( i = 0 ; i < num_layers ; i++)
113 printf( "%d-", layers[i]);
114 printf( ">\n\n");
117 //printf("Creating network.\n");
118 ann = fann_create_standard_array(num_layers, layers);
120 fann_set_activation_function_hidden(ann, FANN_SIGMOID_SYMMETRIC);
122 switch(actfc_output){
123 case LINEAR:
124 fann_set_activation_function_output(ann, FANN_LINEAR);
125 break;
126 case SIGMOID:
127 fann_set_activation_function_output(ann, FANN_SIGMOID_SYMMETRIC);
128 break;
129 default:
130 fprintf(stderr, "Activation function must be either 'linear' or 'sigmoid' (default).\n");
131 exit(1);
134 /* Train the net */
135 printf("Training network:\n");
136 //fann_set_training_algorithm(ann, FANN_TRAIN_INCREMENTAL);
137 //fann_set_training_algorithm(ann, FANN_TRAIN_QUICKPROP);
139 fann_set_learning_momentum(ann, learning_momentum);
140 fann_train_on_data(ann, train_data, max_epochs, 50, desired_error);
142 // fann_set_activation_function_hidden(ann, FANN_THRESHOLD_SYMMETRIC);
143 // fann_set_activation_function_output(ann, FANN_THRESHOLD_SYMMETRIC);
147 /* Save the net */
148 fann_save(ann, net_output_name);
149 printf("\nNetwork saved.\n");
152 /* Clean the net */
153 fann_destroy_train(train_data);
154 fann_destroy(ann);
156 return 0;
159 void print_help(void){
160 printf("Usage: gnet_train [OPTIONS] TRAIN_DATA_FILENAME\n\
162 Trains a neural network from the TRAIN_DATA_FILENAME.\n\
164 TRAIN_DATA_FILENAME format:\n\
165 number_of_pairs length_of_input_vector length_of_output_vector\n\
166 input_vector\n\
167 output_vector\n\
168 another_input_vector\n\
169 another_output_vector\n\
170 ...\n\
171 EOF\n\
173 OPTIONS\n\
174 -l int_number\n\
175 --layers=int_number\n\
176 Number of network layers\n\
177 -n int_number\n\
178 --neurons-hidden=int_number\n\
179 Number of neurons in hidden layers.\n\
180 -p int_number\n\
181 --max-epochs=int_number\n\
182 Maximal number of epochs.\n\
183 -e float_number\n\
184 --desired-error=float_number\n\
185 Desired error when to stop training.\n\
186 -m float_number\n\
187 --learning-momentum=float_number\n\
188 Learning momentum\n\
189 -o filename\n\
190 --net-output-file=filename\n\
191 Where to save the net.\n\
192 -a FUNCTION\n\
193 --activation-function-output=FUNCTION\n\
194 Sets activation function for the output layer.\n\
195 Possibilities are 'sigmoid' (default) or 'linear' (without apostrophes).\n\
197 EXAMPLE\n\
198 gnet_train -l 3 -n 666 -p 1000 -e 0.00666 -o net.net dataset.data\n");