Merge branch 'master' of ssh://repo.or.cz/srv/git/gostyle
[gostyle.git] / gnet / gnet_run.c
blobae462b813565227053f97b9eee4b1270a990547e
1 #include <stdio.h>
2 #include <string.h>
3 #include <getopt.h>
5 #include "floatfann.h"
7 /* Prototypes */
8 void print_help(void);
10 /* There we go! */
11 int main(int argc, char ** argv) {
12 /* Default network params */
13 char * net_input_name=0;
14 char * test_data_filename=0;
15 char * input_data_filename=0;
16 FILE * input_file = stdin;
18 /* Parse input */
20 static struct option long_options[] = {
21 { "test-data-file", required_argument, NULL, 't' },
22 { "input-data-file", required_argument, NULL, 'i' },
23 { "help", no_argument, NULL, 'h' },
24 { NULL, no_argument, NULL, 0 }
26 int c;
27 while( (c = getopt_long(argc, argv, "ht:i:", long_options, &optind)) != -1 ){
28 switch (c){
29 case 't': test_data_filename = optarg; break;
30 case 'i': input_data_filename = optarg; break;
31 case 'h': print_help(); exit(1); break;
32 case '?': break;
33 default: exit(1);
36 while (optind < argc)
37 net_input_name = argv[optind++];
40 if( ! net_input_name ){ fprintf(stderr, "No network file specified, see --help.\n"); exit(1); }
41 //#ifdef DEBUG
42 #if 1
43 fprintf(stderr,"Test input file: %s\n", test_data_filename);
44 fprintf(stderr,"Input file: %s\n", input_data_filename);
45 fprintf(stderr,"Net input file: %s\n", net_input_name);
46 fprintf(stderr,"\n");
47 #endif
48 /* Load Net */
49 struct fann *ann = fann_create_from_file(net_input_name);
50 if( ! ann ){ fprintf(stderr, "Error reading network file '%s'.\n", net_input_name); exit(1); }
52 /* Test the net if applicable */
53 if( test_data_filename ){
54 struct fann_train_data *test_data;
56 fprintf(stderr,"Testing network.\n");
57 test_data = fann_read_train_from_file(test_data_filename);
58 if( ! test_data){
59 fprintf(stderr, "Error reading file '%s'.\n", test_data_filename);
60 }else{
61 fann_reset_MSE(ann);
62 unsigned i;
63 for(i = 0; i < fann_length_train_data(test_data); i++) {
64 fann_test(ann, test_data->input[i], test_data->output[i]);
66 fprintf(stderr,"MSE error on test data: %f\n\n", fann_get_MSE(ann));
67 fann_destroy_train(test_data);
69 if( ! input_data_filename ){
70 fann_destroy(ann);
71 return 0;
75 if( input_data_filename ){
76 // if I should not read from the stdin
77 if( strcmp(input_data_filename,"-") ){
78 input_file = fopen(input_data_filename, "r");
79 if( ! input_file ){ fprintf(stderr, "Error reading input file '%s'.\n", input_data_filename); exit(1); }
83 unsigned int input_vector_len = fann_get_num_input(ann);
84 unsigned int output_vector_len = fann_get_num_output(ann);
86 fann_type * input_vector = (fann_type *) malloc( input_vector_len * sizeof(fann_type));
87 fann_type * output_vector;
89 fprintf(stderr, "Reading input:\n");
90 while(1){
91 /* Read the input vector */
92 unsigned i;
93 for(i=0; i<input_vector_len; i++){
94 if(fscanf(input_file, FANNSCANF, &input_vector[i]) != 1){
95 goto end;
98 /* Process the input vector */
100 // TODO possible leak?? Should I free it after use? Inspect fannsource.
101 output_vector = fann_run( ann, input_vector);
103 /* Write the output */
104 for(i=0; i<output_vector_len; i++){
105 printf(FANNPRINTF " ", output_vector[i]);
107 printf("\n");
110 end:
111 fclose(input_file);
112 fann_destroy(ann);
114 return 0;
117 void print_help(void){
118 printf("Usage: gnet_run [OPTIONS] NETWORK_FILENAME\n\
120 Trains a neural network from the NETWORK_FILENAME saved by gnet_train.\n\
122 OPTIONS\n\
123 -t filename\n\
124 --test-data_file=filename\n\
125 Runs the network against test data and outputs MSE.\n\
126 -i filename\n\
127 --input-data-file=filename\n\
128 Input file to read the data from. If filename == '-' then reads from the stdin.\n\
129 The file (or stdin) must have a precise FORMAT, and the vector lengths must agree with\n\
130 those in NETWORK_FILENAME.\n\
132 Input file FORMAT:\n\
133 first_number_from_the_first_vector second third ...\n\
134 first_number_from_the_second_vector second third ...\n\
135 ...\n\
136 EOF\n");