diff --git a/.gitignore b/.gitignore
index 12d2c999c2a14558a04c700f83ac359c17001c36..9057c39dd0f4ca6011f279e52d22028ccc68052a 100644
--- a/.gitignore
+++ b/.gitignore
@@ -2,12 +2,18 @@
 *.dSYM
 *.csv
 *.out
+*.png
+*.sh
 mnist/
+data/
+caffe/
+grasp/
 images/
 opencv/
 convnet/
 decaf/
 submission/
+cfg/
 darknet
 
 # OS Generated #
diff --git a/Makefile b/Makefile
index 6e7ecf752ce7a5ba4d1061996c6f7d01ed0cdb98..12432b9280de1fd6bd5c3eec80cda9b4eb9f4cce 100644
--- a/Makefile
+++ b/Makefile
@@ -9,7 +9,7 @@ OBJDIR=./obj/
 CC=gcc
 NVCC=nvcc
 OPTS=-O3
-LDFLAGS=`pkg-config --libs opencv` -lm -pthread
+LDFLAGS=`pkg-config --libs opencv` -lm -pthread -lstdc++
 COMMON=`pkg-config --cflags opencv` -I/usr/local/cuda/include/
 CFLAGS=-Wall -Wfatal-errors
 
@@ -25,7 +25,7 @@ CFLAGS+=-DGPU
 LDFLAGS+= -L/usr/local/cuda/lib64 -lcuda -lcudart -lcublas
 endif
 
-OBJ=gemm.o utils.o cuda.o deconvolutional_layer.o convolutional_layer.o list.o image.o activations.o im2col.o col2im.o blas.o crop_layer.o dropout_layer.o maxpool_layer.o softmax_layer.o data.o matrix.o network.o connected_layer.o cost_layer.o normalization_layer.o parser.o option_list.o darknet.o
+OBJ=gemm.o utils.o cuda.o deconvolutional_layer.o convolutional_layer.o list.o image.o activations.o im2col.o col2im.o blas.o crop_layer.o dropout_layer.o maxpool_layer.o softmax_layer.o data.o matrix.o network.o connected_layer.o cost_layer.o normalization_layer.o parser.o option_list.o darknet.o detection_layer.o
 ifeq ($(GPU), 1) 
 OBJ+=convolutional_kernels.o deconvolutional_kernels.o activation_kernels.o im2col_kernels.o col2im_kernels.o blas_kernels.o crop_layer_kernels.o dropout_layer_kernels.o maxpool_layer_kernels.o softmax_layer_kernels.o network_kernels.o
 endif
diff --git a/src/cost_layer.c b/src/cost_layer.c
index 34c8fb59c5cccbb0ccbdb91ea4a15b1fa253019f..815827510be29d9af8ba45530f147cd87fcf2899 100644
--- a/src/cost_layer.c
+++ b/src/cost_layer.c
@@ -10,7 +10,6 @@
 COST_TYPE get_cost_type(char *s)
 {
     if (strcmp(s, "sse")==0) return SSE;
-    if (strcmp(s, "detection")==0) return DETECTION;
     fprintf(stderr, "Couldn't find activation function %s, going with SSE\n", s);
     return SSE;
 }
@@ -20,8 +19,6 @@ char *get_cost_string(COST_TYPE a)
     switch(a){
         case SSE:
             return "sse";
-        case DETECTION:
-            return "detection";
     }
     return "sse";
 }
@@ -41,17 +38,20 @@ cost_layer *make_cost_layer(int batch, int inputs, COST_TYPE type)
     return layer;
 }
 
+void pull_cost_layer(cost_layer layer)
+{
+    cuda_pull_array(layer.delta_gpu, layer.delta, layer.batch*layer.inputs);
+}
+void push_cost_layer(cost_layer layer)
+{
+    cuda_push_array(layer.delta_gpu, layer.delta, layer.batch*layer.inputs);
+}
+
 void forward_cost_layer(cost_layer layer, float *input, float *truth)
 {
     if (!truth) return;
     copy_cpu(layer.batch*layer.inputs, truth, 1, layer.delta, 1);
     axpy_cpu(layer.batch*layer.inputs, -1, input, 1, layer.delta, 1);
-    if(layer.type == DETECTION){
-        int i;
-        for(i = 0; i < layer.batch*layer.inputs; ++i){
-            if((i%25) && !truth[(i/25)*25]) layer.delta[i] = 0;
-        }
-    }
     *(layer.output) = dot_cpu(layer.batch*layer.inputs, layer.delta, 1, layer.delta, 1);
     //printf("cost: %f\n", *layer.output);
 }
