diff --git a/Makefile b/Makefile
index 32ff4c0a63b7b9b8aa83f01f1e1064f6a1f9c64e..1b4227775c7b1056ea09318ac3c68b8ed8ca390d 100644
--- a/Makefile
+++ b/Makefile
@@ -1,5 +1,5 @@
-GPU=0
-OPENCV=0
+GPU=1
+OPENCV=1
 DEBUG=0
 
 ARCH= -arch=sm_52
diff --git a/src/avgpool_layer.c b/src/avgpool_layer.c
index ee56161ef916d2073909735ecea97e9ea178f2f7..8eccde624898feec7660efe60f646de597d03496 100644
--- a/src/avgpool_layer.c
+++ b/src/avgpool_layer.c
@@ -58,7 +58,7 @@ void backward_avgpool_layer(const avgpool_layer l, network_state state)
             int out_index = k + b*l.c;
             for(i = 0; i < l.h*l.w; ++i){
                 int in_index = i + l.h*l.w*(k + b*l.c);
-                state.delta[in_index] = l.delta[out_index] / (l.h*l.w);
+                state.delta[in_index] += l.delta[out_index] / (l.h*l.w);
             }
         }
     }
diff --git a/src/avgpool_layer_kernels.cu b/src/avgpool_layer_kernels.cu
index ca628c68e8c35d9f45bff4c7e3e7c32b41260019..1bd2a2af2d0c43733f281622c9fe8a4d55c8d254 100644
--- a/src/avgpool_layer_kernels.cu
+++ b/src/avgpool_layer_kernels.cu
@@ -35,7 +35,7 @@ __global__ void backward_avgpool_layer_kernel(int n, int w, int h, int c, float
     int out_index = (k + c*b);
     for(i = 0; i < w*h; ++i){
         int in_index = i + h*w*(k + b*c);
-        in_delta[in_index] = out_delta[out_index] / (w*h);
+        in_delta[in_index] += out_delta[out_index] / (w*h);
     }
 }
 
diff --git a/src/col2im_kernels.cu b/src/col2im_kernels.cu
index 67c0b03440015cae350bf60a543ab1d0371435e5..7262f923124c700e46ee07f82c02596360e3a92e 100644
--- a/src/col2im_kernels.cu
+++ b/src/col2im_kernels.cu
@@ -33,7 +33,7 @@ __global__ void col2im_gpu_kernel(const int n, const float* data_col,
                 val += data_col[offset + h_col * coeff_h_col + w_col * coeff_w_col];
             }
         }
-        data_im[index] = val;
+        data_im[index] += val;
     }
 }
 
@@ -53,62 +53,3 @@ void col2im_ongpu(float *data_col,
                 width_col, data_im);
 }
 
-/*
-   __global__ void col2im_kernel(float *data_col,
-   int channels, int height, int width,
-   int ksize, int stride, int pad, float *data_im)
-   {
-
-   int height_col = (height - ksize) / stride + 1;
-   int width_col = (width - ksize) / stride + 1;
-   if (pad){
-   height_col = 1 + (height-1) / stride;
-   width_col = 1 + (width-1) / stride;
-   pad = ksize/2;
-   }
-
-   int id = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
-   if(id >= channels*height*width) return;
-
-   int index = id;
-   int w = id%width + pad;
-   id /= width;
-   int h = id%height + pad;
-   id /= height;
-   int c = id%channels;
-
-   int w_start = (w-ksize+stride)/stride;
-   int w_end = w/stride + 1;
-
-   int h_start = (h-ksize+stride)/stride;
-   int h_end = h/stride + 1;
-
-// int rows = channels * ksize * ksize;
-// int cols = height_col*width_col;
-int col_offset = (c*ksize*ksize + h * ksize + w)*height_col*width_col;
-int h_coeff = (1-stride*ksize*height_col)*width_col;
-int w_coeff = 1-stride*height_col*width_col;
-float val = 0;
-int h_col, w_col;
-for(h_col = h_start; h_col < h_end; ++h_col){
-for(w_col = w_start; w_col < w_end; ++w_col){
-int col_index = col_offset +h_col*h_coeff + w_col*w_coeff;
-float part = (w_col < 0 || h_col < 0 || h_col >= height_col || w_col >= width_col) ? 0 : data_col[col_index];
-val += part;
-}
-}
-data_im[index] = val;
-}
-
-
-extern "C" void col2im_ongpu(float *data_col,
-int channels,  int height,  int width,
-int ksize,  int stride,  int pad, float *data_im)
-{
-
-size_t n = channels*height*width;
-
-col2im_kernel<<<cuda_gridsize(n), BLOCK>>>(data_col, channels, height, width, ksize, stride, pad, data_im);
-check_error(cudaPeekAtLastError());
-}
- */
diff --git a/src/connected_layer.c b/src/connected_layer.c
index 55d84cac3392850d319be787444ca9874e9dd64d..432350575985ffaa89fc3b9efa2483466f0e4438 100644
--- a/src/connected_layer.c
+++ b/src/connected_layer.c
@@ -103,7 +103,7 @@ void backward_connected_layer(connected_layer l, network_state state)
     b = l.weights;
     c = state.delta;
 
