diff --git a/src/art.c b/src/art.c index 785ab526bf2387e977bbbce0c088bce646296d51..9a0559e5ab0ec9cdaf6bccb3fc3931d10346fc1d 100644 --- a/src/art.c +++ b/src/art.c @@ -53,7 +53,7 @@ void demo_art(char *cfgfile, char *weightfile, int cam_index) printf("["); int upper = 30; for(i = 0; i < upper; ++i){ - printf("%s", ((i+.5) < score*upper) ? "\u2588" : " "); + printf("%c", ((i+.5) < score*upper) ? 219 : ' '); } printf("]\n"); diff --git a/src/classifier.c b/src/classifier.c index 24b28b5b92cadbcfd028a4dbbb314a53d54f9d0a..2d0d0e0ce04210de17a66f83ddc643ce151bb8b3 100644 --- a/src/classifier.c +++ b/src/classifier.c @@ -51,7 +51,7 @@ void train_classifier(char *datacfg, char *cfgfile, char *weightfile, int clear) } if(clear) *net.seen = 0; printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay); - int imgs = net.batch; + int imgs = net.batch*net.subdivisions; list *options = read_data_cfg(datacfg); @@ -338,10 +338,10 @@ void validate_classifier_single(char *datacfg, char *filename, char *weightfile) { int i, j; network net = parse_network_cfg(filename); - set_batch_network(&net, 1); if(weightfile){ load_weights(&net, weightfile); } + set_batch_network(&net, 1); srand(time(0)); list *options = read_data_cfg(datacfg); diff --git a/src/convolutional_kernels.cu b/src/convolutional_kernels.cu index 1de9dc0110e66a9d72b05b48506d6e502a2feaff..1590fe7bc80381b4f36b90da17b4cea0c4b92636 100644 --- a/src/convolutional_kernels.cu +++ b/src/convolutional_kernels.cu @@ -72,10 +72,6 @@ void binarize_filters_gpu(float *filters, int n, int size, float *binary) void forward_convolutional_layer_gpu(convolutional_layer l, network_state state) { int i; - int m = l.n; - int k = l.size*l.size*l.c; - int n = convolutional_out_height(l)* - convolutional_out_width(l); fill_ongpu(l.outputs*l.batch, 0, l.output_gpu, 1); if(l.binary){ @@ -109,6 +105,9 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state) l.output_gpu); #else + int m = l.n; + int k = l.size*l.size*l.c; + int n = l.out_w*l.out_h; for(i = 0; i < l.batch; ++i){ im2col_ongpu(state.input + i*l.c*l.h*l.w, l.c, l.h, l.w, l.size, l.stride, l.pad, state.workspace); float * a = l.filters_gpu; @@ -121,23 +120,18 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state) if (l.batch_normalize) { forward_batchnorm_layer_gpu(l, state); } - add_bias_gpu(l.output_gpu, l.biases_gpu, l.batch, l.n, n); + add_bias_gpu(l.output_gpu, l.biases_gpu, l.batch, l.n, l.out_w*l.out_h); - activate_array_ongpu(l.output_gpu, m*n*l.batch, l.activation); + activate_array_ongpu(l.output_gpu, l.outputs*l.batch, l.activation); //if(l.dot > 0) dot_error_gpu(l); if(l.binary || l.xnor) swap_binary(&l); } void backward_convolutional_layer_gpu(convolutional_layer l, network_state state) { - int m = l.n; - int n = l.size*l.size*l.c; - int k = convolutional_out_height(l)* - convolutional_out_width(l); - - gradient_array_ongpu(l.output_gpu, m*k*l.batch, l.activation, l.delta_gpu); + gradient_array_ongpu(l.output_gpu, l.outputs*l.batch, l.activation, l.delta_gpu); - backward_bias_gpu(l.bias_updates_gpu, l.delta_gpu, l.batch, l.n, k); + backward_bias_gpu(l.bias_updates_gpu, l.delta_gpu, l.batch, l.n, l.out_w*l.out_h); if(l.batch_normalize){ backward_batchnorm_layer_gpu(l, state); @@ -181,6 +175,10 @@ void backward_convolutional_layer_gpu(convolutional_layer l, network_state state } #else + int m = l.n; + int n = l.size*l.size*l.c; + int k = l.out_w*l.out_h; + int i; for(i = 0; i < l.batch; ++i){ float * a = l.delta_gpu; diff --git a/src/convolutional_layer.c b/src/convolutional_layer.c index af867e5ef595b01f933f6e914b3043991626df9e..c88cb0ad600a88cfa3312722bfc324e2a6901638 100644 --- a/src/convolutional_layer.c +++ b/src/convolutional_layer.c @@ -14,6 +14,7 @@ #ifndef AI2 #define AI2 0 +void forward_xnor_layer(layer l, network_state state); #endif void swap_binary(convolutional_layer *l) @@ -127,6 +128,47 @@ size_t get_workspace_size(layer l){ #endif } +#ifdef GPU +#ifdef CUDNN +void cudnn_convolutional_setup(layer *l) +{ + cudnnSetTensor4dDescriptor(l->dsrcTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l->batch, l->c, l->h, l->w); + cudnnSetTensor4dDescriptor(l->ddstTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l->batch, l->out_c, l->out_h, l->out_w); + cudnnSetFilter4dDescriptor(l->dfilterDesc, CUDNN_DATA_FLOAT, CUDNN_TENSOR_NCHW, l->n, l->c, l->size, l->size); + + cudnnSetTensor4dDescriptor(l->srcTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l->batch, l->c, l->h, l->w); + cudnnSetTensor4dDescriptor(l->dstTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l->batch, l->out_c, l->out_h, l->out_w); + cudnnSetFilter4dDescriptor(l->filterDesc, CUDNN_DATA_FLOAT, CUDNN_TENSOR_NCHW, l->n, l->c, l->size, l->size); + int padding = l->pad ? l->size/2 : 0; + cudnnSetConvolution2dDescriptor(l->convDesc, padding, padding, l->stride, l->stride, 1, 1, CUDNN_CROSS_CORRELATION); + cudnnGetConvolutionForwardAlgorithm(cudnn_handle(), + l->srcTensorDesc, + l->filterDesc, + l->convDesc, + l->dstTensorDesc, + CUDNN_CONVOLUTION_FWD_PREFER_FASTEST, + 0, + &l->fw_algo); + cudnnGetConvolutionBackwardDataAlgorithm(cudnn_handle(), + l->filterDesc, + l->ddstTensorDesc, + l->convDesc, + l->dsrcTensorDesc, + CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST, + 0, + &l->bd_algo); + cudnnGetConvolutionBackwardFilterAlgorithm(cudnn_handle(), + l->srcTensorDesc, + l->ddstTensorDesc, + l->convDesc, + l->dfilterDesc, + CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST, + 0, + &l->bf_algo); +} +#endif +#endif + convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int n, int size, int stride, int pad, ACTIVATION activation, int batch_normalize, int binary, int xnor) { int i; @@ -231,39 +273,7 @@ convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int cudnnCreateTensorDescriptor(&l.ddstTensorDesc); cudnnCreateFilterDescriptor(&l.dfilterDesc); cudnnCreateConvolutionDescriptor(&l.convDesc); - cudnnSetTensor4dDescriptor(l.dsrcTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l.batch, l.c, l.h, l.w); - cudnnSetTensor4dDescriptor(l.ddstTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l.batch, l.out_c, l.out_h, l.out_w); - cudnnSetFilter4dDescriptor(l.dfilterDesc, CUDNN_DATA_FLOAT, CUDNN_TENSOR_NCHW, l.n, l.c, l.size, l.size); - - cudnnSetTensor4dDescriptor(l.srcTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l.batch, l.c, l.h, l.w); - cudnnSetTensor4dDescriptor(l.dstTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l.batch, l.out_c, l.out_h, l.out_w); - cudnnSetFilter4dDescriptor(l.filterDesc, CUDNN_DATA_FLOAT, CUDNN_TENSOR_NCHW, l.n, l.c, l.size, l.size); - int padding = l.pad ? l.size/2 : 0; - cudnnSetConvolution2dDescriptor(l.convDesc, padding, padding, l.stride, l.stride, 1, 1, CUDNN_CROSS_CORRELATION); - cudnnGetConvolutionForwardAlgorithm(cudnn_handle(), - l.srcTensorDesc, - l.filterDesc, - l.convDesc, - l.dstTensorDesc, - CUDNN_CONVOLUTION_FWD_PREFER_FASTEST, - 0, - &l.fw_algo); - cudnnGetConvolutionBackwardDataAlgorithm(cudnn_handle(), - l.filterDesc, - l.ddstTensorDesc, - l.convDesc, - l.dsrcTensorDesc, - CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST, - 0, - &l.bd_algo); - cudnnGetConvolutionBackwardFilterAlgorithm(cudnn_handle(), - l.srcTensorDesc, - l.ddstTensorDesc, - l.convDesc, - l.dfilterDesc, - CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST, - 0, - &l.bf_algo); + cudnn_convolutional_setup(&l); #endif #endif l.workspace_size = get_workspace_size(l); @@ -335,39 +345,7 @@ void resize_convolutional_layer(convolutional_layer *l, int w, int h) l->delta_gpu = cuda_make_array(l->delta, l->batch*out_h*out_w*l->n); l->output_gpu = cuda_make_array(l->output, l->batch*out_h*out_w*l->n); #ifdef CUDNN - cudnnSetTensor4dDescriptor(l->dsrcTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l->batch, l->c, l->h, l->w); - cudnnSetTensor4dDescriptor(l->ddstTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l->batch, l->out_c, l->out_h, l->out_w); - cudnnSetFilter4dDescriptor(l->dfilterDesc, CUDNN_DATA_FLOAT, CUDNN_TENSOR_NCHW, l->n, l->c, l->size, l->size); - - cudnnSetTensor4dDescriptor(l->srcTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l->batch, l->c, l->h, l->w); - cudnnSetTensor4dDescriptor(l->dstTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l->batch, l->out_c, l->out_h, l->out_w); - cudnnSetFilter4dDescriptor(l->filterDesc, CUDNN_DATA_FLOAT, CUDNN_TENSOR_NCHW, l->n, l->c, l->size, l->size); - int padding = l->pad ? l->size/2 : 0; - cudnnSetConvolution2dDescriptor(l->convDesc, padding, padding, l->stride, l->stride, 1, 1, CUDNN_CROSS_CORRELATION); - cudnnGetConvolutionForwardAlgorithm(cudnn_handle(), - l->srcTensorDesc, - l->filterDesc, - l->convDesc, - l->dstTensorDesc, - CUDNN_CONVOLUTION_FWD_PREFER_FASTEST, - 0, - &l->fw_algo); - cudnnGetConvolutionBackwardDataAlgorithm(cudnn_handle(), - l->filterDesc, - l->ddstTensorDesc, - l->convDesc, - l->dsrcTensorDesc, - CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST, - 0, - &l->bd_algo); - cudnnGetConvolutionBackwardFilterAlgorithm(cudnn_handle(), - l->srcTensorDesc, - l->ddstTensorDesc, - l->convDesc, - l->dfilterDesc, - CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST, - 0, - &l->bf_algo); + cudnn_convolutional_setup(l); #endif #endif l->workspace_size = get_workspace_size(*l); diff --git a/src/convolutional_layer.h b/src/convolutional_layer.h index 3d52b229ab060aebfb087d1c7a4dafeb3c95e3d3..972b765608f5392ac1c991421d570ad1587637f1 100644 --- a/src/convolutional_layer.h +++ b/src/convolutional_layer.h @@ -19,6 +19,9 @@ void pull_convolutional_layer(convolutional_layer layer); void add_bias_gpu(float *output, float *biases, int batch, int n, int size); void backward_bias_gpu(float *bias_updates, float *delta, int batch, int n, int size); +#ifdef CUDNN +void cudnn_convolutional_setup(layer *l); +#endif #endif convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int n, int size, int stride, int pad, ACTIVATION activation, int batch_normalization, int binary, int xnor); diff --git a/src/detection_layer.c b/src/detection_layer.c index e103b4ea954e43c3db98a0080b2f5259c16c8f80..f7019ef2e9238fff9b08d18499ed86a90e8fccac 100644 --- a/src/detection_layer.c +++ b/src/detection_layer.c @@ -133,6 +133,9 @@ void forward_detection_layer(const detection_layer l, network_state state) best_index = 0; } } + if(1 && *(state.net.seen) < 100000){ + best_index = rand()%l.n; + } int box_index = index + locations*(l.classes + l.n) + (i*l.n + best_index) * l.coords; int tbox_index = truth_index + 1 + l.classes; @@ -181,7 +184,6 @@ void forward_detection_layer(const detection_layer l, network_state state) for (b = 0; b < l.batch; ++b) { int index = b*l.inputs; for (i = 0; i < locations; ++i) { - int truth_index = (b*locations + i)*(1+l.coords+l.classes); for (j = 0; j < l.n; ++j) { int p_index = index + locations*l.classes + i*l.n + j; costs[b*locations*l.n + i*l.n + j] = l.delta[p_index]*l.delta[p_index]; @@ -194,7 +196,6 @@ void forward_detection_layer(const detection_layer l, network_state state) for (b = 0; b < l.batch; ++b) { int index = b*l.inputs; for (i = 0; i < locations; ++i) { - int truth_index = (b*locations + i)*(1+l.coords+l.classes); for (j = 0; j < l.n; ++j) { int p_index = index + locations*l.classes + i*l.n + j; if (l.delta[p_index]*l.delta[p_index] < cutoff) l.delta[p_index] = 0; @@ -233,7 +234,7 @@ void forward_detection_layer_gpu(const detection_layer l, network_state state) cuda_pull_array(state.truth, truth_cpu, num_truth); } cuda_pull_array(state.input, in_cpu, l.batch*l.inputs); - network_state cpu_state; + network_state cpu_state = state; cpu_state.train = state.train; cpu_state.truth = truth_cpu; cpu_state.input = in_cpu; diff --git a/src/go.c b/src/go.c index 7883ed58ec67b962cb4b25fa9be0e804bfe53911..91beaf1827bc9074772dd401ca4ffece31e84b0e 100644 --- a/src/go.c +++ b/src/go.c @@ -217,7 +217,7 @@ void print_board(float *board, int swap, int *indexes) } fprintf(stream, "\n"); for(j = 0; j < 19; ++j){ - fprintf(stream, "%2d ", (inverted) ? 19-j : j+1); + fprintf(stream, "%2d", (inverted) ? 19-j : j+1); for(i = 0; i < 19; ++i){ int index = j*19 + i; if(indexes){ @@ -225,17 +225,26 @@ void print_board(float *board, int swap, int *indexes) for(n = 0; n < nind; ++n){ if(index == indexes[n]){ found = 1; + /* if(n == 0) fprintf(stream, "\uff11"); else if(n == 1) fprintf(stream, "\uff12"); else if(n == 2) fprintf(stream, "\uff13"); else if(n == 3) fprintf(stream, "\uff14"); else if(n == 4) fprintf(stream, "\uff15"); + */ + if(n == 0) fprintf(stream, " 1"); + else if(n == 1) fprintf(stream, " 2"); + else if(n == 2) fprintf(stream, " 3"); + else if(n == 3) fprintf(stream, " 4"); + else if(n == 4) fprintf(stream, " 5"); } } if(found) continue; } - if(board[index]*-swap > 0) fprintf(stream, "\u25C9 "); - else if(board[index]*-swap < 0) fprintf(stream, "\u25EF "); + //if(board[index]*-swap > 0) fprintf(stream, "\u25C9 "); + //else if(board[index]*-swap < 0) fprintf(stream, "\u25EF "); + if(board[index]*-swap > 0) fprintf(stream, " O"); + else if(board[index]*-swap < 0) fprintf(stream, " X"); else fprintf(stream, " "); } fprintf(stream, "\n"); @@ -640,8 +649,10 @@ void test_go(char *cfg, char *weights, int multi) col = index % 19; printf("%d: %c %d, %.2f%%\n", i+1, col + 'A' + 1*(col > 7 && noi), (inverted)?19 - row : row+1, move[index]*100); } - if(color == 1) printf("\u25EF Enter move: "); - else printf("\u25C9 Enter move: "); + //if(color == 1) printf("\u25EF Enter move: "); + //else printf("\u25C9 Enter move: "); + if(color == 1) printf("X Enter move: "); + else printf("O Enter move: "); char c; char *line = fgetl(stdin); diff --git a/src/network.c b/src/network.c index 2960d67a399365c1c7aea7797fb640d73b9cbc6f..51f74d962323c8da80f13ca84eda28294447c140 100644 --- a/src/network.c +++ b/src/network.c @@ -392,6 +392,11 @@ void set_batch_network(network *net, int b) int i; for(i = 0; i < net->n; ++i){ net->layers[i].batch = b; + #ifdef CUDNN + if(net->layers[i].type == CONVOLUTIONAL){ + cudnn_convolutional_setup(net->layers + i); + } + #endif } } diff --git a/src/rnn.c b/src/rnn.c index cda38ef8ae49c0275d14570a790cad57772b0899..4f0e011a19c255350694d01357e1bc038e95e9e2 100644 --- a/src/rnn.c +++ b/src/rnn.c @@ -280,7 +280,7 @@ void test_char_rnn(char *cfgfile, char *weightfile, int num, char *seed, float t printf("\n"); } -void test_tactic_rnn(char *cfgfile, char *weightfile, int num, char *seed, float temp, int rseed, char *token_file) +void test_tactic_rnn(char *cfgfile, char *weightfile, int num, float temp, int rseed, char *token_file) { char **tokens = 0; if(token_file){ @@ -301,9 +301,8 @@ void test_tactic_rnn(char *cfgfile, char *weightfile, int num, char *seed, float int i, j; for(i = 0; i < net.n; ++i) net.layers[i].temperature = temp; int c = 0; - int len = strlen(seed); float *input = calloc(inputs, sizeof(float)); - float *out; + float *out = 0; while((c = getc(stdin)) != EOF){ input[c] = 1; @@ -490,5 +489,5 @@ void run_char_rnn(int argc, char **argv) else if(0==strcmp(argv[2], "validtactic")) valid_tactic_rnn(cfg, weights, seed); else if(0==strcmp(argv[2], "vec")) vec_char_rnn(cfg, weights, seed); else if(0==strcmp(argv[2], "generate")) test_char_rnn(cfg, weights, len, seed, temp, rseed, tokens); - else if(0==strcmp(argv[2], "generatetactic")) test_tactic_rnn(cfg, weights, len, seed, temp, rseed, tokens); + else if(0==strcmp(argv[2], "generatetactic")) test_tactic_rnn(cfg, weights, len, temp, rseed, tokens); }