diff --git a/Makefile b/Makefile
index 3dc564da0b934cec90a3a4d2d185672acc7f0ff4..eee3c96c5cd46be0c411bd9964f1ea54b0ecfcfe 100644
--- a/Makefile
+++ b/Makefile
@@ -8,7 +8,7 @@ OBJDIR=./obj/
 
 CC=gcc
 NVCC=nvcc
-OPTS=-O3
+OPTS=-O0
 LDFLAGS=`pkg-config --libs opencv` -lm -pthread -lstdc++
 COMMON=`pkg-config --cflags opencv` -I/usr/local/cuda/include/
 CFLAGS=-Wall -Wfatal-errors
@@ -22,7 +22,7 @@ CFLAGS+=$(OPTS)
 ifeq ($(GPU), 1) 
 COMMON+=-DGPU
 CFLAGS+=-DGPU
-LDFLAGS+= -L/usr/local/cuda/lib64 -lcuda -lcudart -lcublas
+LDFLAGS+= -L/usr/local/cuda/lib64 -lcuda -lcudart -lcublas -lcurand
 endif
 
 OBJ=gemm.o utils.o cuda.o deconvolutional_layer.o convolutional_layer.o list.o image.o activations.o im2col.o col2im.o blas.o crop_layer.o dropout_layer.o maxpool_layer.o softmax_layer.o data.o matrix.o network.o connected_layer.o cost_layer.o normalization_layer.o parser.o option_list.o darknet.o detection_layer.o imagenet.o captcha.o detection.o
diff --git a/src/activation_kernels.cu b/src/activation_kernels.cu
index 5ee1524478a632cc875b819299021ecf1c7b3142..32c032ce0dcae4e505d5644d9b6c5232854d9ac4 100644
--- a/src/activation_kernels.cu
+++ b/src/activation_kernels.cu
@@ -8,12 +8,19 @@ __device__ float logistic_activate_kernel(float x){return 1./(1. + exp(-x));}
 __device__ float relu_activate_kernel(float x){return x*(x>0);}
 __device__ float ramp_activate_kernel(float x){return x*(x>0)+.1*x;}
 __device__ float tanh_activate_kernel(float x){return (exp(2*x)-1)/(exp(2*x)+1);}
+__device__ float plse_activate_kernel(float x)
+{
+    if(x < -4) return .01 * (x + 4);
+    if(x > 4)  return .01 * (x - 4) + 1;
+    return .125*x + .5;
+}
  
 __device__ float linear_gradient_kernel(float x){return 1;}
 __device__ float logistic_gradient_kernel(float x){return (1-x)*x;}
 __device__ float relu_gradient_kernel(float x){return (x>0);}
 __device__ float ramp_gradient_kernel(float x){return (x>0)+.1;}
 __device__ float tanh_gradient_kernel(float x){return 1-x*x;}
+__device__ float plse_gradient_kernel(float x){return (x < 0 || x > 1) ? .01 : .125;}
 
 __device__ float activate_kernel(float x, ACTIVATION a)
 {
@@ -28,6 +35,8 @@ __device__ float activate_kernel(float x, ACTIVATION a)
             return ramp_activate_kernel(x);
         case TANH:
             return tanh_activate_kernel(x);
+        case PLSE:
+            return plse_activate_kernel(x);
     }
     return 0;
 }
@@ -45,6 +54,8 @@ __device__ float gradient_kernel(float x, ACTIVATION a)
             return ramp_gradient_kernel(x);
         case TANH:
             return tanh_gradient_kernel(x);
+        case PLSE:
+            return plse_gradient_kernel(x);
     }
     return 0;
 }
diff --git a/src/activations.c b/src/activations.c
index 7da5ce25927b4d95ef0ce1d7773e562a081b0590..20fc97b47fabff7f168ea344270c2b4c22a1510e 100644
--- a/src/activations.c
+++ b/src/activations.c
@@ -18,6 +18,8 @@ char *get_activation_string(ACTIVATION a)
             return "linear";
         case TANH:
             return "tanh";
+        case PLSE:
+            return "plse";
         default:
             break;
     }
