diff --git a/.gitignore b/.gitignore index 12d2c999c2a14558a04c700f83ac359c17001c36..9057c39dd0f4ca6011f279e52d22028ccc68052a 100644 --- a/.gitignore +++ b/.gitignore @@ -2,12 +2,18 @@ *.dSYM *.csv *.out +*.png +*.sh mnist/ +data/ +caffe/ +grasp/ images/ opencv/ convnet/ decaf/ submission/ +cfg/ darknet # OS Generated # diff --git a/Makefile b/Makefile index 6e7ecf752ce7a5ba4d1061996c6f7d01ed0cdb98..12432b9280de1fd6bd5c3eec80cda9b4eb9f4cce 100644 --- a/Makefile +++ b/Makefile @@ -9,7 +9,7 @@ OBJDIR=./obj/ CC=gcc NVCC=nvcc OPTS=-O3 -LDFLAGS=`pkg-config --libs opencv` -lm -pthread +LDFLAGS=`pkg-config --libs opencv` -lm -pthread -lstdc++ COMMON=`pkg-config --cflags opencv` -I/usr/local/cuda/include/ CFLAGS=-Wall -Wfatal-errors @@ -25,7 +25,7 @@ CFLAGS+=-DGPU LDFLAGS+= -L/usr/local/cuda/lib64 -lcuda -lcudart -lcublas endif -OBJ=gemm.o utils.o cuda.o deconvolutional_layer.o convolutional_layer.o list.o image.o activations.o im2col.o col2im.o blas.o crop_layer.o dropout_layer.o maxpool_layer.o softmax_layer.o data.o matrix.o network.o connected_layer.o cost_layer.o normalization_layer.o parser.o option_list.o darknet.o +OBJ=gemm.o utils.o cuda.o deconvolutional_layer.o convolutional_layer.o list.o image.o activations.o im2col.o col2im.o blas.o crop_layer.o dropout_layer.o maxpool_layer.o softmax_layer.o data.o matrix.o network.o connected_layer.o cost_layer.o normalization_layer.o parser.o option_list.o darknet.o detection_layer.o ifeq ($(GPU), 1) OBJ+=convolutional_kernels.o deconvolutional_kernels.o activation_kernels.o im2col_kernels.o col2im_kernels.o blas_kernels.o crop_layer_kernels.o dropout_layer_kernels.o maxpool_layer_kernels.o softmax_layer_kernels.o network_kernels.o endif diff --git a/src/cost_layer.c b/src/cost_layer.c index 34c8fb59c5cccbb0ccbdb91ea4a15b1fa253019f..815827510be29d9af8ba45530f147cd87fcf2899 100644 --- a/src/cost_layer.c +++ b/src/cost_layer.c @@ -10,7 +10,6 @@ COST_TYPE get_cost_type(char *s) { if (strcmp(s, "sse")==0) return SSE; - if (strcmp(s, "detection")==0) return DETECTION; fprintf(stderr, "Couldn't find activation function %s, going with SSE\n", s); return SSE; } @@ -20,8 +19,6 @@ char *get_cost_string(COST_TYPE a) switch(a){ case SSE: return "sse"; - case DETECTION: - return "detection"; } return "sse"; } @@ -41,17 +38,20 @@ cost_layer *make_cost_layer(int batch, int inputs, COST_TYPE type) return layer; } +void pull_cost_layer(cost_layer layer) +{ + cuda_pull_array(layer.delta_gpu, layer.delta, layer.batch*layer.inputs); +} +void push_cost_layer(cost_layer layer) +{ + cuda_push_array(layer.delta_gpu, layer.delta, layer.batch*layer.inputs); +} + void forward_cost_layer(cost_layer layer, float *input, float *truth) { if (!truth) return; copy_cpu(layer.batch*layer.inputs, truth, 1, layer.delta, 1); axpy_cpu(layer.batch*layer.inputs, -1, input, 1, layer.delta, 1); - if(layer.type == DETECTION){ - int i; - for(i = 0; i < layer.batch*layer.inputs; ++i){ - if((i%25) && !truth[(i/25)*25]) layer.delta[i] = 0; - } - } *(layer.output) = dot_cpu(layer.batch*layer.inputs, layer.delta, 1, layer.delta, 1); //printf("cost: %f\n", *layer.output); } @@ -66,14 +66,21 @@ void backward_cost_layer(const cost_layer layer, float *input, float *delta) void forward_cost_layer_gpu(cost_layer layer, float * input, float * truth) { if (!truth) return; + + /* + float *in = calloc(layer.inputs*layer.batch, sizeof(float)); + float *t = calloc(layer.inputs*layer.batch, sizeof(float)); + cuda_pull_array(input, in, layer.batch*layer.inputs); + cuda_pull_array(truth, t, layer.batch*layer.inputs); + forward_cost_layer(layer, in, t); + cuda_push_array(layer.delta_gpu, layer.delta, layer.batch*layer.inputs); + free(in); + free(t); + */ copy_ongpu(layer.batch*layer.inputs, truth, 1, layer.delta_gpu, 1); axpy_ongpu(layer.batch*layer.inputs, -1, input, 1, layer.delta_gpu, 1); - if(layer.type==DETECTION){ - mask_ongpu(layer.inputs*layer.batch, layer.delta_gpu, truth, 25); - } - cuda_pull_array(layer.delta_gpu, layer.delta, layer.batch*layer.inputs); *(layer.output) = dot_cpu(layer.batch*layer.inputs, layer.delta, 1, layer.delta, 1); //printf("cost: %f\n", *layer.output); diff --git a/src/cost_layer.h b/src/cost_layer.h index e58aae1fa0545c0e3e93377f6211d4aa24773910..0855405276b6aa64956c547d46481b06e501a598 100644 --- a/src/cost_layer.h +++ b/src/cost_layer.h @@ -2,12 +2,14 @@ #define COST_LAYER_H typedef enum{ - SSE, DETECTION + SSE } COST_TYPE; typedef struct { int inputs; int batch; + int coords; + int classes; float *delta; float *output; COST_TYPE type; diff --git a/src/cuda.c b/src/cuda.c index 8849fb1fcdba3f6cd10792e633167d2fcbe532c0..c9142905af50ea4e94c748dae96b5affebb76f74 100644 --- a/src/cuda.c +++ b/src/cuda.c @@ -5,6 +5,7 @@ int gpu_index = 0; #include "cuda.h" #include "utils.h" #include "blas.h" +#include "assert.h" #include <stdlib.h> @@ -15,6 +16,7 @@ void check_error(cudaError_t status) const char *s = cudaGetErrorString(status); char buffer[256]; printf("CUDA Error: %s\n", s); + assert(0); snprintf(buffer, 256, "CUDA Error: %s", s); error(buffer); } diff --git a/src/darknet.c b/src/darknet.c index fc58f3d17134b8b695005ee98ec1863be0377022..413d7f2186f03c89df198255be2b422d85e02af7 100644 --- a/src/darknet.c +++ b/src/darknet.c @@ -36,42 +36,30 @@ char *class_names[] = {"aeroplane", "bicycle", "bird", "boat", "bottle", "bus", void draw_detection(image im, float *box, int side) { int classes = 20; - int elems = 4+classes+1; + int elems = 4+classes; int j; int r, c; - float amount[AMNT] = {0}; - for(r = 0; r < side*side; ++r){ - float val = box[r*elems]; - for(j = 0; j < AMNT; ++j){ - if(val > amount[j]) { - float swap = val; - val = amount[j]; - amount[j] = swap; - } - } - } - float smallest = amount[AMNT-1]; for(r = 0; r < side; ++r){ for(c = 0; c < side; ++c){ j = (r*side + c) * elems; //printf("%d\n", j); //printf("Prob: %f\n", box[j]); - if(box[j] >= smallest){ - int class = max_index(box+j+1, classes); - int z; - for(z = 0; z < classes; ++z) printf("%f %s\n", box[j+1+z], class_names[z]); - printf("%f %s\n", box[j+1+class], class_names[class]); + int class = max_index(box+j, classes); + if(box[j+class] > .02 || 1){ + //int z; + //for(z = 0; z < classes; ++z) printf("%f %s\n", box[j+z], class_names[z]); + printf("%f %s\n", box[j+class], class_names[class]); float red = get_color(0,class,classes); float green = get_color(1,class,classes); float blue = get_color(2,class,classes); j += classes; int d = im.w/side; - int y = r*d+box[j+1]*d; - int x = c*d+box[j+2]*d; - int h = box[j+3]*im.h; - int w = box[j+4]*im.w; + int y = r*d+box[j]*d; + int x = c*d+box[j+1]*d; + int h = box[j+2]*im.h; + int w = box[j+3]*im.w; draw_box(im, x-w/2, y-h/2, x+w/2, y+h/2,red,green,blue); } } @@ -117,21 +105,22 @@ void train_detection_net(char *cfgfile, char *weightfile) data train, buffer; int im_dim = 512; int jitter = 64; - pthread_t load_thread = load_data_detection_thread(imgs, paths, plist->size, 20, im_dim, im_dim, 7, 7, jitter, &buffer); + int classes = 21; + pthread_t load_thread = load_data_detection_thread(imgs, paths, plist->size, classes, im_dim, im_dim, 7, 7, jitter, &buffer); clock_t time; while(1){ i += 1; time=clock(); pthread_join(load_thread, 0); train = buffer; - load_thread = load_data_detection_thread(imgs, paths, plist->size, 20, im_dim, im_dim, 7, 7, jitter, &buffer); + load_thread = load_data_detection_thread(imgs, paths, plist->size, classes, im_dim, im_dim, 7, 7, jitter, &buffer); -/* - image im = float_to_image(im_dim - jitter, im_dim-jitter, 3, train.X.vals[0]); - draw_detection(im, train.y.vals[0], 7); - show_image(im, "truth"); - cvWaitKey(0); - */ + /* + image im = float_to_image(im_dim - jitter, im_dim-jitter, 3, train.X.vals[0]); + draw_detection(im, train.y.vals[0], 7); + show_image(im, "truth"); + cvWaitKey(0); + */ printf("Loaded: %lf seconds\n", sec(clock()-time)); time=clock(); @@ -139,7 +128,7 @@ void train_detection_net(char *cfgfile, char *weightfile) net.seen += imgs; avg_loss = avg_loss*.9 + loss*.1; printf("%d: %f, %f avg, %lf seconds, %d images\n", i, loss, avg_loss, sec(clock()-time), i*imgs); - if(i%800==0){ + if(i%100==0){ char buff[256]; sprintf(buff, "/home/pjreddie/imagenet_backup/%s_%d.weights",base, i); save_weights(net, buff); @@ -161,7 +150,7 @@ void validate_detection_net(char *cfgfile, char *weightfile) char **paths = (char **)list_to_array(plist); int num_output = 1225; int im_size = 448; - int classes = 20; + int classes = 21; int m = plist->size; int i = 0; @@ -180,30 +169,29 @@ void validate_detection_net(char *cfgfile, char *weightfile) num = (i+1)*m/splits - i*m/splits; char **part = paths+(i*m/splits); if(i != splits) load_thread = load_data_thread(part, num, 0, 0, num_output, im_size, im_size, &buffer); - + fprintf(stderr, "%d: Loaded: %lf seconds\n", i, sec(clock()-time)); matrix pred = network_predict_data(net, val); int j, k, class; for(j = 0; j < pred.rows; ++j){ - for(k = 0; k < pred.cols; k += classes+4+1){ + for(k = 0; k < pred.cols; k += classes+4){ /* - int z; - for(z = 0; z < 25; ++z) printf("%f, ", pred.vals[j][k+z]); - printf("\n"); - */ + int z; + for(z = 0; z < 25; ++z) printf("%f, ", pred.vals[j][k+z]); + printf("\n"); + */ - float p = pred.vals[j][k]; //if (pred.vals[j][k] > .001){ - for(class = 0; class < classes; ++class){ - int index = (k)/(classes+4+1); + for(class = 0; class < classes-1; ++class){ + int index = (k)/(classes+4); int r = index/7; int c = index%7; - float y = (r + pred.vals[j][k+1+classes])/7.; - float x = (c + pred.vals[j][k+2+classes])/7.; - float h = pred.vals[j][k+3+classes]; - float w = pred.vals[j][k+4+classes]; - printf("%d %d %f %f %f %f %f\n", (i-1)*m/splits + j, class, p*pred.vals[j][k+class+1], y, x, h, w); + float y = (r + pred.vals[j][k+0+classes])/7.; + float x = (c + pred.vals[j][k+1+classes])/7.; + float h = pred.vals[j][k+2+classes]; + float w = pred.vals[j][k+3+classes]; + printf("%d %d %f %f %f %f %f\n", (i-1)*m/splits + j, class, pred.vals[j][k+class], y, x, h, w); } //} } @@ -462,7 +450,7 @@ void test_detection(char *cfgfile, char *weightfile) if(weightfile){ load_weights(&net, weightfile); } - int im_size = 224; + int im_size = 448; set_batch_network(&net, 1); srand(2222222); clock_t time; diff --git a/src/data.c b/src/data.c index a6b6db36a92785471b310eeef02fb4e83ebff315..0c935976863b520a3a56622de12611ea7d3f02ae 100644 --- a/src/data.c +++ b/src/data.c @@ -89,8 +89,7 @@ void fill_truth_detection(char *path, float *truth, int classes, int height, int float dw = (x - i*box_width)/box_width; float dh = (y - j*box_height)/box_height; //printf("%d %d %d %f %f\n", id, i, j, dh, dw); - int index = (i+j*num_width)*(4+classes+1); - truth[index++] = 1; + int index = (i+j*num_width)*(4+classes); truth[index+id] = 1; index += classes; truth[index++] = dh; @@ -98,6 +97,12 @@ void fill_truth_detection(char *path, float *truth, int classes, int height, int truth[index++] = h*(height+jitter)/height; truth[index++] = w*(width+jitter)/width; } + int i, j; + for(i = 0; i < num_height*num_width*(4+classes); i += 4+classes){ + int background = 1; + for(j = i; j < i+classes; ++j) if (truth[j]) background = 0; + truth[i+classes-1] = background; + } fclose(file); } @@ -209,7 +214,7 @@ data load_data_detection_jitter_random(int n, char **paths, int m, int classes, data d; d.shallow = 0; d.X = load_image_paths(random_paths, n, h, w); - int k = nh*nw*(4+classes+1); + int k = nh*nw*(4+classes); d.y = make_matrix(n, k); for(i = 0; i < n; ++i){ int dx = rand()%jitter; diff --git a/src/detection_layer.c b/src/detection_layer.c index 65370795add253143e2f913fcfcde6472ac126d7..bbc2e4ff95172f21c570c2e1e8674da937c7073b 100644 --- a/src/detection_layer.c +++ b/src/detection_layer.c @@ -1,72 +1,123 @@ -int detection_out_height(detection_layer layer) +#include "detection_layer.h" +#include "activations.h" +#include "softmax_layer.h" +#include "blas.h" +#include "cuda.h" +#include <stdio.h> +#include <stdlib.h> + +int get_detection_layer_locations(detection_layer layer) { - return layer.size + layer.h*layer.stride; + return layer.inputs / (layer.classes+layer.coords+layer.rescore); } -int detection_out_width(detection_layer layer) +int get_detection_layer_output_size(detection_layer layer) { - return layer.size + layer.w*layer.stride; + return get_detection_layer_locations(layer)*(layer.classes+layer.coords); } -detection_layer *make_detection_layer(int batch, int h, int w, int c, int n, int size, int stride, ACTIVATION activation) +detection_layer *make_detection_layer(int batch, int inputs, int classes, int coords, int rescore) { - int i; - size = 2*(size/2)+1; //HA! And you thought you'd use an even sized filter... detection_layer *layer = calloc(1, sizeof(detection_layer)); - layer->h = h; - layer->w = w; - layer->c = c; - layer->n = n; - layer->batch = batch; - layer->stride = stride; - layer->size = size; - assert(c%n == 0); - - layer->filters = calloc(c*size*size, sizeof(float)); - layer->filter_updates = calloc(c*size*size, sizeof(float)); - layer->filter_momentum = calloc(c*size*size, sizeof(float)); - - float scale = 1./(size*size*c); - for(i = 0; i < c*n*size*size; ++i) layer->filters[i] = scale*(rand_uniform()); - - int out_h = detection_out_height(*layer); - int out_w = detection_out_width(*layer); - - layer->output = calloc(layer->batch * out_h * out_w * n, sizeof(float)); - layer->delta = calloc(layer->batch * out_h * out_w * n, sizeof(float)); - layer->activation = activation; + layer->batch = batch; + layer->inputs = inputs; + layer->classes = classes; + layer->coords = coords; + layer->rescore = rescore; + int outputs = get_detection_layer_output_size(*layer); + layer->output = calloc(batch*outputs, sizeof(float)); + layer->delta = calloc(batch*outputs, sizeof(float)); + #ifdef GPU + layer->output_gpu = cuda_make_array(0, batch*outputs); + layer->delta_gpu = cuda_make_array(0, batch*outputs); + #endif - fprintf(stderr, "Convolutional Layer: %d x %d x %d image, %d filters -> %d x %d x %d image\n", h,w,c,n, out_h, out_w, n); + fprintf(stderr, "Detection Layer\n"); srand(0); return layer; } -void forward_detection_layer(const detection_layer layer, float *in) +void forward_detection_layer(const detection_layer layer, float *in, float *truth) { - int out_h = detection_out_height(layer); - int out_w = detection_out_width(layer); - int i,j,fh, fw,c; - memset(layer.output, 0, layer->batch*layer->n*out_h*out_w*sizeof(float)); - for(c = 0; c < layer.c; ++c){ - for(i = 0; i < layer.h; ++i){ - for(j = 0; j < layer.w; ++j){ - float val = layer->input[j+(i + c*layer.h)*layer.w]; - for(fh = 0; fh < layer.size; ++fh){ - for(fw = 0; fw < layer.size; ++fw){ - int h = i*layer.stride + fh; - int w = j*layer.stride + fw; - layer.output[w+(h+c/n*out_h)*out_w] += val*layer->filters[fw+(fh+c*layer.size)*layer.size]; - } - } - } + int in_i = 0; + int out_i = 0; + int locations = get_detection_layer_locations(layer); + int i,j; + for(i = 0; i < layer.batch*locations; ++i){ + int mask = (!truth || !truth[out_i + layer.classes - 1]); + float scale = 1; + if(layer.rescore) scale = in[in_i++]; + for(j = 0; j < layer.classes; ++j){ + layer.output[out_i++] = scale*in[in_i++]; + } + softmax_array(layer.output + out_i - layer.classes, layer.classes, layer.output + out_i - layer.classes); + activate_array(layer.output+out_i, layer.coords, SIGMOID); + for(j = 0; j < layer.coords; ++j){ + layer.output[out_i++] = mask*in[in_i++]; } + //printf("%d\n", mask); + //for(j = 0; j < layer.classes+layer.coords; ++j) printf("%f ", layer.output[i*(layer.classes+layer.coords)+j]); + //printf ("\n"); } } -void backward_detection_layer(const detection_layer layer, float *delta) +void backward_detection_layer(const detection_layer layer, float *in, float *delta) { + int locations = get_detection_layer_locations(layer); + int i,j; + int in_i = 0; + int out_i = 0; + for(i = 0; i < layer.batch*locations; ++i){ + float scale = 1; + float latent_delta = 0; + if(layer.rescore) scale = in[in_i++]; + for(j = 0; j < layer.classes; ++j){ + latent_delta += in[in_i]*layer.delta[out_i]; + delta[in_i++] = scale*layer.delta[out_i++]; + } + + for(j = 0; j < layer.coords; ++j){ + delta[in_i++] = layer.delta[out_i++]; + } + gradient_array(in + in_i - layer.coords, layer.coords, SIGMOID, layer.delta + out_i - layer.coords); + if(layer.rescore) delta[in_i-layer.coords-layer.classes-layer.rescore] = latent_delta; + } } +#ifdef GPU + +void forward_detection_layer_gpu(const detection_layer layer, float *in, float *truth) +{ + int outputs = get_detection_layer_output_size(layer); + float *in_cpu = calloc(layer.batch*layer.inputs, sizeof(float)); + float *truth_cpu = 0; + if(truth){ + truth_cpu = calloc(layer.batch*outputs, sizeof(float)); + cuda_pull_array(truth, truth_cpu, layer.batch*outputs); + } + cuda_pull_array(in, in_cpu, layer.batch*layer.inputs); + forward_detection_layer(layer, in_cpu, truth_cpu); + cuda_push_array(layer.output_gpu, layer.output, layer.batch*outputs); + free(in_cpu); + if(truth_cpu) free(truth_cpu); +} + +void backward_detection_layer_gpu(detection_layer layer, float *in, float *delta) +{ + int outputs = get_detection_layer_output_size(layer); + + float *in_cpu = calloc(layer.batch*layer.inputs, sizeof(float)); + float *delta_cpu = calloc(layer.batch*layer.inputs, sizeof(float)); + + cuda_pull_array(in, in_cpu, layer.batch*layer.inputs); + cuda_pull_array(layer.delta_gpu, layer.delta, layer.batch*outputs); + backward_detection_layer(layer, in_cpu, delta_cpu); + cuda_push_array(delta, delta_cpu, layer.batch*layer.inputs); + + free(in_cpu); + free(delta_cpu); +} +#endif diff --git a/src/detection_layer.h b/src/detection_layer.h index fad0281e97ab0e10ed8f402828aea95de3732c3a..e7e9e2064d3db2d9dc2bc69e36302f946c41a848 100644 --- a/src/detection_layer.h +++ b/src/detection_layer.h @@ -3,38 +3,26 @@ typedef struct { int batch; - int h,w,c; - int n; - int size; - int stride; - - float *filters; - float *filter_updates; - float *filter_momentum; - - float *biases; - float *bias_updates; - float *bias_momentum; - - float *col_image; - float *delta; + int inputs; + int classes; + int coords; + int rescore; float *output; - + float *delta; #ifdef GPU - cl_mem filters_cl; - cl_mem filter_updates_cl; - cl_mem filter_momentum_cl; - - cl_mem biases_cl; - cl_mem bias_updates_cl; - cl_mem bias_momentum_cl; - - cl_mem col_image_cl; - cl_mem delta_cl; - cl_mem output_cl; + float * output_gpu; + float * delta_gpu; #endif +} detection_layer; - ACTIVATION activation; -} convolutional_layer; +detection_layer *make_detection_layer(int batch, int inputs, int classes, int coords, int rescore); +void forward_detection_layer(const detection_layer layer, float *in, float *truth); +void backward_detection_layer(const detection_layer layer, float *in, float *delta); +int get_detection_layer_output_size(detection_layer layer); + +#ifdef GPU +void forward_detection_layer_gpu(const detection_layer layer, float *in, float *truth); +void backward_detection_layer_gpu(detection_layer layer, float *in, float *delta); +#endif #endif diff --git a/src/image.c b/src/image.c index 53cf281c1f96ae241b092329fcdcfe0100f745de..ee7a823fdbdec2957c59811cdef51be2e7591246 100644 --- a/src/image.c +++ b/src/image.c @@ -13,7 +13,7 @@ float get_color(int c, int x, int max) int j = ceil(ratio); ratio -= i; float r = (1-ratio) * colors[i][c] + ratio*colors[j][c]; - printf("%f\n", r); + //printf("%f\n", r); return r; } diff --git a/src/network.c b/src/network.c index bf0d63f976d8fa11c2f4f72e91a91af79bd5f766..b60f05953572bc1c99e3be9c2a143ac5c3a6aa11 100644 --- a/src/network.c +++ b/src/network.c @@ -9,6 +9,7 @@ #include "connected_layer.h" #include "convolutional_layer.h" #include "deconvolutional_layer.h" +#include "detection_layer.h" #include "maxpool_layer.h" #include "cost_layer.h" #include "normalization_layer.h" @@ -29,6 +30,8 @@ char *get_layer_string(LAYER_TYPE a) return "maxpool"; case SOFTMAX: return "softmax"; + case DETECTION: + return "detection"; case NORMALIZATION: return "normalization"; case DROPOUT: @@ -76,6 +79,11 @@ void forward_network(network net, float *input, float *truth, int train) forward_deconvolutional_layer(layer, input); input = layer.output; } + else if(net.types[i] == DETECTION){ + detection_layer layer = *(detection_layer *)net.layers[i]; + forward_detection_layer(layer, input, truth); + input = layer.output; + } else if(net.types[i] == CONNECTED){ connected_layer layer = *(connected_layer *)net.layers[i]; forward_connected_layer(layer, input); @@ -152,6 +160,9 @@ float *get_network_output_layer(network net, int i) } else if(net.types[i] == MAXPOOL){ maxpool_layer layer = *(maxpool_layer *)net.layers[i]; return layer.output; + } else if(net.types[i] == DETECTION){ + detection_layer layer = *(detection_layer *)net.layers[i]; + return layer.output; } else if(net.types[i] == SOFTMAX){ softmax_layer layer = *(softmax_layer *)net.layers[i]; return layer.output; @@ -193,6 +204,9 @@ float *get_network_delta_layer(network net, int i) } else if(net.types[i] == SOFTMAX){ softmax_layer layer = *(softmax_layer *)net.layers[i]; return layer.delta; + } else if(net.types[i] == DETECTION){ + detection_layer layer = *(detection_layer *)net.layers[i]; + return layer.delta; } else if(net.types[i] == DROPOUT){ if(i == 0) return 0; return get_network_delta_layer(net, i-1); @@ -243,7 +257,7 @@ int get_predicted_class_network(network net) return max_index(out, k); } -void backward_network(network net, float *input) +void backward_network(network net, float *input, float *truth) { int i; float *prev_input; @@ -272,6 +286,10 @@ void backward_network(network net, float *input) dropout_layer layer = *(dropout_layer *)net.layers[i]; backward_dropout_layer(layer, prev_delta); } + else if(net.types[i] == DETECTION){ + detection_layer layer = *(detection_layer *)net.layers[i]; + backward_detection_layer(layer, prev_input, prev_delta); + } else if(net.types[i] == NORMALIZATION){ normalization_layer layer = *(normalization_layer *)net.layers[i]; if(i != 0) backward_normalization_layer(layer, prev_input, prev_delta); @@ -297,7 +315,7 @@ float train_network_datum(network net, float *x, float *y) if(gpu_index >= 0) return train_network_datum_gpu(net, x, y); #endif forward_network(net, x, y, 1); - backward_network(net, x); + backward_network(net, x, y); float error = get_network_cost(net); update_network(net); return error; @@ -351,7 +369,7 @@ float train_network_batch(network net, data d, int n) float *x = d.X.vals[index]; float *y = d.y.vals[index]; forward_network(net, x, y, 1); - backward_network(net, x); + backward_network(net, x, y); sum += get_network_cost(net); } update_network(net); @@ -381,7 +399,6 @@ void set_learning_network(network *net, float rate, float momentum, float decay) } } - void set_batch_network(network *net, int b) { net->batch = b; @@ -404,6 +421,9 @@ void set_batch_network(network *net, int b) } else if(net->types[i] == DROPOUT){ dropout_layer *layer = (dropout_layer *) net->layers[i]; layer->batch = b; + } else if(net->types[i] == DETECTION){ + detection_layer *layer = (detection_layer *) net->layers[i]; + layer->batch = b; } else if(net->types[i] == FREEWEIGHT){ freeweight_layer *layer = (freeweight_layer *) net->layers[i]; @@ -445,6 +465,9 @@ int get_network_input_size_layer(network net, int i) } else if(net.types[i] == DROPOUT){ dropout_layer layer = *(dropout_layer *) net.layers[i]; return layer.inputs; + } else if(net.types[i] == DETECTION){ + detection_layer layer = *(detection_layer *) net.layers[i]; + return layer.inputs; } else if(net.types[i] == CROP){ crop_layer layer = *(crop_layer *) net.layers[i]; return layer.c*layer.h*layer.w; @@ -473,6 +496,10 @@ int get_network_output_size_layer(network net, int i) image output = get_deconvolutional_image(layer); return output.h*output.w*output.c; } + else if(net.types[i] == DETECTION){ + detection_layer layer = *(detection_layer *)net.layers[i]; + return get_detection_layer_output_size(layer); + } else if(net.types[i] == MAXPOOL){ maxpool_layer layer = *(maxpool_layer *)net.layers[i]; image output = get_maxpool_image(layer); diff --git a/src/network.h b/src/network.h index 66873d2c0ffe3f0e1dc2984255835e6c4f037c0e..d2fb346cee9c5a31be0f44b2bb24829213eac35a 100644 --- a/src/network.h +++ b/src/network.h @@ -11,6 +11,7 @@ typedef enum { CONNECTED, MAXPOOL, SOFTMAX, + DETECTION, NORMALIZATION, DROPOUT, FREEWEIGHT, @@ -48,7 +49,7 @@ char *get_layer_string(LAYER_TYPE a); network make_network(int n, int batch); void forward_network(network net, float *input, float *truth, int train); -void backward_network(network net, float *input); +void backward_network(network net, float *input, float *truth); void update_network(network net); float train_network(network net, data d); diff --git a/src/network_kernels.cu b/src/network_kernels.cu index b83d0566fa0c44cdb6869d2da27dfb254d27decc..928c7f95677fd0463607854e48aa147e8be601cd 100644 --- a/src/network_kernels.cu +++ b/src/network_kernels.cu @@ -9,6 +9,7 @@ extern "C" { #include "crop_layer.h" #include "connected_layer.h" +#include "detection_layer.h" #include "convolutional_layer.h" #include "deconvolutional_layer.h" #include "maxpool_layer.h" @@ -47,6 +48,11 @@ void forward_network_gpu(network net, float * input, float * truth, int train) forward_connected_layer_gpu(layer, input); input = layer.output_gpu; } + else if(net.types[i] == DETECTION){ + detection_layer layer = *(detection_layer *)net.layers[i]; + forward_detection_layer_gpu(layer, input, truth); + input = layer.output_gpu; + } else if(net.types[i] == MAXPOOL){ maxpool_layer layer = *(maxpool_layer *)net.layers[i]; forward_maxpool_layer_gpu(layer, input); @@ -73,7 +79,7 @@ void forward_network_gpu(network net, float * input, float * truth, int train) } } -void backward_network_gpu(network net, float * input) +void backward_network_gpu(network net, float * input, float *truth) { int i; float * prev_input; @@ -103,6 +109,10 @@ void backward_network_gpu(network net, float * input) connected_layer layer = *(connected_layer *)net.layers[i]; backward_connected_layer_gpu(layer, prev_input, prev_delta); } + else if(net.types[i] == DETECTION){ + detection_layer layer = *(detection_layer *)net.layers[i]; + backward_detection_layer_gpu(layer, prev_input, prev_delta); + } else if(net.types[i] == MAXPOOL){ maxpool_layer layer = *(maxpool_layer *)net.layers[i]; backward_maxpool_layer_gpu(layer, prev_delta); @@ -148,6 +158,10 @@ float * get_network_output_gpu_layer(network net, int i) deconvolutional_layer layer = *(deconvolutional_layer *)net.layers[i]; return layer.output_gpu; } + else if(net.types[i] == DETECTION){ + detection_layer layer = *(detection_layer *)net.layers[i]; + return layer.output_gpu; + } else if(net.types[i] == CONNECTED){ connected_layer layer = *(connected_layer *)net.layers[i]; return layer.output_gpu; @@ -176,6 +190,10 @@ float * get_network_delta_gpu_layer(network net, int i) convolutional_layer layer = *(convolutional_layer *)net.layers[i]; return layer.delta_gpu; } + else if(net.types[i] == DETECTION){ + detection_layer layer = *(detection_layer *)net.layers[i]; + return layer.delta_gpu; + } else if(net.types[i] == DECONVOLUTIONAL){ deconvolutional_layer layer = *(deconvolutional_layer *)net.layers[i]; return layer.delta_gpu; @@ -215,7 +233,7 @@ float train_network_datum_gpu(network net, float *x, float *y) forward_network_gpu(net, *net.input_gpu, *net.truth_gpu, 1); //printf("forw %f\n", sec(clock() - time)); //time = clock(); - backward_network_gpu(net, *net.input_gpu); + backward_network_gpu(net, *net.input_gpu, *net.truth_gpu); //printf("back %f\n", sec(clock() - time)); //time = clock(); update_network_gpu(net); @@ -244,6 +262,12 @@ float *get_network_output_layer_gpu(network net, int i) cuda_pull_array(layer.output_gpu, layer.output, layer.outputs*layer.batch); return layer.output; } + else if(net.types[i] == DETECTION){ + detection_layer layer = *(detection_layer *)net.layers[i]; + int outputs = get_detection_layer_output_size(layer); + cuda_pull_array(layer.output_gpu, layer.output, outputs*layer.batch); + return layer.output; + } else if(net.types[i] == MAXPOOL){ maxpool_layer layer = *(maxpool_layer *)net.layers[i]; return layer.output; diff --git a/src/option_list.c b/src/option_list.c index 76e10166369d5ab8c5c273c79a76bd7d8423efef..f5536e11cff2fef59dc66c6b0dc873c4b3ab4df8 100644 --- a/src/option_list.c +++ b/src/option_list.c @@ -53,6 +53,13 @@ int option_find_int(list *l, char *key, int def) return def; } +int option_find_int_quiet(list *l, char *key, int def) +{ + char *v = option_find(l, key); + if(v) return atoi(v); + return def; +} + float option_find_float_quiet(list *l, char *key, float def) { char *v = option_find(l, key); diff --git a/src/option_list.h b/src/option_list.h index fa795f3e2ea400af091b70c134716dfa24f22aab..4441462821550742834659da8d3c40c6859d4948 100644 --- a/src/option_list.h +++ b/src/option_list.h @@ -13,6 +13,7 @@ void option_insert(list *l, char *key, char *val); char *option_find(list *l, char *key); char *option_find_str(list *l, char *key, char *def); int option_find_int(list *l, char *key, int def); +int option_find_int_quiet(list *l, char *key, int def); float option_find_float(list *l, char *key, float def); float option_find_float_quiet(list *l, char *key, float def); void option_unused(list *l); diff --git a/src/parser.c b/src/parser.c index 850cc38832566d72f54ba9448a260be85b0133ff..53e1f569f99a0c3af13626860388174221e7b0bf 100644 --- a/src/parser.c +++ b/src/parser.c @@ -13,6 +13,7 @@ #include "normalization_layer.h" #include "softmax_layer.h" #include "dropout_layer.h" +#include "detection_layer.h" #include "freeweight_layer.h" #include "list.h" #include "option_list.h" @@ -32,6 +33,7 @@ int is_freeweight(section *s); int is_softmax(section *s); int is_crop(section *s); int is_cost(section *s); +int is_detection(section *s); int is_normalization(section *s); list *read_cfg(char *filename); @@ -204,6 +206,24 @@ softmax_layer *parse_softmax(list *options, network *net, int count) return layer; } +detection_layer *parse_detection(list *options, network *net, int count) +{ + int input; + if(count == 0){ + input = option_find_int(options, "input",1); + net->batch = option_find_int(options, "batch",1); + net->seen = option_find_int(options, "seen",0); + }else{ + input = get_network_output_size_layer(*net, count-1); + } + int coords = option_find_int(options, "coords", 1); + int classes = option_find_int(options, "classes", 1); + int rescore = option_find_int(options, "rescore", 1); + detection_layer *layer = make_detection_layer(net->batch, input, classes, coords, rescore); + option_unused(options); + return layer; +} + cost_layer *parse_cost(list *options, network *net, int count) { int input; @@ -368,6 +388,10 @@ network parse_network_cfg(char *filename) cost_layer *layer = parse_cost(options, &net, count); net.types[count] = COST; net.layers[count] = layer; + }else if(is_detection(s)){ + detection_layer *layer = parse_detection(options, &net, count); + net.types[count] = DETECTION; + net.layers[count] = layer; }else if(is_softmax(s)){ softmax_layer *layer = parse_softmax(options, &net, count); net.types[count] = SOFTMAX; @@ -410,6 +434,10 @@ int is_cost(section *s) { return (strcmp(s->type, "[cost]")==0); } +int is_detection(section *s) +{ + return (strcmp(s->type, "[detection]")==0); +} int is_deconvolutional(section *s) { return (strcmp(s->type, "[deconv]")==0 @@ -684,6 +712,13 @@ void print_softmax_cfg(FILE *fp, softmax_layer *l, network net, int count) fprintf(fp, "\n"); } +void print_detection_cfg(FILE *fp, detection_layer *l, network net, int count) +{ + fprintf(fp, "[detection]\n"); + fprintf(fp, "classes=%d\ncoords=%d\nrescore=%d\n", l->classes, l->coords, l->rescore); + fprintf(fp, "\n"); +} + void print_cost_cfg(FILE *fp, cost_layer *l, network net, int count) { fprintf(fp, "[cost]\ntype=%s\n", get_cost_string(l->type)); @@ -815,6 +850,8 @@ void save_network(network net, char *filename) print_normalization_cfg(fp, (normalization_layer *)net.layers[i], net, i); else if(net.types[i] == SOFTMAX) print_softmax_cfg(fp, (softmax_layer *)net.layers[i], net, i); + else if(net.types[i] == DETECTION) + print_detection_cfg(fp, (detection_layer *)net.layers[i], net, i); else if(net.types[i] == COST) print_cost_cfg(fp, (cost_layer *)net.layers[i], net, i); } diff --git a/src/softmax_layer.h b/src/softmax_layer.h index 1c1cdae839a652734e9a88f5d83f9600abe6dab4..3632c7471938b3e8012641e8ead512e396a67370 100644 --- a/src/softmax_layer.h +++ b/src/softmax_layer.h @@ -13,6 +13,7 @@ typedef struct { #endif } softmax_layer; +void softmax_array(float *input, int n, float *output); softmax_layer *make_softmax_layer(int batch, int groups, int inputs); void forward_softmax_layer(const softmax_layer layer, float *input); void backward_softmax_layer(const softmax_layer layer, float *delta);