-    if(c) gemm(0,1,m,n,k,1,a,k,b,k,0,c,n);
+    if(c) gemm(0,1,m,n,k,1,a,k,b,k,1,c,n);
 }
 
 #ifdef GPU
@@ -173,6 +173,6 @@ void backward_connected_layer_gpu(connected_layer l, network_state state)
     b = l.weights_gpu;
     c = state.delta;
 
-    if(c) gemm_ongpu(0,1,m,n,k,1,a,k,b,k,0,c,n);
+    if(c) gemm_ongpu(0,1,m,n,k,1,a,k,b,k,1,c,n);
 }
 #endif
diff --git a/src/convolutional_kernels.cu b/src/convolutional_kernels.cu
index d260a95b8bbd38061abef4eba4565823df9ecec7..a150c2050a80666fc95f69d8634400e0bf430254 100644
--- a/src/convolutional_kernels.cu
+++ b/src/convolutional_kernels.cu
@@ -82,8 +82,6 @@ void backward_convolutional_layer_gpu(convolutional_layer layer, network_state s
     gradient_array_ongpu(layer.output_gpu, m*k*layer.batch, layer.activation, layer.delta_gpu);
     backward_bias_gpu(layer.bias_updates_gpu, layer.delta_gpu, layer.batch, layer.n, k);
 
-    if(state.delta) scal_ongpu(layer.batch*layer.h*layer.w*layer.c, 0, state.delta, 1);
-
     for(i = 0; i < layer.batch; ++i){
         float * a = layer.delta_gpu;
         float * b = layer.col_image_gpu;
diff --git a/src/convolutional_layer.c b/src/convolutional_layer.c
index c2669348e4265bd1936dfb6e0cb3a90690b64795..c3a3718cfa908dd67908d83965c42fd6500e7140 100644
--- a/src/convolutional_layer.c
+++ b/src/convolutional_layer.c
@@ -188,8 +188,6 @@ void backward_convolutional_layer(convolutional_layer l, network_state state)
     gradient_array(l.output, m*k*l.batch, l.activation, l.delta);
     backward_bias(l.bias_updates, l.delta, l.batch, l.n, k);
 
-    if(state.delta) memset(state.delta, 0, l.batch*l.h*l.w*l.c*sizeof(float));
-
     for(i = 0; i < l.batch; ++i){
         float *a = l.delta + i*m*k;
         float *b = l.col_image;
diff --git a/src/cost_layer.c b/src/cost_layer.c
index 76aa17e10a0e429e983d0ba615c2dcafd26e8f35..d1ae6e5bed8ebe9e47eb236d2c2218a05146c7cf 100644
--- a/src/cost_layer.c
+++ b/src/cost_layer.c
@@ -61,7 +61,7 @@ void forward_cost_layer(cost_layer l, network_state state)
 
 void backward_cost_layer(const cost_layer l, network_state state)
 {
-    copy_cpu(l.batch*l.inputs, l.delta, 1, state.delta, 1);
+    axpy_cpu(l.batch*l.inputs, 1, l.delta, 1, state.delta, 1);
 }
 
 #ifdef GPU
@@ -92,7 +92,7 @@ void forward_cost_layer_gpu(cost_layer l, network_state state)
 
 void backward_cost_layer_gpu(const cost_layer l, network_state state)
 {
-    copy_ongpu(l.batch*l.inputs, l.delta_gpu, 1, state.delta, 1);
+    axpy_ongpu(l.batch*l.inputs, 1, l.delta_gpu, 1, state.delta, 1);
 }
 #endif
 
diff --git a/src/deconvolutional_layer.c b/src/deconvolutional_layer.c
index 524fc958bd8cba0863cd0ddf186cf43ba0e3d3ec..0f4e1e857f5bdf3675ad41f14ee0af04145c681d 100644
--- a/src/deconvolutional_layer.c
+++ b/src/deconvolutional_layer.c
@@ -159,8 +159,6 @@ void backward_deconvolutional_layer(deconvolutional_layer l, network_state state
     gradient_array(l.output, size*l.n*l.batch, l.activation, l.delta);
     backward_bias(l.bias_updates, l.delta, l.batch, l.n, size);
 
-    if(state.delta) memset(state.delta, 0, l.batch*l.h*l.w*l.c*sizeof(float));
-
     for(i = 0; i < l.batch; ++i){
         int m = l.c;
         int n = l.size*l.size*l.n;
diff --git a/src/detection_layer.c b/src/detection_layer.c
index 9ef89d9a6ab3b08d6fba2cbe726c4dc631455f1e..6a25819be93bf1444a8af5943d55c6656dd7f74c 100644
--- a/src/detection_layer.c
+++ b/src/detection_layer.c
@@ -141,20 +141,20 @@ void backward_detection_layer(const detection_layer l, network_state state)
         float scale = 1;
         float latent_delta = 0;
         if(l.joint) scale = state.input[in_i++];
-        else if (l.objectness)   state.delta[in_i++] = -l.delta[out_i++];
-        else if (l.background) state.delta[in_i++] = scale*l.delta[out_i++];
+        else if (l.objectness)   state.delta[in_i++] += -l.delta[out_i++];
+        else if (l.background) state.delta[in_i++] += scale*l.delta[out_i++];
         for(j = 0; j < l.classes; ++j){
             latent_delta += state.input[in_i]*l.delta[out_i];
-            state.delta[in_i++] = scale*l.delta[out_i++];
+            state.delta[in_i++] += scale*l.delta[out_i++];
         }
 
         if (l.objectness) {
 
         }else if (l.background) gradient_array(l.output + out_i, l.coords, LOGISTIC, l.delta + out_i);
         for(j = 0; j < l.coords; ++j){
-            state.delta[in_i++] = l.delta[out_i++];
+            state.delta[in_i++] += l.delta[out_i++];
         }
-        if(l.joint) state.delta[in_i-l.coords-l.classes-l.joint] = latent_delta;
+        if(l.joint) state.delta[in_i-l.coords-l.classes-l.joint] += latent_delta;
     }
 }
 
@@ -198,7 +198,8 @@ void backward_detection_layer_gpu(detection_layer l, network_state state)
     cpu_state.truth = truth_cpu;
     cpu_state.delta = delta_cpu;
 
-    cuda_pull_array(state.input, in_cpu, l.batch*l.inputs);
+    cuda_pull_array(state.input, in_cpu,    l.batch*l.inputs);
+    cuda_pull_array(state.delta, delta_cpu, l.batch*l.inputs);
     cuda_pull_array(l.delta_gpu, l.delta, l.batch*outputs);
     backward_detection_layer(l, cpu_state);
     cuda_push_array(state.delta, delta_cpu, l.batch*l.inputs);
diff --git a/src/maxpool_layer.c b/src/maxpool_layer.c
index bc3aa68faae9304d6790738fe7ec71b5d47f086f..ef06175c12d2398cec391b2d3248eb237400bd5d 100644
--- a/src/maxpool_layer.c
+++ b/src/maxpool_layer.c
@@ -114,7 +114,6 @@ void backward_maxpool_layer(const maxpool_layer l, network_state state)
     int h = (l.h-1)/l.stride + 1;
     int w = (l.w-1)/l.stride + 1;
     int c = l.c;
-    memset(state.delta, 0, l.batch*l.h*l.w*l.c*sizeof(float));
     for(i = 0; i < h*w*c*l.batch; ++i){
         int index = l.indexes[i];
         state.delta[index] += l.delta[i];
diff --git a/src/maxpool_layer_kernels.cu b/src/maxpool_layer_kernels.cu
index 6c633a97c8bf2233235dec418c3b45f9ef0ce9c4..8f69f905d9a8e0410f2787af3e23c92a17917539 100644
--- a/src/maxpool_layer_kernels.cu
+++ b/src/maxpool_layer_kernels.cu
@@ -77,7 +77,7 @@ __global__ void backward_maxpool_layer_kernel(int n, int in_h, int in_w, int in_
             d += (valid && indexes[out_index] == index) ? delta[out_index] : 0;
         }
     }
-    prev_delta[index] = d;
+    prev_delta[index] += d;
 }
 
 extern "C" void forward_maxpool_layer_gpu(maxpool_layer layer, network_state state)
diff --git a/src/network.c b/src/network.c
index e85dfe9830bd428bd436d414fd6370ad7037ba1d..5b52da95e0276812427ae91b19f9c47ec078e0f0 100644
--- a/src/network.c
+++ b/src/network.c
@@ -68,6 +68,9 @@ void forward_network(network net, network_state state)
     int i;
     for(i = 0; i < net.n; ++i){
         layer l = net.layers[i];
+        if(l.delta){
+            scal_cpu(l.outputs * l.batch, 0, l.delta, 1);
+        }
         if(l.type == CONVOLUTIONAL){
             forward_convolutional_layer(l, state);
         } else if(l.type == DECONVOLUTIONAL){
diff --git a/src/network_kernels.cu b/src/network_kernels.cu
index 3a4f0bfa7f6d42787fba1aa95ef543db33db9f6c..6562590cead8c8c1a63e28faac17c0f8c2f442b0 100644
--- a/src/network_kernels.cu
+++ b/src/network_kernels.cu
@@ -21,6 +21,7 @@ extern "C" {
 #include "softmax_layer.h"
 #include "dropout_layer.h"
 #include "route_layer.h"
+#include "blas.h"
 }
 
 float * get_network_output_gpu_layer(network net, int i);
@@ -32,6 +33,9 @@ void forward_network_gpu(network net, network_state state)
     int i;
     for(i = 0; i < net.n; ++i){
         layer l = net.layers[i];
+        if(l.delta){
+            scal_ongpu(l.outputs * l.batch, 0, l.delta_gpu, 1);
+        }
         if(l.type == CONVOLUTIONAL){
             forward_convolutional_layer_gpu(l, state);
         } else if(l.type == DECONVOLUTIONAL){
diff --git a/src/normalization_layer.c b/src/normalization_layer.c
index d0805592a8f40cdd1eccc1581920c496a70d40c7..587ece78c820e88ec7607955d93314721193d9d4 100644
--- a/src/normalization_layer.c
+++ b/src/normalization_layer.c
@@ -90,6 +90,7 @@ void forward_normalization_layer(const layer layer, network_state state)
 void backward_normalization_layer(const layer layer, network_state state)
 {
     // TODO This is approximate ;-)
+    // Also this should add in to delta instead of overwritting.
 
     int w = layer.w;
     int h = layer.h;
diff --git a/src/route_layer.c b/src/route_layer.c
index e3802b7d83de6c0dfb814cb359587e3006761dd2..67b606c3f4463f629af17e52383ad267665d6fbf 100644
--- a/src/route_layer.c
+++ b/src/route_layer.c
@@ -54,7 +54,7 @@ void backward_route_layer(const route_layer l, network net)
         float *delta = net.layers[index].delta;
         int input_size = l.input_sizes[i];
         for(j = 0; j < l.batch; ++j){
-            copy_cpu(input_size, l.delta + offset + j*l.outputs, 1, delta + j*input_size, 1);
+            axpy_cpu(input_size, 1, l.delta + offset + j*l.outputs, 1, delta + j*input_size, 1);
         }
         offset += input_size;
     }
@@ -85,7 +85,7 @@ void backward_route_layer_gpu(const route_layer l, network net)
         float *delta = net.layers[index].delta_gpu;
         int input_size = l.input_sizes[i];
         for(j = 0; j < l.batch; ++j){
-            copy_ongpu(input_size, l.delta_gpu + offset + j*l.outputs, 1, delta + j*input_size, 1);
+            axpy_ongpu(input_size, 1, l.delta_gpu + offset + j*l.outputs, 1, delta + j*input_size, 1);
         }
         offset += input_size;
     }
diff --git a/src/softmax_layer.c b/src/softmax_layer.c
index ea22d0593ca75a14d7154f81eb27a41f4c5201df..0d19acad0f07055729c1be972b305ad313dd14e8 100644
--- a/src/softmax_layer.c
+++ b/src/softmax_layer.c
@@ -58,7 +58,7 @@ void backward_softmax_layer(const softmax_layer l, network_state state)
 {
     int i;
     for(i = 0; i < l.inputs*l.batch; ++i){
-        state.delta[i] = l.delta[i];
+        state.delta[i] += l.delta[i];
     }
 }
 
diff --git a/src/softmax_layer_kernels.cu b/src/softmax_layer_kernels.cu
index 0529f755e758e8440452dd2b7544c6800caa929b..8fbaf1981dbd44c74ac00f67ab9de2ad6f741bb9 100644
--- a/src/softmax_layer_kernels.cu
+++ b/src/softmax_layer_kernels.cu
@@ -42,7 +42,7 @@ extern "C" void forward_softmax_layer_gpu(const softmax_layer layer, network_sta
 
 extern "C" void backward_softmax_layer_gpu(const softmax_layer layer, network_state state)
 {
-    copy_ongpu(layer.batch*layer.inputs, layer.delta_gpu, 1, state.delta, 1);
+    axpy_ongpu(layer.batch*layer.inputs, 1, layer.delta_gpu, 1, state.delta, 1);
 }
 
 /* This is if you want softmax w/o log-loss classification. You probably don't.