@@ -28,6 +30,7 @@ ACTIVATION get_activation(char *s)
 {
     if (strcmp(s, "logistic")==0) return LOGISTIC;
     if (strcmp(s, "relu")==0) return RELU;
+    if (strcmp(s, "plse")==0) return PLSE;
     if (strcmp(s, "linear")==0) return LINEAR;
     if (strcmp(s, "ramp")==0) return RAMP;
     if (strcmp(s, "tanh")==0) return TANH;
@@ -48,6 +51,8 @@ float activate(float x, ACTIVATION a)
             return ramp_activate(x);
         case TANH:
             return tanh_activate(x);
+        case PLSE:
+            return plse_activate(x);
     }
     return 0;
 }
@@ -73,6 +78,8 @@ float gradient(float x, ACTIVATION a)
             return ramp_gradient(x);
         case TANH:
             return tanh_gradient(x);
+        case PLSE:
+            return plse_gradient(x);
     }
     return 0;
 }
diff --git a/src/activations.h b/src/activations.h
index 0cb81af67efc91dde7e8f408a4e1452fa86baabd..f28ac0d4e8cad6b3378082b6ae68019f61cc29d5 100644
--- a/src/activations.h
+++ b/src/activations.h
@@ -3,7 +3,7 @@
 #define ACTIVATIONS_H
 
 typedef enum{
-    LOGISTIC, RELU, LINEAR, RAMP, TANH
+    LOGISTIC, RELU, LINEAR, RAMP, TANH, PLSE
 }ACTIVATION;
 
 ACTIVATION get_activation(char *s);
@@ -23,12 +23,19 @@ static inline float logistic_activate(float x){return 1./(1. + exp(-x));}
 static inline float relu_activate(float x){return x*(x>0);}
 static inline float ramp_activate(float x){return x*(x>0)+.1*x;}
 static inline float tanh_activate(float x){return (exp(2*x)-1)/(exp(2*x)+1);}
+static inline float plse_activate(float x)
+{
+    if(x < -4) return .01 * (x + 4);
+    if(x > 4)  return .01 * (x - 4) + 1;
+    return .125*x + .5;
+}
 
 static inline float linear_gradient(float x){return 1;}
 static inline float logistic_gradient(float x){return (1-x)*x;}
 static inline float relu_gradient(float x){return (x>0);}
 static inline float ramp_gradient(float x){return (x>0)+.1;}
 static inline float tanh_gradient(float x){return 1-x*x;}
+static inline float plse_gradient(float x){return (x < 0 || x > 1) ? .01 : .125;}
 
 #endif
 
diff --git a/src/captcha.c b/src/captcha.c
index 40a4082dc98d36f05e835a595a3fdc8ef06c34be..6e02f5aca92eda3607cd7c616b2aabab677d1291 100644
--- a/src/captcha.c
+++ b/src/captcha.c
@@ -16,7 +16,7 @@ void train_captcha(char *cfgfile, char *weightfile)
     printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
     int imgs = 1024;
     int i = net.seen/imgs;
-    list *plist = get_paths("/data/captcha/train.base");
+    list *plist = get_paths("/data/captcha/train.auto5");
     char **paths = (char **)list_to_array(plist);
     printf("%d\n", plist->size);
     clock_t time;
@@ -34,7 +34,7 @@ void train_captcha(char *cfgfile, char *weightfile)
         avg_loss = avg_loss*.9 + loss*.1;
         printf("%d: %f, %f avg, %lf seconds, %d images\n", i, loss, avg_loss, sec(clock()-time), net.seen);
         free_data(train);
