diff --git a/cfg/yolo.cfg b/cfg/yolo.cfg
index 088edf81573e83c59edd7137cbc07b6fe1433591..2a0cd98fbd07c94aa0840c528a12b1b60a004928 100644
--- a/cfg/yolo.cfg
+++ b/cfg/yolo.cfg
@@ -5,8 +5,8 @@ subdivisions=1
 # Training
 # batch=64
 # subdivisions=8
-width=608
-height=608
+width=416
+height=416
 channels=3
 momentum=0.9
 decay=0.0005
diff --git a/src/classifier.c b/src/classifier.c
index 491e60b9e88f43f6cfae53bb63f8666c19a9abb6..039ad66820f4f748de7ae9796fd98e53af1d49bd 100644
--- a/src/classifier.c
+++ b/src/classifier.c
@@ -698,7 +698,7 @@ void predict_classifier(char *datacfg, char *cfgfile, char *weightfile, char *fi
         float *X = r.data;
         time=clock();
         float *predictions = network_predict(net, X);
-        if(net.hierarchy) hierarchy_predictions(predictions, net.outputs, net.hierarchy, 0, 1);
+        if(net.hierarchy) hierarchy_predictions(predictions, net.outputs, net.hierarchy, 1, 1);
         top_k(predictions, net.outputs, top, indexes);
         printf("%s: Predicted in %f seconds.\n", input, sec(clock()-time));
         for(i = 0; i < top; ++i){
diff --git a/src/coco.c b/src/coco.c
index 74fe3b2f8b57b56707a3f8a3119d2bfe9272a3c4..3bcb651816bfc164efb9a6862d5033be58f6f145 100644
--- a/src/coco.c
+++ b/src/coco.c
@@ -376,9 +376,10 @@ void run_coco(int argc, char **argv)
     char *cfg = argv[3];
     char *weights = (argc > 4) ? argv[4] : 0;
     char *filename = (argc > 5) ? argv[5]: 0;
+    int avg = find_int_arg(argc, argv, "-avg", 1);
     if(0==strcmp(argv[2], "test")) test_coco(cfg, weights, filename, thresh);
     else if(0==strcmp(argv[2], "train")) train_coco(cfg, weights);
     else if(0==strcmp(argv[2], "valid")) validate_coco(cfg, weights);
     else if(0==strcmp(argv[2], "recall")) validate_coco_recall(cfg, weights);
-    else if(0==strcmp(argv[2], "demo")) demo(cfg, weights, thresh, cam_index, filename, coco_classes, 80, frame_skip, prefix, .5, 0,0,0,0);
+    else if(0==strcmp(argv[2], "demo")) demo(cfg, weights, thresh, cam_index, filename, coco_classes, 80, frame_skip, prefix, avg, .5, 0,0,0,0);
 }
diff --git a/src/demo.c b/src/demo.c
index 4ab9716bdb791eb733c6d00508f36ac2506babd7..ff6c732962e3b4d2aa574d05d04233f0b9ed587b 100644
--- a/src/demo.c
+++ b/src/demo.c
@@ -9,7 +9,6 @@
 #include "demo.h"
 #include <sys/time.h>
 
-#define FRAMES 3
 #define DEMO 1
 
 #ifdef OPENCV
@@ -21,80 +20,125 @@ static int demo_classes;
 static float **probs;
 static box *boxes;
 static network net;
-static image in   ;
-static image in_s ;
-static image det  ;
-static image det_s;
-static image disp = {0};
+static image buff [3];
+static image buff_letter[3];
+static int buff_index = 0;
 static CvCapture * cap;
+static IplImage  * ipl;
 static float fps = 0;
 static float demo_thresh = 0;
 static float demo_hier = .5;
+static int running = 0;
 
-static float *predictions[FRAMES];
+static int demo_delay = 0;
+static int demo_frame = 5;
+static int demo_detections = 0;
+static float **predictions;
 static int demo_index = 0;
-static image images[FRAMES];
+static int demo_done = 0;
+static float *last_avg2;
+static float *last_avg;
 static float *avg;
+double demo_time;
 
-void *fetch_in_thread(void *ptr)
+double get_wall_time()
 {
-    in = get_image_from_stream(cap);
-    if(!in.data){
-        error("Stream closed.");
+    struct timeval time;
+    if (gettimeofday(&time,NULL)){
+        return 0;
     }
-    in_s = letterbox_image(in, net.w, net.h);
-    return 0;
+    return (double)time.tv_sec + (double)time.tv_usec * .000001;
 }
 
 void *detect_in_thread(void *ptr)
 {
+    running = 1;
     float nms = .4;
 
     layer l = net.layers[net.n-1];
-    float *X = det_s.data;
+    float *X = buff_letter[(buff_index+2)%3].data;
     float *prediction = network_predict(net, X);
 
     memcpy(predictions[demo_index], prediction, l.outputs*sizeof(float));
-    mean_arrays(predictions, FRAMES, l.outputs, avg);
-    l.output = avg;
-
-    free_image(det_s);
+    mean_arrays(predictions, demo_frame, l.outputs, avg);
+    l.output = last_avg2;
+    if(demo_delay == 0) l.output = avg;
     if(l.type == DETECTION){
         get_detection_boxes(l, 1, 1, demo_thresh, probs, boxes, 0);
     } else if (l.type == REGION){
-        get_region_boxes(l, in.w, in.h, net.w, net.h, demo_thresh, probs, boxes, 0, 0, demo_hier, 1);
+        get_region_boxes(l, buff[0].w, buff[0].h, net.w, net.h, demo_thresh, probs, boxes, 0, 0, demo_hier, 1);
     } else {
         error("Last layer must produce detections\n");
     }
     if (nms > 0) do_nms_obj(boxes, probs, l.w*l.h*l.n, l.classes, nms);
+
     printf("\033[2J");
     printf("\033[1;1H");
     printf("\nFPS:%.1f\n",fps);
     printf("Objects:\n\n");
+    image display = buff[(buff_index+2) % 3];
+    draw_detections(display, demo_detections, demo_thresh, boxes, probs, demo_names, demo_alphabet, demo_classes);
 
-    images[demo_index] = det;
-    det = images[(demo_index + FRAMES/2 + 1)%FRAMES];
-    demo_index = (demo_index + 1)%FRAMES;
-
-    draw_detections(det, l.w*l.h*l.n, demo_thresh, boxes, probs, demo_names, demo_alphabet, demo_classes);
+    demo_index = (demo_index + 1)%demo_frame;
+    running = 0;
+    return 0;
+}
 
+void *fetch_in_thread(void *ptr)
+{
+    int status = fill_image_from_stream(cap, buff[buff_index]);
+    letterbox_image_into(buff[buff_index], net.w, net.h, buff_letter[buff_index]);
+    if(status == 0) demo_done = 1;
     return 0;
 }
 
-double get_wall_time()
+void *display_in_thread(void *ptr)
 {
-    struct timeval time;
-    if (gettimeofday(&time,NULL)){
+    show_image_cv(buff[(buff_index + 1)%3], "Demo", ipl);
+    int c = cvWaitKey(1);
+    if (c != -1) c = c%256;
+    if (c == 10){
+        if(demo_delay == 0) demo_delay = 60;
+        else if(demo_delay == 5) demo_delay = 0;
+        else if(demo_delay == 60) demo_delay = 5;
+        else demo_delay = 0;
+    } else if (c == 27) {
+        demo_done = 1;
         return 0;
+    } else if (c == 82) {
+        demo_thresh += .02;
+    } else if (c == 84) {
+        demo_thresh -= .02;
+        if(demo_thresh <= .02) demo_thresh = .02;
+    } else if (c == 83) {
+        demo_hier += .02;
+    } else if (c == 81) {
+        demo_hier -= .02;
+        if(demo_hier <= .0) demo_hier = .0;
+    }
+    return 0;
+}
+
+void *display_loop(void *ptr)
+{
+    while(1){
+        display_in_thread(0);
     }
-    return (double)time.tv_sec + (double)time.tv_usec * .000001;
 }
 
-void demo(char *cfgfile, char *weightfile, float thresh, int cam_index, const char *filename, char **names, int classes, int frame_skip, char *prefix, float hier, int w, int h, int frames, int fullscreen)
+void *detect_loop(void *ptr)
 {
-    //skip = frame_skip;
+    while(1){
+        detect_in_thread(0);
+    }
+}
+
+void demo(char *cfgfile, char *weightfile, float thresh, int cam_index, const char *filename, char **names, int classes, int delay, char *prefix, int avg_frames, float hier, int w, int h, int frames, int fullscreen)
+{
+    demo_delay = delay;
+    demo_frame = avg_frames;
+    predictions = calloc(demo_frame, sizeof(float*));
     image **alphabet = load_alphabet();
-    int delay = frame_skip;
     demo_names = names;
     demo_alphabet = alphabet;
     demo_classes = classes;
@@ -106,6 +150,8 @@ void demo(char *cfgfile, char *weightfile, float thresh, int cam_index, const ch
         load_weights(&net, weightfile);
     }
     set_batch_network(&net, 1);
+    pthread_t detect_thread;
+    pthread_t fetch_thread;
 
     srand(2222222);
 
@@ -129,36 +175,25 @@ void demo(char *cfgfile, char *weightfile, float thresh, int cam_index, const ch
     if(!cap) error("Couldn't connect to webcam.\n");
 
     layer l = net.layers[net.n-1];
+    demo_detections = l.n*l.w*l.h;
     int j;
 
     avg = (float *) calloc(l.outputs, sizeof(float));
-    for(j = 0; j < FRAMES; ++j) predictions[j] = (float *) calloc(l.outputs, sizeof(float));
-    for(j = 0; j < FRAMES; ++j) images[j] = make_image(1,1,3);
+    last_avg  = (float *) calloc(l.outputs, sizeof(float));
+    last_avg2 = (float *) calloc(l.outputs, sizeof(float));
+    for(j = 0; j < demo_frame; ++j) predictions[j] = (float *) calloc(l.outputs, sizeof(float));
 
     boxes = (box *)calloc(l.w*l.h*l.n, sizeof(box));
     probs = (float **)calloc(l.w*l.h*l.n, sizeof(float *));
-    for(j = 0; j < l.w*l.h*l.n; ++j) probs[j] = (float *)calloc(l.classes, sizeof(float));
-
-    pthread_t fetch_thread;
-    pthread_t detect_thread;
-
-    fetch_in_thread(0);
-    det = in;
-    det_s = in_s;
+    for(j = 0; j < l.w*l.h*l.n; ++j) probs[j] = (float *)calloc(l.classes+1, sizeof(float));
 
-    fetch_in_thread(0);
-    detect_in_thread(0);
-    disp = det;
-    det = in;
-    det_s = in_s;
-
-    for(j = 0; j < FRAMES/2; ++j){
-        fetch_in_thread(0);
-        detect_in_thread(0);
-        disp = det;
-        det = in;
-        det_s = in_s;
-    }
+    buff[0] = get_image_from_stream(cap);
+    buff[1] = copy_image(buff[0]);
+    buff[2] = copy_image(buff[0]);
+    buff_letter[0] = letterbox_image(buff[0], net.w, net.h);
+    buff_letter[1] = letterbox_image(buff[0], net.w, net.h);
+    buff_letter[2] = letterbox_image(buff[0], net.w, net.h);
+    ipl = cvCreateImage(cvSize(buff[0].w,buff[0].h), IPL_DEPTH_8U, buff[0].c);
 
     int count = 0;
     if(!prefix){
@@ -171,76 +206,34 @@ void demo(char *cfgfile, char *weightfile, float thresh, int cam_index, const ch
         }
     }
 
-    double before = get_wall_time();
-
-    while(1){
-        ++count;
-        if(1){
-            if(pthread_create(&fetch_thread, 0, fetch_in_thread, 0)) error("Thread creation failed");
-            if(pthread_create(&detect_thread, 0, detect_in_thread, 0)) error("Thread creation failed");
-
-            if(!prefix){
-                show_image(disp, "Demo");
-                int c = cvWaitKey(1);
-		if (c != -1) c = c%256;
-                if (c == 10){
-                    if(frame_skip == 0) frame_skip = 60;
-                    else if(frame_skip == 4) frame_skip = 0;
-                    else if(frame_skip == 60) frame_skip = 4;   
-                    else frame_skip = 0;
-                } else if (c == 27) {
-                    return;
-                } else if (c == 82) {
-                    demo_thresh += .02;
-                } else if (c == 84) {
-                    demo_thresh -= .02;
-                    if(demo_thresh <= .02) demo_thresh = .02;
-                } else if (c == 83) {
-                    demo_hier += .02;
-                } else if (c == 81) {
-                    demo_hier -= .02;
-                    if(demo_hier <= .0) demo_hier = .0;
-                }
-            }else{
-                char buff[256];
-                sprintf(buff, "%s_%08d", prefix, count);
-                save_image(disp, buff);
-            }
-
-            pthread_join(fetch_thread, 0);
-            pthread_join(detect_thread, 0);
-
-            if(delay == 0){
-                free_image(disp);
-                disp  = det;
+    demo_time = get_wall_time();
+
+    while(!demo_done){
+        buff_index = (buff_index + 1) %3;
+        if(pthread_create(&fetch_thread, 0, fetch_in_thread, 0)) error("Thread creation failed");
+        if(pthread_create(&detect_thread, 0, detect_in_thread, 0)) error("Thread creation failed");
+        if(!prefix){
+            if(count % (demo_delay+1) == 0){
+                fps = 1./(get_wall_time() - demo_time);
+                demo_time = get_wall_time();
+                float *swap = last_avg;
+                last_avg  = last_avg2;
+                last_avg2 = swap;
+                memcpy(last_avg, avg, l.outputs*sizeof(float));
             }
-            det   = in;
-            det_s = in_s;
-        }else {
-            fetch_in_thread(0);
-            det   = in;
-            det_s = in_s;
-            detect_in_thread(0);
-            if(delay == 0) {
-                free_image(disp);
-                disp = det;
-            }
-            show_image(disp, "Demo");
-            cvWaitKey(1);
-        }
-        --delay;
-        if(delay < 0){
-            delay = frame_skip;
-
-            double after = get_wall_time();
-            float curr = 1./(after - before);
-            fps = curr;
-            before = after;
+            display_in_thread(0);
+        }else{
+            char name[256];
+            sprintf(name, "%s_%08d", prefix, count);
+            save_image(buff[(buff_index + 1)%3], name);
         }
+        pthread_join(fetch_thread, 0);
+        pthread_join(detect_thread, 0);
+        ++count;
     }
 }
 #else
-void demo(char *cfgfile, char *weightfile, float thresh, int cam_index, const char *filename, char **names, int classes, int frame_skip, char *prefix, float hier_thresh, int w, int h, int fps, int fullscreen)
+void demo(char *cfgfile, char *weightfile, float thresh, int cam_index, const char *filename, char **names, int classes, int delay, char *prefix, int avg, float hier, int w, int h, int frames, int fullscreen)
 {
     fprintf(stderr, "Demo needs OpenCV for webcam images.\n");
 }
diff --git a/src/demo.h b/src/demo.h
index 2c64a4692c087951f847126812db416015c86890..d920759c12ab08c2f517cd6ca6d30032d2daffe3 100644
--- a/src/demo.h
+++ b/src/demo.h
@@ -2,6 +2,6 @@
 #define DEMO_H
 
 #include "image.h"
-void demo(char *cfgfile, char *weightfile, float thresh, int cam_index, const char *filename, char **names, int classes, int frame_skip, char *prefix, float hier_thresh, int w, int h, int fps, int fullscreen);
+void demo(char *cfgfile, char *weightfile, float thresh, int cam_index, const char *filename, char **names, int classes, int frame_skip, char *prefix, int avg, float hier_thresh, int w, int h, int fps, int fullscreen);
 
 #endif
diff --git a/src/detection_layer.c b/src/detection_layer.c
index f9b4e4e73ba0d65db1c4e5b6db0dfc5c4d6b7c72..fd79cc703411e162fcda2fc4e17ce768620d90b0 100644
--- a/src/detection_layer.c
+++ b/src/detection_layer.c
@@ -259,8 +259,8 @@ void forward_detection_layer_gpu(const detection_layer l, network net)
         return;
     }
 
-    float *in_cpu = calloc(l.batch*l.inputs, sizeof(float));
-    float *truth_cpu = 0;
+    //float *in_cpu = calloc(l.batch*l.inputs, sizeof(float));
+    //float *truth_cpu = 0;
 
     forward_detection_layer(l, net);
     cuda_push_array(l.output_gpu, l.output, l.batch*l.outputs);
diff --git a/src/detector.c b/src/detector.c
index b69d21ff568140790936503ebfd8624c3bf59d23..c205f944d4b03fd50744400ca5334ea74f815a94 100644
--- a/src/detector.c
+++ b/src/detector.c
@@ -292,7 +292,7 @@ void validate_detector_flip(char *datacfg, char *cfgfile, char *weightfile, char
 
     box *boxes = calloc(l.w*l.h*l.n, sizeof(box));
     float **probs = calloc(l.w*l.h*l.n, sizeof(float *));
-    for(j = 0; j < l.w*l.h*l.n; ++j) probs[j] = calloc(classes, sizeof(float *));
+    for(j = 0; j < l.w*l.h*l.n; ++j) probs[j] = calloc(classes+1, sizeof(float *));
 
     int m = plist->size;
     int i=0;
@@ -428,7 +428,7 @@ void validate_detector(char *datacfg, char *cfgfile, char *weightfile, char *out
 
     box *boxes = calloc(l.w*l.h*l.n, sizeof(box));
     float **probs = calloc(l.w*l.h*l.n, sizeof(float *));
-    for(j = 0; j < l.w*l.h*l.n; ++j) probs[j] = calloc(classes, sizeof(float *));
+    for(j = 0; j < l.w*l.h*l.n; ++j) probs[j] = calloc(classes+1, sizeof(float *));
 
     int m = plist->size;
     int i=0;
@@ -521,7 +521,7 @@ void validate_detector_recall(char *cfgfile, char *weightfile)
     int j, k;
     box *boxes = calloc(l.w*l.h*l.n, sizeof(box));
     float **probs = calloc(l.w*l.h*l.n, sizeof(float *));
-    for(j = 0; j < l.w*l.h*l.n; ++j) probs[j] = calloc(classes, sizeof(float *));
+    for(j = 0; j < l.w*l.h*l.n; ++j) probs[j] = calloc(classes+1, sizeof(float *));
 
     int m = plist->size;
     int i=0;
@@ -659,6 +659,7 @@ void run_detector(int argc, char **argv)
     float hier_thresh = find_float_arg(argc, argv, "-hier", .5);
     int cam_index = find_int_arg(argc, argv, "-c", 0);
     int frame_skip = find_int_arg(argc, argv, "-s", 0);
+    int avg = find_int_arg(argc, argv, "-avg", 3);
     if(argc < 4){
         fprintf(stderr, "usage: %s %s [train/test/valid] [cfg] [weights (optional)]\n", argv[0], argv[1]);
         return;
@@ -707,6 +708,6 @@ void run_detector(int argc, char **argv)
         int classes = option_find_int(options, "classes", 20);
         char *name_list = option_find_str(options, "names", "data/names.list");
         char **names = get_labels(name_list);
-        demo(cfg, weights, thresh, cam_index, filename, names, classes, frame_skip, prefix, hier_thresh, width, height, fps, fullscreen);
+        demo(cfg, weights, thresh, cam_index, filename, names, classes, frame_skip, prefix, avg, hier_thresh, width, height, fps, fullscreen);
     }
 }
diff --git a/src/image.c b/src/image.c
index a5cc135c33af7b50fdcc8fd7c737f29d88ca5696..e1d19442acca7cbb1c6672841b0cfaf2b2dc4ac7 100644
--- a/src/image.c
+++ b/src/image.c
@@ -216,6 +216,7 @@ void draw_detections(image im, int num, float thresh, box *boxes, float **probs,
             if (alphabet) {
                 image label = get_label(alphabet, names[class], (im.h*.03)/10);
                 draw_label(im, top + width, left, label, rgb);
+                free_image(label);
             }
         }
     }
@@ -394,6 +395,11 @@ void normalize_image2(image p)
     free(max);
 }
 
+void copy_image_into(image src, image dest)
+{
+    memcpy(dest.data, src.data, src.h*src.w*src.c*sizeof(float));
+}
+
 image copy_image(image p)
 {
     image copy = p;
@@ -413,19 +419,16 @@ void rgbgr_image(image im)
 }
 
 #ifdef OPENCV
-void show_image_cv(image p, const char *name)
+void show_image_cv(image p, const char *name, IplImage *disp)
 {
     int x,y,k;
-    image copy = copy_image(p);
-    constrain_image(copy);
-    if(p.c == 3) rgbgr_image(copy);
+    if(p.c == 3) rgbgr_image(p);
     //normalize_image(copy);
 
     char buff[256];
     //sprintf(buff, "%s (%d)", name, windows);
     sprintf(buff, "%s", name);
 
-    IplImage *disp = cvCreateImage(cvSize(p.w,p.h), IPL_DEPTH_8U, p.c);
     int step = disp->widthStep;
     cvNamedWindow(buff, CV_WINDOW_NORMAL); 
     //cvMoveWindow(buff, 100*(windows%10) + 200*(windows/10), 100*(windows%10));
@@ -433,11 +436,10 @@ void show_image_cv(image p, const char *name)
     for(y = 0; y < p.h; ++y){
         for(x = 0; x < p.w; ++x){
             for(k= 0; k < p.c; ++k){
-                disp->imageData[y*step + x*p.c + k] = (unsigned char)(get_pixel(copy,x,y,k)*255);
+                disp->imageData[y*step + x*p.c + k] = (unsigned char)(get_pixel(p,x,y,k)*255);
             }
         }
     }
-    free_image(copy);
     if(0){
         int w = 448;
         int h = w*p.h/p.w;
@@ -451,14 +453,18 @@ void show_image_cv(image p, const char *name)
         cvReleaseImage(&buffer);
     }
     cvShowImage(buff, disp);
-    cvReleaseImage(&disp);
 }
 #endif
 
 void show_image(image p, const char *name)
 {
 #ifdef OPENCV
-    show_image_cv(p, name);
+    IplImage *disp = cvCreateImage(cvSize(p.w,p.h), IPL_DEPTH_8U, p.c);
+    image copy = copy_image(p);
+    constrain_image(copy);
+    show_image_cv(copy, name, disp);
+    free_image(copy);
+    cvReleaseImage(&disp);
 #else
     fprintf(stderr, "Not compiled with OpenCV, saving to %s.png instead\n", name);
     save_image(p, name);
@@ -467,23 +473,31 @@ void show_image(image p, const char *name)
 
 #ifdef OPENCV
 
-image ipl_to_image(IplImage* src)
+void ipl_into_image(IplImage* src, image im)
 {
     unsigned char *data = (unsigned char *)src->imageData;
     int h = src->height;
     int w = src->width;
     int c = src->nChannels;
     int step = src->widthStep;
-    image out = make_image(w, h, c);
-    int i, j, k, count=0;;
+    int i, j, k;
 
-    for(k= 0; k < c; ++k){
-        for(i = 0; i < h; ++i){
+    for(i = 0; i < h; ++i){
+        for(k= 0; k < c; ++k){
             for(j = 0; j < w; ++j){
-                out.data[count++] = data[i*step + j*c + k]/255.;
+                im.data[k*w*h + i*w + j] = data[i*step + j*c + k]/255.;
             }
         }
     }
+}
+
+image ipl_to_image(IplImage* src)
+{
+    int h = src->height;
+    int w = src->width;
+    int c = src->nChannels;
+    image out = make_image(w, h, c);
+    ipl_into_image(src, out);
     return out;
 }
 
@@ -513,6 +527,14 @@ image load_image_cv(char *filename, int channels)
     return out;
 }
 
+void flush_stream_buffer(CvCapture *cap, int n)
+{
+    int i;
+    for(i = 0; i < n; ++i) {
+        cvQueryFrame(cap);
+    }
+}
+
 image get_image_from_stream(CvCapture *cap)
 {
     IplImage* src = cvQueryFrame(cap);
@@ -522,6 +544,15 @@ image get_image_from_stream(CvCapture *cap)
     return im;
 }
 
+int fill_image_from_stream(CvCapture *cap, image im)
+{
+    IplImage* src = cvQueryFrame(cap);
+    if (!src) return 0;
+    ipl_into_image(src, im);
+    rgbgr_image(im);
+    return 1;
+}
+
 void save_image_jpg(image p, const char *name)
 {
     image copy = copy_image(p);
@@ -794,6 +825,22 @@ void composite_3d(char *f1, char *f2, char *out, int delta)
 #endif
 }
 
+void letterbox_image_into(image im, int w, int h, image boxed)
+{
+    int new_w = im.w;
+    int new_h = im.h;
+    if (((float)w/im.w) < ((float)h/im.h)) {
+        new_w = w;
+        new_h = (im.h * w)/im.w;
+    } else {
+        new_h = h;
+        new_w = (im.w * h)/im.h;
+    }
+    image resized = resize_image(im, new_w, new_h);
+    embed_image(resized, boxed, (w-new_w)/2, (h-new_h)/2); 
+    free_image(resized);
+}
+
 image letterbox_image(image im, int w, int h)
 {
     int new_w = im.w;
diff --git a/src/image.h b/src/image.h
index fd4ca414d0d992938f9e0572af82696e15497429..abe99d6cda6783018404548466ad5e712bc3f86f 100644
--- a/src/image.h
+++ b/src/image.h
@@ -29,7 +29,11 @@ typedef struct {
 #ifndef __cplusplus
 #ifdef OPENCV
 image get_image_from_stream(CvCapture *cap);
+int fill_image_from_stream(CvCapture *cap, image im);
 image ipl_to_image(IplImage* src);
+void ipl_into_image(IplImage* src, image im);
+void flush_stream_buffer(CvCapture *cap, int n);
+void show_image_cv(image p, const char *name, IplImage *disp);
 #endif
 #endif
 
@@ -49,6 +53,7 @@ image random_crop_image(image im, int w, int h);
 image random_augment_image(image im, float angle, float aspect, int low, int high, int size);
 void random_distort_image(image im, float hue, float saturation, float exposure);
 image letterbox_image(image im, int w, int h);
+void letterbox_image_into(image im, int w, int h, image boxed);
 image resize_image(image im, int w, int h);
 image resize_min(image im, int min);
 image resize_max(image im, int max);
@@ -96,6 +101,7 @@ image make_random_image(int w, int h, int c);
 image make_empty_image(int w, int h, int c);
 image float_to_image(int w, int h, int c, float *data);
 image copy_image(image p);
+void copy_image_into(image src, image dest);
 image load_image(char *filename, int w, int h, int c);
 image load_image_color(char *filename, int w, int h);
 image **load_alphabet();
diff --git a/src/region_layer.c b/src/region_layer.c
index e9f18b526eab8c314e7b793db67d85572ba38b94..a9a31208006d23a39a7979968a44f745fcf56cf0 100644
--- a/src/region_layer.c
+++ b/src/region_layer.c
@@ -406,7 +406,7 @@ void get_region_boxes(layer l, int w, int h, int netw, int neth, float thresh, f
                     probs[index][j] = (prob > thresh) ? prob : 0;
                     if(prob > max) max = prob;
                     // TODO REMOVE
-                    // if (j != 15 && j != 16) probs[index][j] = 0; 
+                    // if (j == 56 ) probs[index][j] = 0; 
                     /*
                        if (j != 0) probs[index][j] = 0; 
                        int blacklist[] = {121, 497, 482, 504, 122, 518,481, 418, 542, 491, 914, 478, 120, 510,500};
diff --git a/src/yolo.c b/src/yolo.c
index 48d63dfb912b8463291a630fa28d39716dba3d7d..6ddfb6ef52249eece633628bf87c38e3361c543b 100644
--- a/src/yolo.c
+++ b/src/yolo.c
@@ -340,6 +340,7 @@ void run_yolo(int argc, char **argv)
         return;
     }
 
+    int avg = find_int_arg(argc, argv, "-avg", 1);
     char *cfg = argv[3];
     char *weights = (argc > 4) ? argv[4] : 0;
     char *filename = (argc > 5) ? argv[5]: 0;
@@ -347,5 +348,5 @@ void run_yolo(int argc, char **argv)
     else if(0==strcmp(argv[2], "train")) train_yolo(cfg, weights);
     else if(0==strcmp(argv[2], "valid")) validate_yolo(cfg, weights);
     else if(0==strcmp(argv[2], "recall")) validate_yolo_recall(cfg, weights);
-    else if(0==strcmp(argv[2], "demo")) demo(cfg, weights, thresh, cam_index, filename, voc_names, 20, frame_skip, prefix, .5, 0,0,0,0);
+    else if(0==strcmp(argv[2], "demo")) demo(cfg, weights, thresh, cam_index, filename, voc_names, 20, frame_skip, prefix, avg, .5, 0,0,0,0);
 }