From 0cd2379e2ccbad07bad3f88f8dc564776605802d Mon Sep 17 00:00:00 2001 From: Joseph Redmon <pjreddie@gmail.com> Date: Sat, 14 Nov 2015 12:34:17 -0800 Subject: [PATCH] some changes --- Makefile | 2 +- src/coco.c | 8 +- src/coco_kernels.cu | 109 ++++++++++++++++ src/data.c | 4 +- src/layer.h | 3 +- src/local_kernels.cu | 226 +++++++++++++++++++++++++++++++++ src/local_layer.c | 275 +++++++++++++++++++++++++++++++++++++++++ src/local_layer.h | 31 +++++ src/network.c | 9 ++ src/network_kernels.cu | 7 ++ src/nightmare.c | 7 +- src/parser.c | 50 ++++++++ src/yolo.c | 26 ++-- src/yolo_kernels.cu | 115 +++++++++++++---- 14 files changed, 817 insertions(+), 55 deletions(-) create mode 100644 src/coco_kernels.cu create mode 100644 src/local_kernels.cu create mode 100644 src/local_layer.c create mode 100644 src/local_layer.h diff --git a/Makefile b/Makefile index 44a193f..1b6aa80 100644 --- a/Makefile +++ b/Makefile @@ -34,7 +34,7 @@ CFLAGS+= -DGPU LDFLAGS+= -L/usr/local/cuda/lib64 -lcuda -lcudart -lcublas -lcurand endif -OBJ=gemm.o utils.o cuda.o deconvolutional_layer.o convolutional_layer.o list.o image.o activations.o im2col.o col2im.o blas.o crop_layer.o dropout_layer.o maxpool_layer.o softmax_layer.o data.o matrix.o network.o connected_layer.o cost_layer.o parser.o option_list.o darknet.o detection_layer.o imagenet.o captcha.o route_layer.o writing.o box.o nightmare.o normalization_layer.o avgpool_layer.o coco.o dice.o yolo.o layer.o compare.o classifier.o +OBJ=gemm.o utils.o cuda.o deconvolutional_layer.o convolutional_layer.o list.o image.o activations.o im2col.o col2im.o blas.o crop_layer.o dropout_layer.o maxpool_layer.o softmax_layer.o data.o matrix.o network.o connected_layer.o cost_layer.o parser.o option_list.o darknet.o detection_layer.o imagenet.o captcha.o route_layer.o writing.o box.o nightmare.o normalization_layer.o avgpool_layer.o coco.o dice.o yolo.o layer.o compare.o classifier.o local_layer.o ifeq ($(GPU), 1) OBJ+=convolutional_kernels.o deconvolutional_kernels.o activation_kernels.o im2col_kernels.o col2im_kernels.o blas_kernels.o crop_layer_kernels.o dropout_layer_kernels.o maxpool_layer_kernels.o softmax_layer_kernels.o network_kernels.o avgpool_layer_kernels.o yolo_kernels.o endif diff --git a/src/coco.c b/src/coco.c index aadf09d..cef6ade 100644 --- a/src/coco.c +++ b/src/coco.c @@ -15,7 +15,7 @@ char *coco_classes[] = {"person","bicycle","car","motorcycle","airplane","bus"," int coco_ids[] = {1,2,3,4,5,6,7,8,9,10,11,13,14,15,16,17,18,19,20,21,22,23,24,25,27,28,31,32,33,34,35,36,37,38,39,40,41,42,43,44,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,64,65,67,70,72,73,74,75,76,77,78,79,80,81,82,84,85,86,87,88,89,90}; -void draw_coco(image im, int num, float thresh, box *boxes, float **probs, char *label) +void draw_coco(image im, int num, float thresh, box *boxes, float **probs) { int classes = 80; int i; @@ -38,7 +38,6 @@ void draw_coco(image im, int num, float thresh, box *boxes, float **probs, char draw_box_width(im, left, top, right, bot, width, red, green, blue); } } - show_image(im, label); } void train_coco(char *cfgfile, char *weightfile) @@ -215,7 +214,7 @@ void validate_coco(char *cfgfile, char *weightfile) int i=0; int t; - float thresh = .001; + float thresh = .01; int nms = 1; float iou_thresh = .5; @@ -393,7 +392,8 @@ void test_coco(char *cfgfile, char *weightfile, char *filename, float thresh) float *predictions = network_predict(net, X); printf("%s: Predicted in %f seconds.\n", input, sec(clock()-time)); convert_coco_detections(predictions, l.classes, l.n, l.sqrt, l.side, 1, 1, thresh, probs, boxes, 0); - draw_coco(im, l.side*l.side*l.n, thresh, boxes, probs, "predictions"); + draw_coco(im, l.side*l.side*l.n, thresh, boxes, probs); + show_image(im, "predictions"); show_image(sized, "resized"); free_image(im); diff --git a/src/coco_kernels.cu b/src/coco_kernels.cu new file mode 100644 index 0000000..9c201c0 --- /dev/null +++ b/src/coco_kernels.cu @@ -0,0 +1,109 @@ +extern "C" { +#include "network.h" +#include "detection_layer.h" +#include "cost_layer.h" +#include "utils.h" +#include "parser.h" +#include "box.h" +#include "image.h" +} + +#ifdef OPENCV +#include "opencv2/highgui/highgui.hpp" +#include "opencv2/imgproc/imgproc.hpp" +extern "C" image ipl_to_image(IplImage* src); +extern "C" void convert_coco_detections(float *predictions, int classes, int num, int square, int side, int w, int h, float thresh, float **probs, box *boxes, int only_objectness); +extern "C" void draw_coco(image im, int num, float thresh, box *boxes, float **probs); + +static float **probs; +static box *boxes; +static network net; +static image in ; +static image in_s ; +static image det ; +static image det_s; +static image disp ; +static cv::VideoCapture cap(0); + +void *fetch_in_thread(void *ptr) +{ + cv::Mat frame_m; + cap >> frame_m; + IplImage frame = frame_m; + in = ipl_to_image(&frame); + rgbgr_image(in); + in_s = resize_image(in, net.w, net.h); + return 0; +} + +void *detect_in_thread(void *ptr) +{ + float nms = .4; + float thresh = .2; + + detection_layer l = net.layers[net.n-1]; + float *X = det_s.data; + float *predictions = network_predict(net, X); + free_image(det_s); + convert_coco_detections(predictions, l.classes, l.n, l.sqrt, l.side, 1, 1, thresh, probs, boxes, 0); + if (nms > 0) do_nms(boxes, probs, l.side*l.side*l.n, l.classes, nms); + printf("\033[2J"); + printf("\033[1;1H"); + printf("\nObjects:\n\n"); + draw_coco(det, l.side*l.side*l.n, thresh, boxes, probs); + return 0; +} + +extern "C" void demo_coco(char *cfgfile, char *weightfile, float thresh) +{ + printf("YOLO demo\n"); + net = parse_network_cfg(cfgfile); + if(weightfile){ + load_weights(&net, weightfile); + } + set_batch_network(&net, 1); + + srand(2222222); + + if(!cap.isOpened()) error("Couldn't connect to webcam.\n"); + + detection_layer l = net.layers[net.n-1]; + int j; + + boxes = (box *)calloc(l.side*l.side*l.n, sizeof(box)); + probs = (float **)calloc(l.side*l.side*l.n, sizeof(float *)); + for(j = 0; j < l.side*l.side*l.n; ++j) probs[j] = (float *)calloc(l.classes, sizeof(float *)); + + pthread_t fetch_thread; + pthread_t detect_thread; + + fetch_in_thread(0); + det = in; + det_s = in_s; + + fetch_in_thread(0); + detect_in_thread(0); + disp = det; + det = in; + det_s = in_s; + + while(1){ + if(pthread_create(&fetch_thread, 0, fetch_in_thread, 0)) error("Thread creation failed"); + if(pthread_create(&detect_thread, 0, detect_in_thread, 0)) error("Thread creation failed"); + show_image(disp, "YOLO"); + free_image(disp); + cvWaitKey(1); + pthread_join(fetch_thread, 0); + pthread_join(detect_thread, 0); + + disp = det; + det = in; + det_s = in_s; + } +} +#else +extern "C" void demo_coco(char *cfgfile, char *weightfile, float thresh){ + fprintf(stderr, "YOLO-COCO demo needs OpenCV for webcam images.\n"); +} +#endif + diff --git a/src/data.c b/src/data.c index df15dc5..9b84c5a 100644 --- a/src/data.c +++ b/src/data.c @@ -574,9 +574,7 @@ pthread_t load_data_in_thread(load_args args) pthread_t thread; struct load_args *ptr = calloc(1, sizeof(struct load_args)); *ptr = args; - if(pthread_create(&thread, 0, load_thread, ptr)) { - error("Thread creation failed"); - } + if(pthread_create(&thread, 0, load_thread, ptr)) error("Thread creation failed"); return thread; } diff --git a/src/layer.h b/src/layer.h index 0137c8e..2a74437 100644 --- a/src/layer.h +++ b/src/layer.h @@ -15,7 +15,8 @@ typedef enum { ROUTE, COST, NORMALIZATION, - AVGPOOL + AVGPOOL, + LOCAL } LAYER_TYPE; typedef enum{ diff --git a/src/local_kernels.cu b/src/local_kernels.cu new file mode 100644 index 0000000..8717416 --- /dev/null +++ b/src/local_kernels.cu @@ -0,0 +1,226 @@ +extern "C" { +#include "local_layer.h" +#include "gemm.h" +#include "blas.h" +#include "im2col.h" +#include "col2im.h" +#include "utils.h" +#include "cuda.h" +} + +__global__ void scale_bias_kernel(float *output, float *biases, int n, int size) +{ + int offset = blockIdx.x * blockDim.x + threadIdx.x; + int filter = blockIdx.y; + int batch = blockIdx.z; + + if(offset < size) output[(batch*n+filter)*size + offset] *= biases[filter]; +} + +void scale_bias_gpu(float *output, float *biases, int batch, int n, int size) +{ + dim3 dimGrid((size-1)/BLOCK + 1, n, batch); + dim3 dimBlock(BLOCK, 1, 1); + + scale_bias_kernel<<<dimGrid, dimBlock>>>(output, biases, n, size); + check_error(cudaPeekAtLastError()); +} + +__global__ void backward_scale_kernel(float *x_norm, float *delta, int batch, int n, int size, float *scale_updates) +{ + __shared__ float part[BLOCK]; + int i,b; + int filter = blockIdx.x; + int p = threadIdx.x; + float sum = 0; + for(b = 0; b < batch; ++b){ + for(i = 0; i < size; i += BLOCK){ + int index = p + i + size*(filter + n*b); + sum += (p+i < size) ? delta[index]*x_norm[index] : 0; + } + } + part[p] = sum; + __syncthreads(); + if (p == 0) { + for(i = 0; i < BLOCK; ++i) scale_updates[filter] += part[i]; + } +} + +void backward_scale_gpu(float *x_norm, float *delta, int batch, int n, int size, float *scale_updates) +{ + backward_scale_kernel<<<n, BLOCK>>>(x_norm, delta, batch, n, size, scale_updates); + check_error(cudaPeekAtLastError()); +} + +__global__ void add_bias_kernel(float *output, float *biases, int n, int size) +{ + int offset = blockIdx.x * blockDim.x + threadIdx.x; + int filter = blockIdx.y; + int batch = blockIdx.z; + + if(offset < size) output[(batch*n+filter)*size + offset] += biases[filter]; +} + +void add_bias_gpu(float *output, float *biases, int batch, int n, int size) +{ + dim3 dimGrid((size-1)/BLOCK + 1, n, batch); + dim3 dimBlock(BLOCK, 1, 1); + + add_bias_kernel<<<dimGrid, dimBlock>>>(output, biases, n, size); + check_error(cudaPeekAtLastError()); +} + +__global__ void backward_bias_kernel(float *bias_updates, float *delta, int batch, int n, int size) +{ + __shared__ float part[BLOCK]; + int i,b; + int filter = blockIdx.x; + int p = threadIdx.x; + float sum = 0; + for(b = 0; b < batch; ++b){ + for(i = 0; i < size; i += BLOCK){ + int index = p + i + size*(filter + n*b); + sum += (p+i < size) ? delta[index] : 0; + } + } + part[p] = sum; + __syncthreads(); + if (p == 0) { + for(i = 0; i < BLOCK; ++i) bias_updates[filter] += part[i]; + } +} + +void backward_bias_gpu(float *bias_updates, float *delta, int batch, int n, int size) +{ + backward_bias_kernel<<<n, BLOCK>>>(bias_updates, delta, batch, n, size); + check_error(cudaPeekAtLastError()); +} + +void forward_local_layer_gpu(local_layer l, network_state state) +{ + int i; + int m = l.n; + int k = l.size*l.size*l.c; + int n = local_out_height(l)* + local_out_width(l); + + fill_ongpu(l.outputs*l.batch, 0, l.output_gpu, 1); + for(i = 0; i < l.batch; ++i){ + im2col_ongpu(state.input + i*l.c*l.h*l.w, l.c, l.h, l.w, l.size, l.stride, l.pad, l.col_image_gpu); + float * a = l.filters_gpu; + float * b = l.col_image_gpu; + float * c = l.output_gpu; + gemm_ongpu(0,0,m,n,k,1.,a,k,b,n,1.,c+i*m*n,n); + } + + if(l.batch_normalize){ + if(state.train){ + fast_mean_gpu(l.output_gpu, l.batch, l.n, l.out_h*l.out_w, l.spatial_mean_gpu, l.mean_gpu); + fast_variance_gpu(l.output_gpu, l.mean_gpu, l.batch, l.n, l.out_h*l.out_w, l.spatial_variance_gpu, l.variance_gpu); + + scal_ongpu(l.n, .95, l.rolling_mean_gpu, 1); + axpy_ongpu(l.n, .05, l.mean_gpu, 1, l.rolling_mean_gpu, 1); + scal_ongpu(l.n, .95, l.rolling_variance_gpu, 1); + axpy_ongpu(l.n, .05, l.variance_gpu, 1, l.rolling_variance_gpu, 1); + + // cuda_pull_array(l.variance_gpu, l.mean, l.n); + // printf("%f\n", l.mean[0]); + + copy_ongpu(l.outputs*l.batch, l.output_gpu, 1, l.x_gpu, 1); + normalize_gpu(l.output_gpu, l.mean_gpu, l.variance_gpu, l.batch, l.n, l.out_h*l.out_w); + copy_ongpu(l.outputs*l.batch, l.output_gpu, 1, l.x_norm_gpu, 1); + } else { + normalize_gpu(l.output_gpu, l.rolling_mean_gpu, l.rolling_variance_gpu, l.batch, l.n, l.out_h*l.out_w); + } + + scale_bias_gpu(l.output_gpu, l.scales_gpu, l.batch, l.n, l.out_h*l.out_w); + } + add_bias_gpu(l.output_gpu, l.biases_gpu, l.batch, l.n, n); + + activate_array_ongpu(l.output_gpu, m*n*l.batch, l.activation); +} + +void backward_local_layer_gpu(local_layer l, network_state state) +{ + int i; + int m = l.n; + int n = l.size*l.size*l.c; + int k = local_out_height(l)* + local_out_width(l); + + gradient_array_ongpu(l.output_gpu, m*k*l.batch, l.activation, l.delta_gpu); + + backward_bias_gpu(l.bias_updates_gpu, l.delta_gpu, l.batch, l.n, k); + + if(l.batch_normalize){ + backward_scale_gpu(l.x_norm_gpu, l.delta_gpu, l.batch, l.n, l.out_w*l.out_h, l.scale_updates_gpu); + + scale_bias_gpu(l.delta_gpu, l.scales_gpu, l.batch, l.n, l.out_h*l.out_w); + + fast_mean_delta_gpu(l.delta_gpu, l.variance_gpu, l.batch, l.n, l.out_w*l.out_h, l.spatial_mean_delta_gpu, l.mean_delta_gpu); + fast_variance_delta_gpu(l.x_gpu, l.delta_gpu, l.mean_gpu, l.variance_gpu, l.batch, l.n, l.out_w*l.out_h, l.spatial_variance_delta_gpu, l.variance_delta_gpu); + normalize_delta_gpu(l.x_gpu, l.mean_gpu, l.variance_gpu, l.mean_delta_gpu, l.variance_delta_gpu, l.batch, l.n, l.out_w*l.out_h, l.delta_gpu); + } + + for(i = 0; i < l.batch; ++i){ + float * a = l.delta_gpu; + float * b = l.col_image_gpu; + float * c = l.filter_updates_gpu; + + im2col_ongpu(state.input + i*l.c*l.h*l.w, l.c, l.h, l.w, l.size, l.stride, l.pad, l.col_image_gpu); + gemm_ongpu(0,1,m,n,k,1,a + i*m*k,k,b,k,1,c,n); + + if(state.delta){ + float * a = l.filters_gpu; + float * b = l.delta_gpu; + float * c = l.col_image_gpu; + + gemm_ongpu(1,0,n,k,m,1,a,n,b + i*k*m,k,0,c,k); + + col2im_ongpu(l.col_image_gpu, l.c, l.h, l.w, l.size, l.stride, l.pad, state.delta + i*l.c*l.h*l.w); + } + } +} + +void pull_local_layer(local_layer layer) +{ + cuda_pull_array(layer.filters_gpu, layer.filters, layer.c*layer.n*layer.size*layer.size); + cuda_pull_array(layer.biases_gpu, layer.biases, layer.n); + cuda_pull_array(layer.filter_updates_gpu, layer.filter_updates, layer.c*layer.n*layer.size*layer.size); + cuda_pull_array(layer.bias_updates_gpu, layer.bias_updates, layer.n); + if (layer.batch_normalize){ + cuda_pull_array(layer.scales_gpu, layer.scales, layer.n); + cuda_pull_array(layer.rolling_mean_gpu, layer.rolling_mean, layer.n); + cuda_pull_array(layer.rolling_variance_gpu, layer.rolling_variance, layer.n); + } +} + +void push_local_layer(local_layer layer) +{ + cuda_push_array(layer.filters_gpu, layer.filters, layer.c*layer.n*layer.size*layer.size); + cuda_push_array(layer.biases_gpu, layer.biases, layer.n); + cuda_push_array(layer.filter_updates_gpu, layer.filter_updates, layer.c*layer.n*layer.size*layer.size); + cuda_push_array(layer.bias_updates_gpu, layer.bias_updates, layer.n); + if (layer.batch_normalize){ + cuda_push_array(layer.scales_gpu, layer.scales, layer.n); + cuda_push_array(layer.rolling_mean_gpu, layer.rolling_mean, layer.n); + cuda_push_array(layer.rolling_variance_gpu, layer.rolling_variance, layer.n); + } +} + +void update_local_layer_gpu(local_layer layer, int batch, float learning_rate, float momentum, float decay) +{ + int size = layer.size*layer.size*layer.c*layer.n; + + axpy_ongpu(layer.n, learning_rate/batch, layer.bias_updates_gpu, 1, layer.biases_gpu, 1); + scal_ongpu(layer.n, momentum, layer.bias_updates_gpu, 1); + + axpy_ongpu(layer.n, learning_rate/batch, layer.scale_updates_gpu, 1, layer.scales_gpu, 1); + scal_ongpu(layer.n, momentum, layer.scale_updates_gpu, 1); + + axpy_ongpu(size, -decay*batch, layer.filters_gpu, 1, layer.filter_updates_gpu, 1); + axpy_ongpu(size, learning_rate/batch, layer.filter_updates_gpu, 1, layer.filters_gpu, 1); + scal_ongpu(size, momentum, layer.filter_updates_gpu, 1); +} + + diff --git a/src/local_layer.c b/src/local_layer.c new file mode 100644 index 0000000..c0f52cb --- /dev/null +++ b/src/local_layer.c @@ -0,0 +1,275 @@ +#include "local_layer.h" +#include "utils.h" +#include "im2col.h" +#include "col2im.h" +#include "blas.h" +#include "gemm.h" +#include <stdio.h> +#include <time.h> + +int local_out_height(local_layer l) +{ + int h = l.h; + if (!l.pad) h -= l.size; + else h -= 1; + return h/l.stride + 1; +} + +int local_out_width(local_layer l) +{ + int w = l.w; + if (!l.pad) w -= l.size; + else w -= 1; + return w/l.stride + 1; +} + +local_layer make_local_layer(int batch, int h, int w, int c, int n, int size, int stride, int pad, ACTIVATION activation) +{ + int i; + local_layer l = {0}; + l.type = LOCAL; + + l.h = h; + l.w = w; + l.c = c; + l.n = n; + l.batch = batch; + l.stride = stride; + l.size = size; + l.pad = pad; + + int out_h = local_out_height(l); + int out_w = local_out_width(l); + int locations = out_h*out_w; + l.out_h = out_h; + l.out_w = out_w; + l.out_c = n; + l.outputs = l.out_h * l.out_w * l.out_c; + l.inputs = l.w * l.h * l.c; + + l.filters = calloc(c*n*size*size*locations, sizeof(float)); + l.filter_updates = calloc(c*n*size*size*locations, sizeof(float)); + + l.biases = calloc(l.outputs, sizeof(float)); + l.bias_updates = calloc(l.outputs, sizeof(float)); + + // float scale = 1./sqrt(size*size*c); + float scale = sqrt(2./(size*size*c)); + for(i = 0; i < c*n*size*size; ++i) l.filters[i] = 2*scale*rand_uniform() - scale; + + l.col_image = calloc(out_h*out_w*size*size*c, sizeof(float)); + l.output = calloc(l.batch*out_h * out_w * n, sizeof(float)); + l.delta = calloc(l.batch*out_h * out_w * n, sizeof(float)); + +#ifdef GPU + l.filters_gpu = cuda_make_array(l.filters, c*n*size*size*locations); + l.filter_updates_gpu = cuda_make_array(l.filter_updates, c*n*size*size*locations); + + l.biases_gpu = cuda_make_array(l.biases, l.outputs); + l.bias_updates_gpu = cuda_make_array(l.bias_updates, l.outputs); + + l.col_image_gpu = cuda_make_array(l.col_image, out_h*out_w*size*size*c); + l.delta_gpu = cuda_make_array(l.delta, l.batch*out_h*out_w*n); + l.output_gpu = cuda_make_array(l.output, l.batch*out_h*out_w*n); + +#endif + l.activation = activation; + + fprintf(stderr, "Local Layer: %d x %d x %d image, %d filters -> %d x %d x %d image\n", h,w,c,n, out_h, out_w, n); + + return l; +} + +void forward_local_layer(const local_layer l, network_state state) +{ + int out_h = local_out_height(l); + int out_w = local_out_width(l); + int i, j; + int locations = out_h * out_w; + + for(i = 0; i < l.batch; ++i){ + copy_cpu(l.outputs, l.biases, 1, l.output + i*l.outputs, 1); + } + + for(i = 0; i < l.batch; ++i){ + float *input = state.input + i*l.w*l.h*l.c; + im2col_cpu(input, l.c, l.h, l.w, + l.size, l.stride, l.pad, l.col_image); + float *output = l.output + i*l.outputs; + for(j = 0; j < locations; ++j){ + float *a = l.filters + j*l.size*l.size*l.c*l.n; + float *b = l.col_image + j; + float *c = output + j; + + int m = l.n; + int n = 1; + int k = l.size*l.size*l.c; + + gemm(0,0,m,n,k,1,a,k,b,locations,1,c,locations); + } + } + activate_array(l.output, l.outputs*l.batch, l.activation); +} + +void backward_local_layer(local_layer l, network_state state) +{ + int i, j; + int locations = l.out_w*l.out_h; + + gradient_array(l.output, l.outputs*l.batch, l.activation, l.delta); + + for(i = 0; i < l.batch; ++i){ + axpy_cpu(l.outputs, 1, l.delta + i*l.outputs, 1, l.bias_updates, 1); + } + + for(i = 0; i < l.batch; ++i){ + float *input = state.input + i*l.w*l.h*l.c; + im2col_cpu(input, l.c, l.h, l.w, + l.size, l.stride, l.pad, l.col_image); + + for(j = 0; j < locations; ++j){ + float *a = l.delta + i*l.outputs + j; + float *b = l.col_image + j; + float *c = l.filter_updates + j*l.size*l.size*l.c*l.n; + int m = l.n; + int n = l.size*l.size*l.c; + int k = 1; + + gemm(0,1,m,n,k,1,a,locations,b,locations,1,c,n); + } + + if(state.delta){ + for(j = 0; j < locations; ++j){ + float *a = l.filters + j*l.size*l.size*l.c*l.n; + float *b = l.delta + i*l.outputs + j; + float *c = l.col_image + j; + + int m = l.size*l.size*l.c; + int n = 1; + int k = l.n; + + gemm(1,0,m,n,k,1,a,m,b,locations,0,c,locations); + } + + col2im_cpu(l.col_image, l.c, l.h, l.w, l.size, l.stride, l.pad, state.delta+i*l.c*l.h*l.w); + } + } +} + +void update_local_layer(local_layer l, int batch, float learning_rate, float momentum, float decay) +{ + int locations = l.out_w*l.out_h; + int size = l.size*l.size*l.c*l.n*locations; + axpy_cpu(l.outputs, learning_rate/batch, l.bias_updates, 1, l.biases, 1); + scal_cpu(l.outputs, momentum, l.bias_updates, 1); + + axpy_cpu(size, -decay*batch, l.filters, 1, l.filter_updates, 1); + axpy_cpu(size, learning_rate/batch, l.filter_updates, 1, l.filters, 1); + scal_cpu(size, momentum, l.filter_updates, 1); +} + +#ifdef GPU + +void forward_local_layer_gpu(const local_layer l, network_state state) +{ + int out_h = local_out_height(l); + int out_w = local_out_width(l); + int i, j; + int locations = out_h * out_w; + + for(i = 0; i < l.batch; ++i){ + copy_ongpu(l.outputs, l.biases_gpu, 1, l.output_gpu + i*l.outputs, 1); + } + + for(i = 0; i < l.batch; ++i){ + float *input = state.input + i*l.w*l.h*l.c; + im2col_ongpu(input, l.c, l.h, l.w, + l.size, l.stride, l.pad, l.col_image_gpu); + float *output = l.output_gpu + i*l.outputs; + for(j = 0; j < locations; ++j){ + float *a = l.filters_gpu + j*l.size*l.size*l.c*l.n; + float *b = l.col_image_gpu + j; + float *c = output + j; + + int m = l.n; + int n = 1; + int k = l.size*l.size*l.c; + + gemm_ongpu(0,0,m,n,k,1,a,k,b,locations,1,c,locations); + } + } + activate_array_ongpu(l.output_gpu, l.outputs*l.batch, l.activation); +} + +void backward_local_layer_gpu(local_layer l, network_state state) +{ + int i, j; + int locations = l.out_w*l.out_h; + + gradient_array_ongpu(l.output_gpu, l.outputs*l.batch, l.activation, l.delta_gpu); + for(i = 0; i < l.batch; ++i){ + axpy_ongpu(l.outputs, 1, l.delta_gpu + i*l.outputs, 1, l.bias_updates_gpu, 1); + } + + for(i = 0; i < l.batch; ++i){ + float *input = state.input + i*l.w*l.h*l.c; + im2col_ongpu(input, l.c, l.h, l.w, + l.size, l.stride, l.pad, l.col_image_gpu); + + for(j = 0; j < locations; ++j){ + float *a = l.delta_gpu + i*l.outputs + j; + float *b = l.col_image_gpu + j; + float *c = l.filter_updates_gpu + j*l.size*l.size*l.c*l.n; + int m = l.n; + int n = l.size*l.size*l.c; + int k = 1; + + gemm_ongpu(0,1,m,n,k,1,a,locations,b,locations,1,c,n); + } + + if(state.delta){ + for(j = 0; j < locations; ++j){ + float *a = l.filters_gpu + j*l.size*l.size*l.c*l.n; + float *b = l.delta_gpu + i*l.outputs + j; + float *c = l.col_image_gpu + j; + + int m = l.size*l.size*l.c; + int n = 1; + int k = l.n; + + gemm_ongpu(1,0,m,n,k,1,a,m,b,locations,0,c,locations); + } + + col2im_ongpu(l.col_image_gpu, l.c, l.h, l.w, l.size, l.stride, l.pad, state.delta+i*l.c*l.h*l.w); + } + } +} + +void update_local_layer_gpu(local_layer l, int batch, float learning_rate, float momentum, float decay) +{ + int locations = l.out_w*l.out_h; + int size = l.size*l.size*l.c*l.n*locations; + axpy_ongpu(l.outputs, learning_rate/batch, l.bias_updates_gpu, 1, l.biases_gpu, 1); + scal_ongpu(l.outputs, momentum, l.bias_updates_gpu, 1); + + axpy_ongpu(size, -decay*batch, l.filters_gpu, 1, l.filter_updates_gpu, 1); + axpy_ongpu(size, learning_rate/batch, l.filter_updates_gpu, 1, l.filters_gpu, 1); + scal_ongpu(size, momentum, l.filter_updates_gpu, 1); +} + +void pull_local_layer(local_layer l) +{ + int locations = l.out_w*l.out_h; + int size = l.size*l.size*l.c*l.n*locations; + cuda_pull_array(l.filters_gpu, l.filters, size); + cuda_pull_array(l.biases_gpu, l.biases, l.outputs); +} + +void push_local_layer(local_layer l) +{ + int locations = l.out_w*l.out_h; + int size = l.size*l.size*l.c*l.n*locations; + cuda_push_array(l.filters_gpu, l.filters, size); + cuda_push_array(l.biases_gpu, l.biases, l.outputs); +} +#endif diff --git a/src/local_layer.h b/src/local_layer.h new file mode 100644 index 0000000..675a5fb --- /dev/null +++ b/src/local_layer.h @@ -0,0 +1,31 @@ +#ifndef LOCAL_LAYER_H +#define LOCAL_LAYER_H + +#include "cuda.h" +#include "params.h" +#include "image.h" +#include "activations.h" +#include "layer.h" + +typedef layer local_layer; + +#ifdef GPU +void forward_local_layer_gpu(local_layer layer, network_state state); +void backward_local_layer_gpu(local_layer layer, network_state state); +void update_local_layer_gpu(local_layer layer, int batch, float learning_rate, float momentum, float decay); + +void push_local_layer(local_layer layer); +void pull_local_layer(local_layer layer); +#endif + +local_layer make_local_layer(int batch, int h, int w, int c, int n, int size, int stride, int pad, ACTIVATION activation); + +void forward_local_layer(const local_layer layer, network_state state); +void backward_local_layer(local_layer layer, network_state state); +void update_local_layer(local_layer layer, int batch, float learning_rate, float momentum, float decay); + +void bias_output(float *output, float *biases, int batch, int n, int size); +void backward_bias(float *bias_updates, float *delta, int batch, int n, int size); + +#endif + diff --git a/src/network.c b/src/network.c index 9bcb264..6c7461d 100644 --- a/src/network.c +++ b/src/network.c @@ -8,6 +8,7 @@ #include "crop_layer.h" #include "connected_layer.h" +#include "local_layer.h" #include "convolutional_layer.h" #include "deconvolutional_layer.h" #include "detection_layer.h" @@ -59,6 +60,8 @@ char *get_layer_string(LAYER_TYPE a) switch(a){ case CONVOLUTIONAL: return "convolutional"; + case LOCAL: + return "local"; case DECONVOLUTIONAL: return "deconvolutional"; case CONNECTED: @@ -112,6 +115,8 @@ void forward_network(network net, network_state state) forward_convolutional_layer(l, state); } else if(l.type == DECONVOLUTIONAL){ forward_deconvolutional_layer(l, state); + } else if(l.type == LOCAL){ + forward_local_layer(l, state); } else if(l.type == NORMALIZATION){ forward_normalization_layer(l, state); } else if(l.type == DETECTION){ @@ -150,6 +155,8 @@ void update_network(network net) update_deconvolutional_layer(l, rate, net.momentum, net.decay); } else if(l.type == CONNECTED){ update_connected_layer(l, update_batch, rate, net.momentum, net.decay); + } else if(l.type == LOCAL){ + update_local_layer(l, update_batch, rate, net.momentum, net.decay); } } } @@ -219,6 +226,8 @@ void backward_network(network net, network_state state) if(i != 0) backward_softmax_layer(l, state); } else if(l.type == CONNECTED){ backward_connected_layer(l, state); + } else if(l.type == LOCAL){ + backward_local_layer(l, state); } else if(l.type == COST){ backward_cost_layer(l, state); } else if(l.type == ROUTE){ diff --git a/src/network_kernels.cu b/src/network_kernels.cu index 8561372..ffd5c59 100644 --- a/src/network_kernels.cu +++ b/src/network_kernels.cu @@ -19,6 +19,7 @@ extern "C" { #include "avgpool_layer.h" #include "normalization_layer.h" #include "cost_layer.h" +#include "local_layer.h" #include "softmax_layer.h" #include "dropout_layer.h" #include "route_layer.h" @@ -41,6 +42,8 @@ void forward_network_gpu(network net, network_state state) forward_convolutional_layer_gpu(l, state); } else if(l.type == DECONVOLUTIONAL){ forward_deconvolutional_layer_gpu(l, state); + } else if(l.type == LOCAL){ + forward_local_layer_gpu(l, state); } else if(l.type == DETECTION){ forward_detection_layer_gpu(l, state); } else if(l.type == CONNECTED){ @@ -85,6 +88,8 @@ void backward_network_gpu(network net, network_state state) backward_convolutional_layer_gpu(l, state); } else if(l.type == DECONVOLUTIONAL){ backward_deconvolutional_layer_gpu(l, state); + } else if(l.type == LOCAL){ + backward_local_layer_gpu(l, state); } else if(l.type == MAXPOOL){ if(i != 0) backward_maxpool_layer_gpu(l, state); } else if(l.type == AVGPOOL){ @@ -120,6 +125,8 @@ void update_network_gpu(network net) update_deconvolutional_layer_gpu(l, rate, net.momentum, net.decay); } else if(l.type == CONNECTED){ update_connected_layer_gpu(l, update_batch, rate, net.momentum, net.decay); + } else if(l.type == LOCAL){ + update_local_layer_gpu(l, update_batch, rate, net.momentum, net.decay); } } } diff --git a/src/nightmare.c b/src/nightmare.c index 0eb3ca1..1a78dd5 100644 --- a/src/nightmare.c +++ b/src/nightmare.c @@ -25,7 +25,7 @@ void calculate_loss(float *output, float *delta, int n, float thresh) } } -void optimize_picture(network *net, image orig, int max_layer, float scale, float rate, float thresh) +void optimize_picture(network *net, image orig, int max_layer, float scale, float rate, float thresh, int norm) { scale_image(orig, 2); translate_image(orig, -1); @@ -85,7 +85,7 @@ void optimize_picture(network *net, image orig, int max_layer, float scale, floa //rate = rate / abs_mean(out.data, out.w*out.h*out.c); - normalize_array(out.data, out.w*out.h*out.c); + if(norm) normalize_array(out.data, out.w*out.h*out.c); axpy_cpu(orig.w*orig.h*orig.c, rate, out.data, 1, orig.data, 1); /* @@ -123,6 +123,7 @@ void run_nightmare(int argc, char **argv) int max_layer = atoi(argv[5]); int range = find_int_arg(argc, argv, "-range", 1); + int norm = find_int_arg(argc, argv, "-norm", 1); int rounds = find_int_arg(argc, argv, "-rounds", 1); int iters = find_int_arg(argc, argv, "-iters", 10); int octaves = find_int_arg(argc, argv, "-octaves", 4); @@ -160,7 +161,7 @@ void run_nightmare(int argc, char **argv) fflush(stderr); int layer = max_layer + rand()%range - range/2; int octave = rand()%octaves; - optimize_picture(&net, im, layer, 1/pow(1.33333333, octave), rate, thresh); + optimize_picture(&net, im, layer, 1/pow(1.33333333, octave), rate, thresh, norm); } fprintf(stderr, "done\n"); if(0){ diff --git a/src/parser.c b/src/parser.c index b095294..277c6e2 100644 --- a/src/parser.c +++ b/src/parser.c @@ -15,6 +15,7 @@ #include "dropout_layer.h" #include "detection_layer.h" #include "avgpool_layer.h" +#include "local_layer.h" #include "route_layer.h" #include "list.h" #include "option_list.h" @@ -27,6 +28,7 @@ typedef struct{ int is_network(section *s); int is_convolutional(section *s); +int is_local(section *s); int is_deconvolutional(section *s); int is_connected(section *s); int is_maxpool(section *s); @@ -107,6 +109,27 @@ deconvolutional_layer parse_deconvolutional(list *options, size_params params) return layer; } +local_layer parse_local(list *options, size_params params) +{ + int n = option_find_int(options, "filters",1); + int size = option_find_int(options, "size",1); + int stride = option_find_int(options, "stride",1); + int pad = option_find_int(options, "pad",0); + char *activation_s = option_find_str(options, "activation", "logistic"); + ACTIVATION activation = get_activation(activation_s); + + int batch,h,w,c; + h = params.h; + w = params.w; + c = params.c; + batch=params.batch; + if(!(h && w && c)) error("Layer before local layer must output image."); + + local_layer layer = make_local_layer(batch,h,w,c,n,size,stride,pad,activation); + + return layer; +} + convolutional_layer parse_convolutional(list *options, size_params params) { int n = option_find_int(options, "filters",1); @@ -402,6 +425,8 @@ network parse_network_cfg(char *filename) layer l = {0}; if(is_convolutional(s)){ l = parse_convolutional(options, params); + }else if(is_local(s)){ + l = parse_local(options, params); }else if(is_deconvolutional(s)){ l = parse_deconvolutional(options, params); }else if(is_connected(s)){ @@ -465,6 +490,10 @@ int is_detection(section *s) { return (strcmp(s->type, "[detection]")==0); } +int is_local(section *s) +{ + return (strcmp(s->type, "[local]")==0); +} int is_deconvolutional(section *s) { return (strcmp(s->type, "[deconv]")==0 @@ -626,6 +655,16 @@ void save_weights_upto(network net, char *filename, int cutoff) #endif fwrite(l.biases, sizeof(float), l.outputs, fp); fwrite(l.weights, sizeof(float), l.outputs*l.inputs, fp); + } if(l.type == LOCAL){ +#ifdef GPU + if(gpu_index >= 0){ + pull_local_layer(l); + } +#endif + int locations = l.out_w*l.out_h; + int size = l.size*l.size*l.c*l.n*locations; + fwrite(l.biases, sizeof(float), l.outputs, fp); + fwrite(l.filters, sizeof(float), size, fp); } } fclose(fp); @@ -684,6 +723,17 @@ void load_weights_upto(network *net, char *filename, int cutoff) if(gpu_index >= 0){ push_connected_layer(l); } +#endif + } + if(l.type == LOCAL){ + int locations = l.out_w*l.out_h; + int size = l.size*l.size*l.c*l.n*locations; + fread(l.biases, sizeof(float), l.outputs, fp); + fread(l.filters, sizeof(float), size, fp); +#ifdef GPU + if(gpu_index >= 0){ + push_local_layer(l); + } #endif } } diff --git a/src/yolo.c b/src/yolo.c index 2abfa13..7da69f7 100644 --- a/src/yolo.c +++ b/src/yolo.c @@ -11,7 +11,7 @@ char *voc_names[] = {"aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"}; -void draw_yolo(image im, int num, float thresh, box *boxes, float **probs, char *label) +void draw_yolo(image im, int num, float thresh, box *boxes, float **probs) { int classes = 20; int i; @@ -20,8 +20,10 @@ void draw_yolo(image im, int num, float thresh, box *boxes, float **probs, char int class = max_index(probs[i], classes); float prob = probs[i][class]; if(prob > thresh){ - int width = pow(prob, 1./2.)*10; - printf("%f %s\n", prob, voc_names[class]); + int width = pow(prob, 1./2.)*10+1; + //width = 8; + printf("%s: %.2f\n", voc_names[class], prob); + class = class * 7 % 20; float red = get_color(0,class,classes); float green = get_color(1,class,classes); float blue = get_color(2,class,classes); @@ -41,7 +43,6 @@ void draw_yolo(image im, int num, float thresh, box *boxes, float **probs, char draw_box_width(im, left, top, right, bot, width, red, green, blue); } } - show_image(im, label); } void train_yolo(char *cfgfile, char *weightfile) @@ -97,21 +98,13 @@ void train_yolo(char *cfgfile, char *weightfile) printf("Loaded: %lf seconds\n", sec(clock()-time)); - /* - image im = float_to_image(net.w, net.h, 3, train.X.vals[113]); - image copy = copy_image(im); - draw_yolo(copy, train.y.vals[113], 7, "truth"); - cvWaitKey(0); - free_image(copy); - */ - time=clock(); float loss = train_network(net, train); if (avg_loss < 0) avg_loss = loss; avg_loss = avg_loss*.9 + loss*.1; printf("%d: %f, %f avg, %f rate, %lf seconds, %d images\n", i, loss, avg_loss, get_current_rate(net), sec(clock()-time), i*imgs); - if(i%1000==0){ + if(i%1000==0 || i == 600){ char buff[256]; sprintf(buff, "%s/%s_%d.weights", backup_directory, base, i); save_weights(net, buff); @@ -183,8 +176,8 @@ void validate_yolo(char *cfgfile, char *weightfile) srand(time(0)); char *base = "results/comp4_det_test_"; - //list *plist = get_paths("data/voc.2007.test"); - list *plist = get_paths("data/voc.2012.test"); + list *plist = get_paths("data/voc.2007.test"); + //list *plist = get_paths("data/voc.2012.test"); char **paths = (char **)list_to_array(plist); layer l = net.layers[net.n-1]; @@ -384,7 +377,8 @@ void test_yolo(char *cfgfile, char *weightfile, char *filename, float thresh) printf("%s: Predicted in %f seconds.\n", input, sec(clock()-time)); convert_yolo_detections(predictions, l.classes, l.n, l.sqrt, l.side, 1, 1, thresh, probs, boxes, 0); if (nms) do_nms_sort(boxes, probs, l.side*l.side*l.n, l.classes, nms); - draw_yolo(im, l.side*l.side*l.n, thresh, boxes, probs, "predictions"); + draw_yolo(im, l.side*l.side*l.n, thresh, boxes, probs); + show_image(im, "predictions"); show_image(sized, "resized"); free_image(im); diff --git a/src/yolo_kernels.cu b/src/yolo_kernels.cu index f02b7a2..487e9bd 100644 --- a/src/yolo_kernels.cu +++ b/src/yolo_kernels.cu @@ -6,6 +6,7 @@ extern "C" { #include "parser.h" #include "box.h" #include "image.h" +#include <sys/time.h> } #ifdef OPENCV @@ -13,48 +14,108 @@ extern "C" { #include "opencv2/imgproc/imgproc.hpp" extern "C" image ipl_to_image(IplImage* src); extern "C" void convert_yolo_detections(float *predictions, int classes, int num, int square, int side, int w, int h, float thresh, float **probs, box *boxes, int only_objectness); -extern "C" void draw_yolo(image im, int num, float thresh, box *boxes, float **probs, char *label); +extern "C" void draw_yolo(image im, int num, float thresh, box *boxes, float **probs); + +static float **probs; +static box *boxes; +static network net; +static image in ; +static image in_s ; +static image det ; +static image det_s; +static image disp ; +static cv::VideoCapture cap; +static float fps = 0; + +void *fetch_in_thread(void *ptr) +{ + cv::Mat frame_m; + cap >> frame_m; + IplImage frame = frame_m; + in = ipl_to_image(&frame); + rgbgr_image(in); + in_s = resize_image(in, net.w, net.h); + return 0; +} + +void *detect_in_thread(void *ptr) +{ + float nms = .4; + float thresh = .2; + + detection_layer l = net.layers[net.n-1]; + float *X = det_s.data; + float *predictions = network_predict(net, X); + free_image(det_s); + convert_yolo_detections(predictions, l.classes, l.n, l.sqrt, l.side, 1, 1, thresh, probs, boxes, 0); + if (nms > 0) do_nms(boxes, probs, l.side*l.side*l.n, l.classes, nms); + printf("\033[2J"); + printf("\033[1;1H"); + printf("\nFPS:%.0f\n",fps); + printf("Objects:\n\n"); + draw_yolo(det, l.side*l.side*l.n, thresh, boxes, probs); + return 0; +} extern "C" void demo_yolo(char *cfgfile, char *weightfile, float thresh) { - network net = parse_network_cfg(cfgfile); + printf("YOLO demo\n"); + net = parse_network_cfg(cfgfile); if(weightfile){ load_weights(&net, weightfile); } - detection_layer l = net.layers[net.n-1]; - cv::VideoCapture cap(0); - set_batch_network(&net, 1); + srand(2222222); - float nms = .4; + + cv::VideoCapture cam(0); + cap = cam; + if(!cap.isOpened()) error("Couldn't connect to webcam.\n"); + + detection_layer l = net.layers[net.n-1]; int j; - box *boxes = (box *)calloc(l.side*l.side*l.n, sizeof(box)); - float **probs = (float **)calloc(l.side*l.side*l.n, sizeof(float *)); + + boxes = (box *)calloc(l.side*l.side*l.n, sizeof(box)); + probs = (float **)calloc(l.side*l.side*l.n, sizeof(float *)); for(j = 0; j < l.side*l.side*l.n; ++j) probs[j] = (float *)calloc(l.classes, sizeof(float *)); + pthread_t fetch_thread; + pthread_t detect_thread; + + fetch_in_thread(0); + det = in; + det_s = in_s; + + fetch_in_thread(0); + detect_in_thread(0); + disp = det; + det = in; + det_s = in_s; + while(1){ - cv::Mat frame_m; - cap >> frame_m; - IplImage frame = frame_m; - image im = ipl_to_image(&frame); - rgbgr_image(im); - - image sized = resize_image(im, net.w, net.h); - float *X = sized.data; - float *predictions = network_predict(net, X); - convert_yolo_detections(predictions, l.classes, l.n, l.sqrt, l.side, 1, 1, thresh, probs, boxes, 0); - if (nms > 0) do_nms(boxes, probs, l.side*l.side*l.n, l.classes, nms); - printf("\033[2J"); - printf("\033[1;1H"); - printf("\nObjects:\n\n"); - draw_yolo(im, l.side*l.side*l.n, thresh, boxes, probs, "predictions"); - - free_image(im); - free_image(sized); + struct timeval tval_before, tval_after, tval_result; + gettimeofday(&tval_before, NULL); + if(pthread_create(&fetch_thread, 0, fetch_in_thread, 0)) error("Thread creation failed"); + if(pthread_create(&detect_thread, 0, detect_in_thread, 0)) error("Thread creation failed"); + show_image(disp, "YOLO"); + free_image(disp); cvWaitKey(1); + pthread_join(fetch_thread, 0); + pthread_join(detect_thread, 0); + + disp = det; + det = in; + det_s = in_s; + + gettimeofday(&tval_after, NULL); + timersub(&tval_after, &tval_before, &tval_result); + float curr = 1000000.f/((long int)tval_result.tv_usec); + fps = .9*fps + .1*curr; } } #else -extern "C" void demo_yolo(char *cfgfile, char *weightfile, float thresh){} +extern "C" void demo_yolo(char *cfgfile, char *weightfile, float thresh){ + fprintf(stderr, "YOLO demo needs OpenCV for webcam images.\n"); +} #endif -- GitLab