@@ -66,14 +66,21 @@ void backward_cost_layer(const cost_layer layer, float *input, float *delta)
 void forward_cost_layer_gpu(cost_layer layer, float * input, float * truth)
 {
     if (!truth) return;
+    
+    /*
+    float *in = calloc(layer.inputs*layer.batch, sizeof(float));
+    float *t = calloc(layer.inputs*layer.batch, sizeof(float));
+    cuda_pull_array(input, in, layer.batch*layer.inputs);
+    cuda_pull_array(truth, t, layer.batch*layer.inputs);
+    forward_cost_layer(layer, in, t);
+    cuda_push_array(layer.delta_gpu, layer.delta, layer.batch*layer.inputs);
+    free(in);
+    free(t);
+    */
 
     copy_ongpu(layer.batch*layer.inputs, truth, 1, layer.delta_gpu, 1);
     axpy_ongpu(layer.batch*layer.inputs, -1, input, 1, layer.delta_gpu, 1);
 
-    if(layer.type==DETECTION){
-        mask_ongpu(layer.inputs*layer.batch, layer.delta_gpu, truth, 25);
-    }
-
     cuda_pull_array(layer.delta_gpu, layer.delta, layer.batch*layer.inputs);
     *(layer.output) = dot_cpu(layer.batch*layer.inputs, layer.delta, 1, layer.delta, 1);
     //printf("cost: %f\n", *layer.output);
diff --git a/src/cost_layer.h b/src/cost_layer.h
index e58aae1fa0545c0e3e93377f6211d4aa24773910..0855405276b6aa64956c547d46481b06e501a598 100644
--- a/src/cost_layer.h
+++ b/src/cost_layer.h
@@ -2,12 +2,14 @@
 #define COST_LAYER_H
 
 typedef enum{
-    SSE, DETECTION
+    SSE
 } COST_TYPE;
 
 typedef struct {
     int inputs;
     int batch;
+    int coords;
+    int classes;
     float *delta;
     float *output;
     COST_TYPE type;
diff --git a/src/cuda.c b/src/cuda.c
index 8849fb1fcdba3f6cd10792e633167d2fcbe532c0..c9142905af50ea4e94c748dae96b5affebb76f74 100644
--- a/src/cuda.c
+++ b/src/cuda.c
@@ -5,6 +5,7 @@ int gpu_index = 0;
 #include "cuda.h"
 #include "utils.h"
 #include "blas.h"
+#include "assert.h"
 #include <stdlib.h>
 
 
@@ -15,6 +16,7 @@ void check_error(cudaError_t status)
         const char *s = cudaGetErrorString(status);
         char buffer[256];
         printf("CUDA Error: %s\n", s);
+        assert(0);
         snprintf(buffer, 256, "CUDA Error: %s", s);
         error(buffer);
     } 
diff --git a/src/darknet.c b/src/darknet.c
index fc58f3d17134b8b695005ee98ec1863be0377022..413d7f2186f03c89df198255be2b422d85e02af7 100644
--- a/src/darknet.c
+++ b/src/darknet.c
@@ -36,42 +36,30 @@ char *class_names[] = {"aeroplane", "bicycle", "bird", "boat", "bottle", "bus",
 void draw_detection(image im, float *box, int side)
 {
     int classes = 20;
-    int elems = 4+classes+1;
+    int elems = 4+classes;
     int j;
     int r, c;
-    float amount[AMNT] = {0};
-    for(r = 0; r < side*side; ++r){
-        float val = box[r*elems];
-        for(j = 0; j < AMNT; ++j){
-            if(val > amount[j]) {
-                float swap = val;
-                val = amount[j];
-                amount[j] = swap;
-            }
-        }
-    }
-    float smallest = amount[AMNT-1];
 
     for(r = 0; r < side; ++r){
         for(c = 0; c < side; ++c){
             j = (r*side + c) * elems;
             //printf("%d\n", j);
             //printf("Prob: %f\n", box[j]);
-            if(box[j] >= smallest){
-                int class = max_index(box+j+1, classes);
-                int z;
-                for(z = 0; z < classes; ++z) printf("%f %s\n", box[j+1+z], class_names[z]);
-                printf("%f %s\n", box[j+1+class], class_names[class]);
+            int class = max_index(box+j, classes);
+            if(box[j+class] > .02 || 1){
+                //int z;
+                //for(z = 0; z < classes; ++z) printf("%f %s\n", box[j+z], class_names[z]);
+                printf("%f %s\n", box[j+class], class_names[class]);
                 float red = get_color(0,class,classes);
                 float green = get_color(1,class,classes);
                 float blue = get_color(2,class,classes);
 
                 j += classes;
                 int d = im.w/side;
-                int y = r*d+box[j+1]*d;
-                int x = c*d+box[j+2]*d;
-                int h = box[j+3]*im.h;
-                int w = box[j+4]*im.w;
+                int y = r*d+box[j]*d;
+                int x = c*d+box[j+1]*d;
+                int h = box[j+2]*im.h;
+                int w = box[j+3]*im.w;
                 draw_box(im, x-w/2, y-h/2, x+w/2, y+h/2,red,green,blue);
             }
         }
@@ -117,21 +105,22 @@ void train_detection_net(char *cfgfile, char *weightfile)
     data train, buffer;
     int im_dim = 512;
     int jitter = 64;
-    pthread_t load_thread = load_data_detection_thread(imgs, paths, plist->size, 20, im_dim, im_dim, 7, 7, jitter, &buffer);
+    int classes = 21;
+    pthread_t load_thread = load_data_detection_thread(imgs, paths, plist->size, classes, im_dim, im_dim, 7, 7, jitter, &buffer);
     clock_t time;
     while(1){
         i += 1;
         time=clock();
         pthread_join(load_thread, 0);
         train = buffer;
-        load_thread = load_data_detection_thread(imgs, paths, plist->size, 20, im_dim, im_dim, 7, 7, jitter, &buffer);
+        load_thread = load_data_detection_thread(imgs, paths, plist->size, classes, im_dim, im_dim, 7, 7, jitter, &buffer);
 
-/*
-        image im = float_to_image(im_dim - jitter, im_dim-jitter, 3, train.X.vals[0]);
-        draw_detection(im, train.y.vals[0], 7);
-        show_image(im, "truth");
-        cvWaitKey(0);
-        */
+        /*
+           image im = float_to_image(im_dim - jitter, im_dim-jitter, 3, train.X.vals[0]);
+           draw_detection(im, train.y.vals[0], 7);
+           show_image(im, "truth");
+           cvWaitKey(0);
+         */
 
         printf("Loaded: %lf seconds\n", sec(clock()-time));
         time=clock();
@@ -139,7 +128,7 @@ void train_detection_net(char *cfgfile, char *weightfile)
         net.seen += imgs;
         avg_loss = avg_loss*.9 + loss*.1;
         printf("%d: %f, %f avg, %lf seconds, %d images\n", i, loss, avg_loss, sec(clock()-time), i*imgs);
-        if(i%800==0){
+        if(i%100==0){
             char buff[256];
             sprintf(buff, "/home/pjreddie/imagenet_backup/%s_%d.weights",base, i);
             save_weights(net, buff);
@@ -161,7 +150,7 @@ void validate_detection_net(char *cfgfile, char *weightfile)
     char **paths = (char **)list_to_array(plist);
     int num_output = 1225;
     int im_size = 448;
-    int classes = 20;
+    int classes = 21;
 
     int m = plist->size;
     int i = 0;
@@ -180,30 +169,29 @@ void validate_detection_net(char *cfgfile, char *weightfile)
         num = (i+1)*m/splits - i*m/splits;
         char **part = paths+(i*m/splits);
         if(i != splits) load_thread = load_data_thread(part, num, 0, 0, num_output, im_size, im_size, &buffer);
- 
+
         fprintf(stderr, "%d: Loaded: %lf seconds\n", i, sec(clock()-time));
         matrix pred = network_predict_data(net, val);
         int j, k, class;
         for(j = 0; j < pred.rows; ++j){
-            for(k = 0; k < pred.cols; k += classes+4+1){
+            for(k = 0; k < pred.cols; k += classes+4){
 
                 /*
-                int z;
-                for(z = 0; z < 25; ++z) printf("%f, ", pred.vals[j][k+z]);
-                printf("\n");
-                */
+                   int z;
+                   for(z = 0; z < 25; ++z) printf("%f, ", pred.vals[j][k+z]);
+                   printf("\n");
+                 */
 
-                float p = pred.vals[j][k];
                 //if (pred.vals[j][k] > .001){
-                for(class = 0; class < classes; ++class){
-                    int index = (k)/(classes+4+1); 
+                for(class = 0; class < classes-1; ++class){
+                    int index = (k)/(classes+4); 
                     int r = index/7;
                     int c = index%7;
-                    float y = (r + pred.vals[j][k+1+classes])/7.;
-                    float x = (c + pred.vals[j][k+2+classes])/7.;
-                    float h = pred.vals[j][k+3+classes];
-                    float w = pred.vals[j][k+4+classes];
-                    printf("%d %d %f %f %f %f %f\n", (i-1)*m/splits + j, class, p*pred.vals[j][k+class+1], y, x, h, w);
+                    float y = (r + pred.vals[j][k+0+classes])/7.;
+                    float x = (c + pred.vals[j][k+1+classes])/7.;
+                    float h = pred.vals[j][k+2+classes];
+                    float w = pred.vals[j][k+3+classes];
+                    printf("%d %d %f %f %f %f %f\n", (i-1)*m/splits + j, class, pred.vals[j][k+class], y, x, h, w);
                 }
                 //}
             }
@@ -462,7 +450,7 @@ void test_detection(char *cfgfile, char *weightfile)
     if(weightfile){
         load_weights(&net, weightfile);
     }
-    int im_size = 224;
+    int im_size = 448;
     set_batch_network(&net, 1);
     srand(2222222);
     clock_t time;
diff --git a/src/data.c b/src/data.c
index a6b6db36a92785471b310eeef02fb4e83ebff315..0c935976863b520a3a56622de12611ea7d3f02ae 100644
--- a/src/data.c
+++ b/src/data.c
@@ -89,8 +89,7 @@ void fill_truth_detection(char *path, float *truth, int classes, int height, int
         float dw = (x - i*box_width)/box_width;
         float dh = (y - j*box_height)/box_height;
         //printf("%d %d %d %f %f\n", id, i, j, dh, dw);
-        int index = (i+j*num_width)*(4+classes+1);
-        truth[index++] = 1;
+        int index = (i+j*num_width)*(4+classes);
         truth[index+id] = 1;
         index += classes;
         truth[index++] = dh;
@@ -98,6 +97,12 @@ void fill_truth_detection(char *path, float *truth, int classes, int height, int
         truth[index++] = h*(height+jitter)/height;
         truth[index++] = w*(width+jitter)/width;
     }
+    int i, j;
+    for(i = 0; i < num_height*num_width*(4+classes); i += 4+classes){
+        int background = 1;
+        for(j = i; j < i+classes; ++j) if (truth[j]) background = 0;
+        truth[i+classes-1] = background;
+    }
     fclose(file);
 }
 
@@ -209,7 +214,7 @@ data load_data_detection_jitter_random(int n, char **paths, int m, int classes,
     data d;
     d.shallow = 0;
     d.X = load_image_paths(random_paths, n, h, w);
-    int k = nh*nw*(4+classes+1);
+    int k = nh*nw*(4+classes);
     d.y = make_matrix(n, k);
     for(i = 0; i < n; ++i){
         int dx = rand()%jitter;
diff --git a/src/detection_layer.c b/src/detection_layer.c
index 65370795add253143e2f913fcfcde6472ac126d7..bbc2e4ff95172f21c570c2e1e8674da937c7073b 100644
--- a/src/detection_layer.c
+++ b/src/detection_layer.c
@@ -1,72 +1,123 @@
-int detection_out_height(detection_layer layer)
+#include "detection_layer.h"
+#include "activations.h"
+#include "softmax_layer.h"
+#include "blas.h"
+#include "cuda.h"
+#include <stdio.h>
+#include <stdlib.h>
+
+int get_detection_layer_locations(detection_layer layer)
 {
-    return layer.size + layer.h*layer.stride;
+    return layer.inputs / (layer.classes+layer.coords+layer.rescore);
 }
 
-int detection_out_width(detection_layer layer)
+int get_detection_layer_output_size(detection_layer layer)
 {
-    return layer.size + layer.w*layer.stride;
+    return get_detection_layer_locations(layer)*(layer.classes+layer.coords);
 }
 
-detection_layer *make_detection_layer(int batch, int h, int w, int c, int n, int size, int stride, ACTIVATION activation)
+detection_layer *make_detection_layer(int batch, int inputs, int classes, int coords, int rescore)
 {
-    int i;
-    size = 2*(size/2)+1; //HA! And you thought you'd use an even sized filter...
     detection_layer *layer = calloc(1, sizeof(detection_layer));
-    layer->h = h;
-    layer->w = w;
-    layer->c = c;
-    layer->n = n;
-    layer->batch = batch;
-    layer->stride = stride;
-    layer->size = size;
-    assert(c%n == 0);
-
-    layer->filters = calloc(c*size*size, sizeof(float));
-    layer->filter_updates = calloc(c*size*size, sizeof(float));
-    layer->filter_momentum = calloc(c*size*size, sizeof(float));
-
-    float scale = 1./(size*size*c);
-    for(i = 0; i < c*n*size*size; ++i) layer->filters[i] = scale*(rand_uniform());
-
-    int out_h = detection_out_height(*layer);
-    int out_w = detection_out_width(*layer);
-
-    layer->output = calloc(layer->batch * out_h * out_w * n, sizeof(float));
-    layer->delta  = calloc(layer->batch * out_h * out_w * n, sizeof(float));
 
-    layer->activation = activation;
+    layer->batch = batch;
+    layer->inputs = inputs;
+    layer->classes = classes;
+    layer->coords = coords;
+    layer->rescore = rescore;
+    int outputs = get_detection_layer_output_size(*layer);
+    layer->output = calloc(batch*outputs, sizeof(float));
+    layer->delta = calloc(batch*outputs, sizeof(float));
+    #ifdef GPU
+    layer->output_gpu = cuda_make_array(0, batch*outputs);
+    layer->delta_gpu = cuda_make_array(0, batch*outputs);
+    #endif
 
-    fprintf(stderr, "Convolutional Layer: %d x %d x %d image, %d filters -> %d x %d x %d image\n", h,w,c,n, out_h, out_w, n);
+    fprintf(stderr, "Detection Layer\n");
     srand(0);
 
     return layer;
 }
 
-void forward_detection_layer(const detection_layer layer, float *in)
+void forward_detection_layer(const detection_layer layer, float *in, float *truth)
 {
-    int out_h = detection_out_height(layer);
-    int out_w = detection_out_width(layer);
-    int i,j,fh, fw,c;
-    memset(layer.output, 0, layer->batch*layer->n*out_h*out_w*sizeof(float));
-    for(c = 0; c < layer.c; ++c){
-        for(i = 0; i < layer.h; ++i){
-            for(j = 0; j < layer.w; ++j){
-                float val = layer->input[j+(i + c*layer.h)*layer.w];
-                for(fh = 0; fh < layer.size; ++fh){
-                    for(fw = 0; fw < layer.size; ++fw){
-                        int h = i*layer.stride + fh;
-                        int w = j*layer.stride + fw;
-                        layer.output[w+(h+c/n*out_h)*out_w] += val*layer->filters[fw+(fh+c*layer.size)*layer.size];
-                    }
-                }
-            }
+    int in_i = 0;
+    int out_i = 0;
+    int locations = get_detection_layer_locations(layer);
+    int i,j;
+    for(i = 0; i < layer.batch*locations; ++i){
+        int mask = (!truth || !truth[out_i + layer.classes - 1]);
+        float scale = 1;
+        if(layer.rescore) scale = in[in_i++];
+        for(j = 0; j < layer.classes; ++j){
+            layer.output[out_i++] = scale*in[in_i++];
+        }
+        softmax_array(layer.output + out_i - layer.classes, layer.classes, layer.output + out_i - layer.classes);
+        activate_array(layer.output+out_i, layer.coords, SIGMOID);
+        for(j = 0; j < layer.coords; ++j){
+            layer.output[out_i++] = mask*in[in_i++];
         }
+        //printf("%d\n", mask);
+        //for(j = 0; j < layer.classes+layer.coords; ++j) printf("%f ", layer.output[i*(layer.classes+layer.coords)+j]);
+        //printf ("\n");
     }
 }
 
-void backward_detection_layer(const detection_layer layer, float *delta)
+void backward_detection_layer(const detection_layer layer, float *in, float *delta)
 {
+    int locations = get_detection_layer_locations(layer);
+    int i,j;
+    int in_i = 0;
+    int out_i = 0;
+    for(i = 0; i < layer.batch*locations; ++i){
+        float scale = 1;
+        float latent_delta = 0;
+        if(layer.rescore) scale = in[in_i++];
+        for(j = 0; j < layer.classes; ++j){
+            latent_delta += in[in_i]*layer.delta[out_i];
+            delta[in_i++] = scale*layer.delta[out_i++];
+        }
+        
+        for(j = 0; j < layer.coords; ++j){
+            delta[in_i++] = layer.delta[out_i++];
+        }
+        gradient_array(in + in_i - layer.coords, layer.coords, SIGMOID, layer.delta + out_i - layer.coords); 
+        if(layer.rescore) delta[in_i-layer.coords-layer.classes-layer.rescore] = latent_delta;
+    }
 }
 
+#ifdef GPU
+
+void forward_detection_layer_gpu(const detection_layer layer, float *in, float *truth)
+{
+    int outputs = get_detection_layer_output_size(layer);
+    float *in_cpu = calloc(layer.batch*layer.inputs, sizeof(float));
+    float *truth_cpu = 0;
+    if(truth){
+        truth_cpu = calloc(layer.batch*outputs, sizeof(float));
+        cuda_pull_array(truth, truth_cpu, layer.batch*outputs);
+    }
+    cuda_pull_array(in, in_cpu, layer.batch*layer.inputs);
+    forward_detection_layer(layer, in_cpu, truth_cpu);
+    cuda_push_array(layer.output_gpu, layer.output, layer.batch*outputs);
+    free(in_cpu);
+    if(truth_cpu) free(truth_cpu);
+}
+
+void backward_detection_layer_gpu(detection_layer layer, float *in, float *delta)
+{
+    int outputs = get_detection_layer_output_size(layer);
+
+    float *in_cpu =    calloc(layer.batch*layer.inputs, sizeof(float));
+    float *delta_cpu = calloc(layer.batch*layer.inputs, sizeof(float));
+
+    cuda_pull_array(in, in_cpu, layer.batch*layer.inputs);
+    cuda_pull_array(layer.delta_gpu, layer.delta, layer.batch*outputs);
+    backward_detection_layer(layer, in_cpu, delta_cpu);
+    cuda_push_array(delta, delta_cpu, layer.batch*layer.inputs);
+
+    free(in_cpu);
+    free(delta_cpu);
+}
+#endif
 
diff --git a/src/detection_layer.h b/src/detection_layer.h
index fad0281e97ab0e10ed8f402828aea95de3732c3a..e7e9e2064d3db2d9dc2bc69e36302f946c41a848 100644
--- a/src/detection_layer.h
+++ b/src/detection_layer.h
@@ -3,38 +3,26 @@
 
 typedef struct {
     int batch;
-    int h,w,c;
-    int n;
-    int size;
-    int stride;
-
-    float *filters;
-    float *filter_updates;
-    float *filter_momentum;
-
-    float *biases;
-    float *bias_updates;
-    float *bias_momentum;
-
-    float *col_image;
-    float *delta;
+    int inputs;
+    int classes;
+    int coords;
+    int rescore;
     float *output;
-
+    float *delta;
     #ifdef GPU
-    cl_mem filters_cl;
-    cl_mem filter_updates_cl;
-    cl_mem filter_momentum_cl;
-
-    cl_mem biases_cl;
-    cl_mem bias_updates_cl;
-    cl_mem bias_momentum_cl;
-
-    cl_mem col_image_cl;
-    cl_mem delta_cl;
-    cl_mem output_cl;
+    float * output_gpu;
+    float * delta_gpu;
     #endif
+} detection_layer;
 
-    ACTIVATION activation;
-} convolutional_layer;
+detection_layer *make_detection_layer(int batch, int inputs, int classes, int coords, int rescore);
+void forward_detection_layer(const detection_layer layer, float *in, float *truth);
+void backward_detection_layer(const detection_layer layer, float *in, float *delta);
+int get_detection_layer_output_size(detection_layer layer);
+
+#ifdef GPU
+void forward_detection_layer_gpu(const detection_layer layer, float *in, float *truth);
+void backward_detection_layer_gpu(detection_layer layer, float *in, float *delta);
+#endif
 
 #endif
diff --git a/src/image.c b/src/image.c
index 53cf281c1f96ae241b092329fcdcfe0100f745de..ee7a823fdbdec2957c59811cdef51be2e7591246 100644
--- a/src/image.c
+++ b/src/image.c
@@ -13,7 +13,7 @@ float get_color(int c, int x, int max)
     int j = ceil(ratio);
     ratio -= i;
     float r = (1-ratio) * colors[i][c] + ratio*colors[j][c];
-    printf("%f\n", r);
+    //printf("%f\n", r);
     return r;
 }
 
diff --git a/src/network.c b/src/network.c
index bf0d63f976d8fa11c2f4f72e91a91af79bd5f766..b60f05953572bc1c99e3be9c2a143ac5c3a6aa11 100644
--- a/src/network.c
+++ b/src/network.c
@@ -9,6 +9,7 @@
 #include "connected_layer.h"
 #include "convolutional_layer.h"
 #include "deconvolutional_layer.h"
+#include "detection_layer.h"
 #include "maxpool_layer.h"
 #include "cost_layer.h"
 #include "normalization_layer.h"
@@ -29,6 +30,8 @@ char *get_layer_string(LAYER_TYPE a)
             return "maxpool";
         case SOFTMAX:
             return "softmax";
+        case DETECTION:
+            return "detection";
         case NORMALIZATION:
             return "normalization";
         case DROPOUT:
@@ -76,6 +79,11 @@ void forward_network(network net, float *input, float *truth, int train)
             forward_deconvolutional_layer(layer, input);
             input = layer.output;
         }
+        else if(net.types[i] == DETECTION){
+            detection_layer layer = *(detection_layer *)net.layers[i];
+            forward_detection_layer(layer, input, truth);
+            input = layer.output;
+        }
         else if(net.types[i] == CONNECTED){
             connected_layer layer = *(connected_layer *)net.layers[i];
             forward_connected_layer(layer, input);
@@ -152,6 +160,9 @@ float *get_network_output_layer(network net, int i)
     } else if(net.types[i] == MAXPOOL){
         maxpool_layer layer = *(maxpool_layer *)net.layers[i];
         return layer.output;
+    } else if(net.types[i] == DETECTION){
+        detection_layer layer = *(detection_layer *)net.layers[i];
+        return layer.output;
     } else if(net.types[i] == SOFTMAX){
         softmax_layer layer = *(softmax_layer *)net.layers[i];
         return layer.output;
@@ -193,6 +204,9 @@ float *get_network_delta_layer(network net, int i)
     } else if(net.types[i] == SOFTMAX){
         softmax_layer layer = *(softmax_layer *)net.layers[i];
         return layer.delta;
+    } else if(net.types[i] == DETECTION){
+        detection_layer layer = *(detection_layer *)net.layers[i];
+        return layer.delta;
     } else if(net.types[i] == DROPOUT){
         if(i == 0) return 0;
         return get_network_delta_layer(net, i-1);
@@ -243,7 +257,7 @@ int get_predicted_class_network(network net)
     return max_index(out, k);
 }
 
-void backward_network(network net, float *input)
+void backward_network(network net, float *input, float *truth)
 {
     int i;
     float *prev_input;
@@ -272,6 +286,10 @@ void backward_network(network net, float *input)
             dropout_layer layer = *(dropout_layer *)net.layers[i];
             backward_dropout_layer(layer, prev_delta);
         }
+        else if(net.types[i] == DETECTION){
+            detection_layer layer = *(detection_layer *)net.layers[i];
+            backward_detection_layer(layer, prev_input, prev_delta);
+        }
         else if(net.types[i] == NORMALIZATION){
             normalization_layer layer = *(normalization_layer *)net.layers[i];
             if(i != 0) backward_normalization_layer(layer, prev_input, prev_delta);
@@ -297,7 +315,7 @@ float train_network_datum(network net, float *x, float *y)
     if(gpu_index >= 0) return train_network_datum_gpu(net, x, y);
     #endif
     forward_network(net, x, y, 1);
-    backward_network(net, x);
+    backward_network(net, x, y);
     float error = get_network_cost(net);
     update_network(net);
     return error;
@@ -351,7 +369,7 @@ float train_network_batch(network net, data d, int n)
             float *x = d.X.vals[index];
             float *y = d.y.vals[index];
             forward_network(net, x, y, 1);
-            backward_network(net, x);
+            backward_network(net, x, y);
             sum += get_network_cost(net);
         }
         update_network(net);
@@ -381,7 +399,6 @@ void set_learning_network(network *net, float rate, float momentum, float decay)
     }
 }
 
-
 void set_batch_network(network *net, int b)
 {
     net->batch = b;
@@ -404,6 +421,9 @@ void set_batch_network(network *net, int b)
         } else if(net->types[i] == DROPOUT){
             dropout_layer *layer = (dropout_layer *) net->layers[i];
             layer->batch = b;
+        } else if(net->types[i] == DETECTION){
+            detection_layer *layer = (detection_layer *) net->layers[i];
+            layer->batch = b;
         }
         else if(net->types[i] == FREEWEIGHT){
             freeweight_layer *layer = (freeweight_layer *) net->layers[i];
@@ -445,6 +465,9 @@ int get_network_input_size_layer(network net, int i)
     } else if(net.types[i] == DROPOUT){
         dropout_layer layer = *(dropout_layer *) net.layers[i];
         return layer.inputs;
+    } else if(net.types[i] == DETECTION){
+        detection_layer layer = *(detection_layer *) net.layers[i];
+        return layer.inputs;
     } else if(net.types[i] == CROP){
         crop_layer layer = *(crop_layer *) net.layers[i];
         return layer.c*layer.h*layer.w;
@@ -473,6 +496,10 @@ int get_network_output_size_layer(network net, int i)
         image output = get_deconvolutional_image(layer);
         return output.h*output.w*output.c;
     }
+    else if(net.types[i] == DETECTION){
+        detection_layer layer = *(detection_layer *)net.layers[i];
+        return get_detection_layer_output_size(layer);
+    }
     else if(net.types[i] == MAXPOOL){
         maxpool_layer layer = *(maxpool_layer *)net.layers[i];
         image output = get_maxpool_image(layer);
diff --git a/src/network.h b/src/network.h
index 66873d2c0ffe3f0e1dc2984255835e6c4f037c0e..d2fb346cee9c5a31be0f44b2bb24829213eac35a 100644
--- a/src/network.h
+++ b/src/network.h
@@ -11,6 +11,7 @@ typedef enum {
     CONNECTED,
     MAXPOOL,
     SOFTMAX,
+    DETECTION,
     NORMALIZATION,
     DROPOUT,
     FREEWEIGHT,
@@ -48,7 +49,7 @@ char *get_layer_string(LAYER_TYPE a);
 
 network make_network(int n, int batch);
 void forward_network(network net, float *input, float *truth, int train);
-void backward_network(network net, float *input);
+void backward_network(network net, float *input, float *truth);
 void update_network(network net);
 
 float train_network(network net, data d);
diff --git a/src/network_kernels.cu b/src/network_kernels.cu
index b83d0566fa0c44cdb6869d2da27dfb254d27decc..928c7f95677fd0463607854e48aa147e8be601cd 100644
--- a/src/network_kernels.cu
+++ b/src/network_kernels.cu
@@ -9,6 +9,7 @@ extern "C" {
 
 #include "crop_layer.h"
 #include "connected_layer.h"
+#include "detection_layer.h"
 #include "convolutional_layer.h"
 #include "deconvolutional_layer.h"
 #include "maxpool_layer.h"
@@ -47,6 +48,11 @@ void forward_network_gpu(network net, float * input, float * truth, int train)
             forward_connected_layer_gpu(layer, input);
             input = layer.output_gpu;
         }
+        else if(net.types[i] == DETECTION){
+            detection_layer layer = *(detection_layer *)net.layers[i];
+            forward_detection_layer_gpu(layer, input, truth);
+            input = layer.output_gpu;
+        }
         else if(net.types[i] == MAXPOOL){
             maxpool_layer layer = *(maxpool_layer *)net.layers[i];
             forward_maxpool_layer_gpu(layer, input);
@@ -73,7 +79,7 @@ void forward_network_gpu(network net, float * input, float * truth, int train)
     }
 }
 
-void backward_network_gpu(network net, float * input)
+void backward_network_gpu(network net, float * input, float *truth)
 {
     int i;
     float * prev_input;
@@ -103,6 +109,10 @@ void backward_network_gpu(network net, float * input)
             connected_layer layer = *(connected_layer *)net.layers[i];
             backward_connected_layer_gpu(layer, prev_input, prev_delta);
         }
+        else if(net.types[i] == DETECTION){
+            detection_layer layer = *(detection_layer *)net.layers[i];
+            backward_detection_layer_gpu(layer, prev_input, prev_delta);
+        }
         else if(net.types[i] == MAXPOOL){
             maxpool_layer layer = *(maxpool_layer *)net.layers[i];
             backward_maxpool_layer_gpu(layer, prev_delta);
@@ -148,6 +158,10 @@ float * get_network_output_gpu_layer(network net, int i)
         deconvolutional_layer layer = *(deconvolutional_layer *)net.layers[i];
         return layer.output_gpu;
     }
+    else if(net.types[i] == DETECTION){
+        detection_layer layer = *(detection_layer *)net.layers[i];
+        return layer.output_gpu;
+    }
     else if(net.types[i] == CONNECTED){
         connected_layer layer = *(connected_layer *)net.layers[i];
         return layer.output_gpu;
@@ -176,6 +190,10 @@ float * get_network_delta_gpu_layer(network net, int i)
         convolutional_layer layer = *(convolutional_layer *)net.layers[i];
         return layer.delta_gpu;
     }
+    else if(net.types[i] == DETECTION){
+        detection_layer layer = *(detection_layer *)net.layers[i];
+        return layer.delta_gpu;
+    }
     else if(net.types[i] == DECONVOLUTIONAL){
         deconvolutional_layer layer = *(deconvolutional_layer *)net.layers[i];
         return layer.delta_gpu;
@@ -215,7 +233,7 @@ float train_network_datum_gpu(network net, float *x, float *y)
     forward_network_gpu(net, *net.input_gpu, *net.truth_gpu, 1);
   //printf("forw %f\n", sec(clock() - time));
   //time = clock();
-    backward_network_gpu(net, *net.input_gpu);
+    backward_network_gpu(net, *net.input_gpu, *net.truth_gpu);
   //printf("back %f\n", sec(clock() - time));
   //time = clock();
     update_network_gpu(net);
@@ -244,6 +262,12 @@ float *get_network_output_layer_gpu(network net, int i)
         cuda_pull_array(layer.output_gpu, layer.output, layer.outputs*layer.batch);
         return layer.output;
     }
+    else if(net.types[i] == DETECTION){
+        detection_layer layer = *(detection_layer *)net.layers[i];
+        int outputs = get_detection_layer_output_size(layer);
+        cuda_pull_array(layer.output_gpu, layer.output, outputs*layer.batch);
+        return layer.output;
+    }
     else if(net.types[i] == MAXPOOL){
         maxpool_layer layer = *(maxpool_layer *)net.layers[i];
         return layer.output;
diff --git a/src/option_list.c b/src/option_list.c
index 76e10166369d5ab8c5c273c79a76bd7d8423efef..f5536e11cff2fef59dc66c6b0dc873c4b3ab4df8 100644
--- a/src/option_list.c
+++ b/src/option_list.c
@@ -53,6 +53,13 @@ int option_find_int(list *l, char *key, int def)
     return def;
 }
 
+int option_find_int_quiet(list *l, char *key, int def)
+{
+    char *v = option_find(l, key);
+    if(v) return atoi(v);
+    return def;
+}
+
 float option_find_float_quiet(list *l, char *key, float def)
 {
     char *v = option_find(l, key);
diff --git a/src/option_list.h b/src/option_list.h
index fa795f3e2ea400af091b70c134716dfa24f22aab..4441462821550742834659da8d3c40c6859d4948 100644
--- a/src/option_list.h
+++ b/src/option_list.h
@@ -13,6 +13,7 @@ void option_insert(list *l, char *key, char *val);
 char *option_find(list *l, char *key);
 char *option_find_str(list *l, char *key, char *def);
 int option_find_int(list *l, char *key, int def);
+int option_find_int_quiet(list *l, char *key, int def);
 float option_find_float(list *l, char *key, float def);
 float option_find_float_quiet(list *l, char *key, float def);
 void option_unused(list *l);
diff --git a/src/parser.c b/src/parser.c
index 850cc38832566d72f54ba9448a260be85b0133ff..53e1f569f99a0c3af13626860388174221e7b0bf 100644
--- a/src/parser.c
+++ b/src/parser.c
@@ -13,6 +13,7 @@
 #include "normalization_layer.h"
 #include "softmax_layer.h"
 #include "dropout_layer.h"
+#include "detection_layer.h"
 #include "freeweight_layer.h"
 #include "list.h"
 #include "option_list.h"
@@ -32,6 +33,7 @@ int is_freeweight(section *s);
 int is_softmax(section *s);
 int is_crop(section *s);
 int is_cost(section *s);
+int is_detection(section *s);
 int is_normalization(section *s);
 list *read_cfg(char *filename);
 
@@ -204,6 +206,24 @@ softmax_layer *parse_softmax(list *options, network *net, int count)
     return layer;
 }
 
+detection_layer *parse_detection(list *options, network *net, int count)
+{
+    int input;
+    if(count == 0){
+        input = option_find_int(options, "input",1);
+        net->batch = option_find_int(options, "batch",1);
+        net->seen = option_find_int(options, "seen",0);
+    }else{
+        input =  get_network_output_size_layer(*net, count-1);
+    }
+    int coords = option_find_int(options, "coords", 1);
+    int classes = option_find_int(options, "classes", 1);
+    int rescore = option_find_int(options, "rescore", 1);
+    detection_layer *layer = make_detection_layer(net->batch, input, classes, coords, rescore);
+    option_unused(options);
+    return layer;
+}
+
 cost_layer *parse_cost(list *options, network *net, int count)
 {
     int input;
@@ -368,6 +388,10 @@ network parse_network_cfg(char *filename)
             cost_layer *layer = parse_cost(options, &net, count);
             net.types[count] = COST;
             net.layers[count] = layer;
+        }else if(is_detection(s)){
+            detection_layer *layer = parse_detection(options, &net, count);
+            net.types[count] = DETECTION;
+            net.layers[count] = layer;
         }else if(is_softmax(s)){
             softmax_layer *layer = parse_softmax(options, &net, count);
             net.types[count] = SOFTMAX;
@@ -410,6 +434,10 @@ int is_cost(section *s)
 {
     return (strcmp(s->type, "[cost]")==0);
 }
+int is_detection(section *s)
+{
+    return (strcmp(s->type, "[detection]")==0);
+}
 int is_deconvolutional(section *s)
 {
     return (strcmp(s->type, "[deconv]")==0
@@ -684,6 +712,13 @@ void print_softmax_cfg(FILE *fp, softmax_layer *l, network net, int count)
     fprintf(fp, "\n");
 }
 
+void print_detection_cfg(FILE *fp, detection_layer *l, network net, int count)
+{
+    fprintf(fp, "[detection]\n");
+    fprintf(fp, "classes=%d\ncoords=%d\nrescore=%d\n", l->classes, l->coords, l->rescore);
+    fprintf(fp, "\n");
+}
+
 void print_cost_cfg(FILE *fp, cost_layer *l, network net, int count)
 {
     fprintf(fp, "[cost]\ntype=%s\n", get_cost_string(l->type));
@@ -815,6 +850,8 @@ void save_network(network net, char *filename)
             print_normalization_cfg(fp, (normalization_layer *)net.layers[i], net, i);
         else if(net.types[i] == SOFTMAX)
             print_softmax_cfg(fp, (softmax_layer *)net.layers[i], net, i);
+        else if(net.types[i] == DETECTION)
+            print_detection_cfg(fp, (detection_layer *)net.layers[i], net, i);
         else if(net.types[i] == COST)
             print_cost_cfg(fp, (cost_layer *)net.layers[i], net, i);
     }
diff --git a/src/softmax_layer.h b/src/softmax_layer.h
index 1c1cdae839a652734e9a88f5d83f9600abe6dab4..3632c7471938b3e8012641e8ead512e396a67370 100644
--- a/src/softmax_layer.h
+++ b/src/softmax_layer.h
@@ -13,6 +13,7 @@ typedef struct {
     #endif
 } softmax_layer;
 
+void softmax_array(float *input, int n, float *output);
 softmax_layer *make_softmax_layer(int batch, int groups, int inputs);
 void forward_softmax_layer(const softmax_layer layer, float *input);
 void backward_softmax_layer(const softmax_layer layer, float *delta);