-        if(i%100==0){
+        if(i%10==0){
             char buff[256];
             sprintf(buff, "/home/pjreddie/imagenet_backup/%s_%d.weights",base, i);
             save_weights(net, buff);
@@ -56,11 +56,11 @@ void decode_captcha(char *cfgfile, char *weightfile)
         printf("Enter filename: ");
         fgets(filename, 256, stdin);
         strtok(filename, "\n");
-        image im = load_image_color(filename, 60, 200);
+        image im = load_image_color(filename, 57, 300);
         scale_image(im, 1./255.);
         float *X = im.data;
         float *predictions = network_predict(net, X);
-        image out  = float_to_image(60, 200, 3, predictions);
+        image out  = float_to_image(57, 300, 1, predictions);
         show_image(out, "decoded");
         cvWaitKey(0);
         free_image(im);
@@ -87,7 +87,7 @@ void encode_captcha(char *cfgfile, char *weightfile)
     while(1){
         ++i;
         time=clock();
-        data train = load_data_captcha_encode(paths, imgs, plist->size, 60, 200);
+        data train = load_data_captcha_encode(paths, imgs, plist->size, 57, 300);
         scale_data_rows(train, 1./255);
         printf("Loaded: %lf seconds\n", sec(clock()-time));
         time=clock();
@@ -114,10 +114,10 @@ void validate_captcha(char *cfgfile, char *weightfile)
     if(weightfile){
         load_weights(&net, weightfile);
     }
-    int imgs = 1000;
     int numchars = 37;
-    list *plist = get_paths("/data/captcha/valid.base");
+    list *plist = get_paths("/data/captcha/solved.hard");
     char **paths = (char **)list_to_array(plist);
+    int imgs = plist->size;
     data valid = load_data_captcha(paths, imgs, 0, 10, 60, 200);
     translate_data_rows(valid, -128);
     scale_data_rows(valid, 1./128);
diff --git a/src/convolutional_kernels.cu b/src/convolutional_kernels.cu
index 77304aa87fe6dc7a4a46e3da17091b4152e9b768..864d7fa3bacd9bb505c1ab8a1a0f5250450dec3f 100644
--- a/src/convolutional_kernels.cu
+++ b/src/convolutional_kernels.cu
@@ -56,6 +56,7 @@ extern "C" void backward_bias_gpu(float *bias_updates, float *delta, int batch,
 
 extern "C" void forward_convolutional_layer_gpu(convolutional_layer layer, network_state state)
 {
+clock_t time = clock();
     int i;
     int m = layer.n;
     int k = layer.size*layer.size*layer.c;
@@ -63,15 +64,31 @@ extern "C" void forward_convolutional_layer_gpu(convolutional_layer layer, netwo
         convolutional_out_width(layer);
 
     bias_output_gpu(layer.output_gpu, layer.biases_gpu, layer.batch, layer.n, n);
+cudaDeviceSynchronize();
+printf("bias %f\n", sec(clock() - time));
+time = clock();
 
+float imt=0;
+float gemt = 0;
     for(i = 0; i < layer.batch; ++i){
+time = clock();
         im2col_ongpu(state.input + i*layer.c*layer.h*layer.w, layer.c,  layer.h,  layer.w,  layer.size,  layer.stride, layer.pad, layer.col_image_gpu);
+cudaDeviceSynchronize();
+imt += sec(clock()-time);
+time = clock();
         float * a = layer.filters_gpu;
         float * b = layer.col_image_gpu;
         float * c = layer.output_gpu;
         gemm_ongpu(0,0,m,n,k,1.,a,k,b,n,1.,c+i*m*n,n);
+cudaDeviceSynchronize();
+gemt += sec(clock()-time);
+time = clock();
     }
     activate_array_ongpu(layer.output_gpu, m*n*layer.batch, layer.activation);
+cudaDeviceSynchronize();
+printf("activate %f\n", sec(clock() - time));
+printf("im2col %f\n", imt);
+printf("gemm %f\n", gemt);
 }
 
 extern "C" void backward_convolutional_layer_gpu(convolutional_layer layer, network_state state)
diff --git a/src/cuda.c b/src/cuda.c
index c9142905af50ea4e94c748dae96b5affebb76f74..79829534309bb2985562c4c8d78fcbda888d4520 100644
--- a/src/cuda.c
+++ b/src/cuda.c
@@ -59,6 +59,18 @@ float *cuda_make_array(float *x, int n)
     return x_gpu;
 }
 
+void cuda_random(float *x_gpu, int n)
+{
+    static curandGenerator_t gen;
+    static int init = 0;
+    if(!init){
+        curandCreateGenerator(&gen, CURAND_RNG_PSEUDO_DEFAULT);
+        curandSetPseudoRandomGeneratorSeed(gen, 0ULL);
+    }
+    curandGenerateUniform(gen, x_gpu, n);
+    check_error(cudaPeekAtLastError());
+}
+
 float cuda_compare(float *x_gpu, float *x, int n, char *s)
 {
     float *tmp = calloc(n, sizeof(float));
diff --git a/src/cuda.h b/src/cuda.h
index cbe79755d0956def2ad525444e200ff97eb88ad1..ff0d5f337bfb5cce4c1dafb9f2fc55720d012896 100644
--- a/src/cuda.h
+++ b/src/cuda.h
@@ -8,6 +8,7 @@ extern int gpu_index;
 #define BLOCK 256
 
 #include "cuda_runtime.h"
+#include "curand.h"
 #include "cublas_v2.h"
 
 void check_error(cudaError_t status);
@@ -17,6 +18,7 @@ int *cuda_make_int_array(int n);
 void cuda_push_array(float *x_gpu, float *x, int n);
 void cuda_pull_array(float *x_gpu, float *x, int n);
 void cuda_free(float *x_gpu);
+void cuda_random(float *x_gpu, int n);
 float cuda_compare(float *x_gpu, float *x, int n, char *s);
 dim3 cuda_gridsize(size_t n);
 
diff --git a/src/data.c b/src/data.c
index 342edfa4f31bce99f1bee1589411945ecb8454ef..8dd7d9a6309f51c7fcb36ad4af177273dc6083db 100644
--- a/src/data.c
+++ b/src/data.c
@@ -112,7 +112,12 @@ void fill_truth_detection(char *path, float *truth, int classes, int height, int
     randomize_boxes(boxes, count);
     float x, y, h, w;
     int id;
-    int i, j;
+    int i;
+    if(background){
+        for(i = 0; i < num_height*num_width*(4+classes+background); i += 4+classes+background){
+            truth[i] = 1;
+        }
+    }
     for(i = 0; i < count; ++i){
         x = boxes[i].x;
         y = boxes[i].y;
@@ -137,21 +142,15 @@ void fill_truth_detection(char *path, float *truth, int classes, int height, int
 
         int index = (i+j*num_width)*(4+classes+background);
         if(truth[index+classes+background]) continue;
+        if(background) truth[index++] = 0;
         truth[index+id] = 1;
-        index += classes+background;
+        index += classes;
         truth[index++] = dh;
         truth[index++] = dw;
         truth[index++] = h*(height+jitter)/height;
         truth[index++] = w*(width+jitter)/width;
     }
     free(boxes);
-    if(background){
-        for(i = 0; i < num_height*num_width*(4+classes+background); i += 4+classes+background){
-            int object = 0;
-            for(j = i; j < i+classes; ++j) if (truth[j]) object = 1;
-            truth[i+classes] = !object;
-        }
-    }
 }
 
 #define NUMCHARS 37
@@ -202,6 +201,7 @@ data load_data_captcha_encode(char **paths, int n, int m, int h, int w)
     data d;
     d.shallow = 0;
     d.X = load_image_paths(paths, n, h, w);
+    d.X.cols = 17100;
     d.y = d.X;
     if(m) free(paths);
     return d;
diff --git a/src/detection.c b/src/detection.c
index f86134780d09b3b3fa99738d72ab5ab3532a5d2d..15694c51c3483e1377771de2835abe3effb59227 100644
--- a/src/detection.c
+++ b/src/detection.c
@@ -108,7 +108,7 @@ void validate_detection(char *cfgfile, char *weightfile)
     char **paths = (char **)list_to_array(plist);
     int im_size = 448;
     int classes = 20;
-    int background = 0;
+    int background = 1;
     int num_output = 7*7*(4+classes+background);
 
     int m = plist->size;
@@ -143,7 +143,7 @@ void validate_detection(char *cfgfile, char *weightfile)
                     float x = (c + pred.vals[j][ci + 1])/7.;
                     float h = pred.vals[j][ci + 2];
                     float w = pred.vals[j][ci + 3];
-                    printf("%d %d %f %f %f %f %f\n", (i-1)*m/splits + j, class, pred.vals[j][k+class], y, x, h, w);
+                    printf("%d %d %f %f %f %f %f\n", (i-1)*m/splits + j, class, pred.vals[j][k+class+background], y, x, h, w);
                 }
             }
         }
diff --git a/src/detection_layer.c b/src/detection_layer.c
index 5ca7fa2b5c0b486cc2e3bc824d54633a49cdc6c1..0a754fd4b3bef077606efd184d4ad56143ad832b 100644
--- a/src/detection_layer.c
+++ b/src/detection_layer.c
@@ -8,23 +8,24 @@
 
 int get_detection_layer_locations(detection_layer layer)
 {
-    return layer.inputs / (layer.classes+layer.coords+layer.rescore);
+    return layer.inputs / (layer.classes+layer.coords+layer.rescore+layer.background);
 }
 
 int get_detection_layer_output_size(detection_layer layer)
 {
-    return get_detection_layer_locations(layer)*(layer.classes+layer.coords);
+    return get_detection_layer_locations(layer)*(layer.background + layer.classes + layer.coords);
 }
 
-detection_layer *make_detection_layer(int batch, int inputs, int classes, int coords, int rescore)
+detection_layer *make_detection_layer(int batch, int inputs, int classes, int coords, int rescore, int background)
 {
     detection_layer *layer = calloc(1, sizeof(detection_layer));
-
+    
     layer->batch = batch;
     layer->inputs = inputs;
     layer->classes = classes;
     layer->coords = coords;
     layer->rescore = rescore;
+    layer->background = background;
     int outputs = get_detection_layer_output_size(*layer);
     layer->output = calloc(batch*outputs, sizeof(float));
     layer->delta = calloc(batch*outputs, sizeof(float));
@@ -39,6 +40,27 @@ detection_layer *make_detection_layer(int batch, int inputs, int classes, int co
     return layer;
 }
 
+void dark_zone(detection_layer layer, int class, int start, network_state state)
+{
+    int index = start+layer.background+class;
+    int size = layer.classes+layer.coords+layer.background;
+    int location = (index%(7*7*size)) / size ;
+    int r = location / 7;
+    int c = location % 7;
+    int dr, dc;
+    for(dr = -1; dr <= 1; ++dr){
+        for(dc = -1; dc <= 1; ++dc){
+            if(!(dr || dc)) continue;
+            if((r + dr) > 6 || (r + dr) < 0) continue;
+            if((c + dc) > 6 || (c + dc) < 0) continue;
+            int di = (dr*7 + dc) * size;
+            if(state.truth[index+di]) continue;
+            layer.output[index + di] = 0;
+            //if(!state.truth[start+di]) continue;
+            //layer.output[start + di] = 1;
+        }
+    }
+}
 
 void forward_detection_layer(const detection_layer layer, network_state state)
 {
@@ -47,39 +69,30 @@ void forward_detection_layer(const detection_layer layer, network_state state)
     int locations = get_detection_layer_locations(layer);
     int i,j;
     for(i = 0; i < layer.batch*locations; ++i){
-        int mask = (!state.truth || state.truth[out_i + layer.classes + 2]);
+        int mask = (!state.truth || state.truth[out_i + layer.background + layer.classes + 2]);
         float scale = 1;
         if(layer.rescore) scale = state.input[in_i++];
+        if(layer.background) layer.output[out_i++] = scale*state.input[in_i++];
+
         for(j = 0; j < layer.classes; ++j){
             layer.output[out_i++] = scale*state.input[in_i++];
         }
-        if(!layer.rescore){
-            softmax_array(layer.output + out_i - layer.classes, layer.classes, layer.output + out_i - layer.classes);
+        if(layer.background){
+            softmax_array(layer.output + out_i - layer.classes-layer.background, layer.classes+layer.background, layer.output + out_i - layer.classes-layer.background);
             activate_array(state.input+in_i, layer.coords, LOGISTIC);
         }
         for(j = 0; j < layer.coords; ++j){
             layer.output[out_i++] = mask*state.input[in_i++];
         }
     }
-}
-
-void dark_zone(detection_layer layer, int index, network_state state)
-{
-    int size = layer.classes+layer.rescore+layer.coords;
-    int location = (index%(7*7*size)) / size ;
-    int r = location / 7;
-    int c = location % 7;
-    int class = index%size;
-    if(layer.rescore) --class;
-    int dr, dc;
-    for(dr = -1; dr <= 1; ++dr){
-        for(dc = -1; dc <= 1; ++dc){
-            if(!(dr || dc)) continue;
-            if((r + dr) > 6 || (r + dr) < 0) continue;
-            if((c + dc) > 6 || (c + dc) < 0) continue;
-            int di = (dr*7 + dc) * size;
-            if(state.truth[index+di]) continue;
-            layer.delta[index + di] = 0;
+    if(layer.background || 1){
+        for(i = 0; i < layer.batch*locations; ++i){
+            int index = i*(layer.classes+layer.coords+layer.background);
+            for(j= 0; j < layer.classes; ++j){
+                if(state.truth[index+j+layer.background]){
+                    //dark_zone(layer, j, index, state);
+                }
+            }
         }
     }
 }
@@ -94,21 +107,17 @@ void backward_detection_layer(const detection_layer layer, network_state state)
         float scale = 1;
         float latent_delta = 0;
         if(layer.rescore) scale = state.input[in_i++];
-        if(!layer.rescore){
-            for(j = 0; j < layer.classes-1; ++j){
-                if(state.truth[out_i + j]) dark_zone(layer, out_i+j, state);
-            }
-        }
+        if(layer.background) state.delta[in_i++] = scale*layer.delta[out_i++];
         for(j = 0; j < layer.classes; ++j){
             latent_delta += state.input[in_i]*layer.delta[out_i];
             state.delta[in_i++] = scale*layer.delta[out_i++];
         }
 
-        if (!layer.rescore) gradient_array(layer.output + out_i, layer.coords, LOGISTIC, layer.delta + out_i);
+        if (layer.background) gradient_array(layer.output + out_i, layer.coords, LOGISTIC, layer.delta + out_i);
         for(j = 0; j < layer.coords; ++j){
             state.delta[in_i++] = layer.delta[out_i++];
         }
-        if(layer.rescore) state.delta[in_i-layer.coords-layer.classes-layer.rescore] = latent_delta;
+        if(layer.rescore) state.delta[in_i-layer.coords-layer.classes-layer.rescore-layer.background] = latent_delta;
     }
 }
 
diff --git a/src/detection_layer.h b/src/detection_layer.h
index 69a83a73c1f55758f529b581e7593506983b1ba8..2ad1ef2350528485586f327839c3da5483a3b2b4 100644
--- a/src/detection_layer.h
+++ b/src/detection_layer.h
@@ -8,6 +8,7 @@ typedef struct {
     int inputs;
     int classes;
     int coords;
+    int background;
     int rescore;
     float *output;
     float *delta;
@@ -17,7 +18,7 @@ typedef struct {
     #endif
 } detection_layer;
 
-detection_layer *make_detection_layer(int batch, int inputs, int classes, int coords, int rescore);
+detection_layer *make_detection_layer(int batch, int inputs, int classes, int coords, int rescore, int background);
 void forward_detection_layer(const detection_layer layer, network_state state);
 void backward_detection_layer(const detection_layer layer, network_state state);
 int get_detection_layer_output_size(detection_layer layer);
diff --git a/src/dropout_layer_kernels.cu b/src/dropout_layer_kernels.cu
index 94f61ab0fb961ccc2b7f717b1168ac040495e878..4561d89df7d774d29a24d3aaf594e53a26679e9d 100644
--- a/src/dropout_layer_kernels.cu
+++ b/src/dropout_layer_kernels.cu
@@ -14,10 +14,8 @@ __global__ void yoloswag420blazeit360noscope(float *input, int size, float *rand
 extern "C" void forward_dropout_layer_gpu(dropout_layer layer, network_state state)
 {
     if (!state.train) return;
-    int j;
     int size = layer.inputs*layer.batch;
-    for(j = 0; j < size; ++j) layer.rand[j] = rand_uniform();
-    cuda_push_array(layer.rand_gpu, layer.rand, layer.inputs*layer.batch);
+    cuda_random(layer.rand_gpu, size);
 
     yoloswag420blazeit360noscope<<<cuda_gridsize(size), BLOCK>>>(state.input, size, layer.rand_gpu, layer.probability, layer.scale);
     check_error(cudaPeekAtLastError());
diff --git a/src/imagenet.c b/src/imagenet.c
index 9118c084e3239bb714e0df3f35a98a3004957785..7da73a09d45d00495bad923406b6803c51a37891 100644
--- a/src/imagenet.c
+++ b/src/imagenet.c
@@ -13,7 +13,7 @@ void train_imagenet(char *cfgfile, char *weightfile)
         load_weights(&net, weightfile);
     }
     printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
-    int imgs = 1024;
+    int imgs = 128;
     int i = net.seen/imgs;
     char **labels = get_labels("/home/pjreddie/data/imagenet/cls.labels.list");
     list *plist = get_paths("/data/imagenet/cls.train.list");
diff --git a/src/network_kernels.cu b/src/network_kernels.cu
index acc31d7cd5699b835d901a5d340c61c7a5860d88..03cb149fae9afddf5dda3d2aa94f330898584493 100644
--- a/src/network_kernels.cu
+++ b/src/network_kernels.cu
@@ -28,6 +28,7 @@ void forward_network_gpu(network net, network_state state)
 {
     int i;
     for(i = 0; i < net.n; ++i){
+//clock_t time = clock();
         if(net.types[i] == CONVOLUTIONAL){
             forward_convolutional_layer_gpu(*(convolutional_layer *)net.layers[i], state);
         }
@@ -56,6 +57,9 @@ void forward_network_gpu(network net, network_state state)
             forward_crop_layer_gpu(*(crop_layer *)net.layers[i], state);
         }
         state.input = get_network_output_gpu_layer(net, i);
+//cudaDeviceSynchronize();
+//printf("forw %d: %s %f\n", i, get_layer_string(net.types[i]), sec(clock() - time));
+//time = clock();
     }
 }
 
@@ -64,7 +68,7 @@ void backward_network_gpu(network net, network_state state)
     int i;
     float * original_input = state.input;
     for(i = net.n-1; i >= 0; --i){
-        //clock_t time = clock();
+//clock_t time = clock();
         if(i == 0){
             state.input = original_input;
             state.delta = 0;
@@ -96,6 +100,9 @@ void backward_network_gpu(network net, network_state state)
         else if(net.types[i] == SOFTMAX){
             backward_softmax_layer_gpu(*(softmax_layer *)net.layers[i], state);
         }
+//cudaDeviceSynchronize();
+//printf("back %d: %s %f\n", i, get_layer_string(net.types[i]), sec(clock() - time));
+//time = clock();
     }
 }
 
@@ -181,7 +188,7 @@ float * get_network_delta_gpu_layer(network net, int i)
 
 float train_network_datum_gpu(network net, float *x, float *y)
 {
-  //clock_t time = clock();
+ // clock_t time = clock();
     network_state state;
     int x_size = get_network_input_size(net)*net.batch;
     int y_size = get_network_output_size(net)*net.batch;
@@ -195,22 +202,26 @@ float train_network_datum_gpu(network net, float *x, float *y)
     state.input = *net.input_gpu;
     state.truth = *net.truth_gpu;
     state.train = 1;
-  //printf("trans %f\n", sec(clock() - time));
-  //time = clock();
+//cudaDeviceSynchronize();
+//printf("trans %f\n", sec(clock() - time));
+//time = clock();
     forward_network_gpu(net, state);
-  //printf("forw %f\n", sec(clock() - time));
-  //time = clock();
+//cudaDeviceSynchronize();
+//printf("forw %f\n", sec(clock() - time));
+//time = clock();
     backward_network_gpu(net, state);
-  //printf("back %f\n", sec(clock() - time));
-  //time = clock();
+//cudaDeviceSynchronize();
+//printf("back %f\n", sec(clock() - time));
+//time = clock();
     update_network_gpu(net);
     float error = get_network_cost(net);
 
     //print_letters(y, 50);
     //float *out = get_network_output_gpu(net);
     //print_letters(out, 50);
-  //printf("updt %f\n", sec(clock() - time));
-  //time = clock();
+//cudaDeviceSynchronize();
+//printf("updt %f\n", sec(clock() - time));
+//time = clock();
     return error;
 }
 
@@ -256,7 +267,6 @@ float *get_network_output_gpu(network net)
 
 float *network_predict_gpu(network net, float *input)
 {
-
     int size = get_network_input_size(net) * net.batch;
     network_state state;
     state.input = cuda_make_array(input, size);
diff --git a/src/parser.c b/src/parser.c
index d7c4a3104d98f6c909f422cb9cff2499f81e6e12..81d1f8fce336c145076938ee1099492ad82a9cd3 100644
--- a/src/parser.c
+++ b/src/parser.c
@@ -165,7 +165,8 @@ detection_layer *parse_detection(list *options, size_params params)
     int coords = option_find_int(options, "coords", 1);
     int classes = option_find_int(options, "classes", 1);
     int rescore = option_find_int(options, "rescore", 1);
-    detection_layer *layer = make_detection_layer(params.batch, params.inputs, classes, coords, rescore);
+    int background = option_find_int(options, "background", 1);
+    detection_layer *layer = make_detection_layer(params.batch, params.inputs, classes, coords, rescore, background);
     option_unused(options);
     return layer;
 }