From c40cdeb4021fc1a638969563972f13c9f5e90d74 Mon Sep 17 00:00:00 2001 From: Joseph Redmon <pjreddie@gmail.com> Date: Fri, 9 Oct 2015 12:50:43 -0700 Subject: [PATCH] lots of comparator stuff --- Makefile | 2 +- cfg/darknet.cfg | 1 + src/coco.c | 1 + src/compare.c | 110 +++++++++++++++++++++++++++----------- src/convolutional_layer.c | 2 +- src/darknet.c | 3 ++ src/data.c | 8 +-- src/data.h | 1 + src/dice.c | 2 +- src/imagenet.c | 2 +- src/layer.h | 4 ++ src/network.c | 6 +-- src/network.h | 2 +- src/option_list.c | 18 +++++++ src/option_list.h | 1 + src/parser.c | 23 ++------ src/region_layer.c | 38 ++++++++++++- src/swag.c | 99 +++++++++++++++++++++++++++++++++- 18 files changed, 258 insertions(+), 65 deletions(-) diff --git a/Makefile b/Makefile index 22e89a1..26c4076 100644 --- a/Makefile +++ b/Makefile @@ -34,7 +34,7 @@ CFLAGS+= -DGPU LDFLAGS+= -L/usr/local/cuda/lib64 -lcuda -lcudart -lcublas -lcurand endif -OBJ=gemm.o utils.o cuda.o deconvolutional_layer.o convolutional_layer.o list.o image.o activations.o im2col.o col2im.o blas.o crop_layer.o dropout_layer.o maxpool_layer.o softmax_layer.o data.o matrix.o network.o connected_layer.o cost_layer.o parser.o option_list.o darknet.o detection_layer.o imagenet.o captcha.o route_layer.o writing.o box.o nightmare.o normalization_layer.o avgpool_layer.o coco.o dice.o yolo.o region_layer.o layer.o compare.o swag.o +OBJ=gemm.o utils.o cuda.o deconvolutional_layer.o convolutional_layer.o list.o image.o activations.o im2col.o col2im.o blas.o crop_layer.o dropout_layer.o maxpool_layer.o softmax_layer.o data.o matrix.o network.o connected_layer.o cost_layer.o parser.o option_list.o darknet.o detection_layer.o imagenet.o captcha.o route_layer.o writing.o box.o nightmare.o normalization_layer.o avgpool_layer.o coco.o dice.o yolo.o region_layer.o layer.o compare.o swag.o classifier.o ifeq ($(GPU), 1) OBJ+=convolutional_kernels.o deconvolutional_kernels.o activation_kernels.o im2col_kernels.o col2im_kernels.o blas_kernels.o crop_layer_kernels.o dropout_layer_kernels.o maxpool_layer_kernels.o softmax_layer_kernels.o network_kernels.o avgpool_layer_kernels.o endif diff --git a/cfg/darknet.cfg b/cfg/darknet.cfg index 0b3c46c..00e9c36 100644 --- a/cfg/darknet.cfg +++ b/cfg/darknet.cfg @@ -104,6 +104,7 @@ output=1000 activation=leaky [softmax] +groups=1 [cost] type=sse diff --git a/src/coco.c b/src/coco.c index c016548..f6b135f 100644 --- a/src/coco.c +++ b/src/coco.c @@ -135,6 +135,7 @@ void get_probs(float *predictions, int total, int classes, int inc, float **prob } } } + void get_boxes(float *predictions, int n, int num_boxes, int per_box, box *boxes) { int i,j; diff --git a/src/compare.c b/src/compare.c index 74c1cf5..76e0b60 100644 --- a/src/compare.c +++ b/src/compare.c @@ -150,17 +150,20 @@ typedef struct { network net; char *filename; int class; + int classes; float elo; + float *elos; } sortable_bbox; int total_compares = 0; +int current_class = 0; int elo_comparator(const void*a, const void *b) { sortable_bbox box1 = *(sortable_bbox*)a; sortable_bbox box2 = *(sortable_bbox*)b; - if(box1.elo == box2.elo) return 0; - if(box1.elo > box2.elo) return -1; + if(box1.elos[current_class] == box2.elos[current_class]) return 0; + if(box1.elos[current_class] > box2.elos[current_class]) return -1; return 1; } @@ -188,16 +191,38 @@ int bbox_comparator(const void *a, const void *b) return -1; } -void bbox_fight(sortable_bbox *a, sortable_bbox *b) +void bbox_update(sortable_bbox *a, sortable_bbox *b, int class, int result) { int k = 32; - int result = bbox_comparator(a,b); - float EA = 1./(1+pow(10, (b->elo - a->elo)/400.)); - float EB = 1./(1+pow(10, (a->elo - b->elo)/400.)); - float SA = 1.*(result > 0); - float SB = 1.*(result < 0); - a->elo = a->elo + k*(SA - EA); - b->elo = b->elo + k*(SB - EB); + float EA = 1./(1+pow(10, (b->elos[class] - a->elos[class])/400.)); + float EB = 1./(1+pow(10, (a->elos[class] - b->elos[class])/400.)); + float SA = result ? 1 : 0; + float SB = result ? 0 : 1; + a->elos[class] += k*(SA - EA); + b->elos[class] += k*(SB - EB); +} + +void bbox_fight(network net, sortable_bbox *a, sortable_bbox *b, int classes, int class) +{ + image im1 = load_image_color(a->filename, net.w, net.h); + image im2 = load_image_color(b->filename, net.w, net.h); + float *X = calloc(net.w*net.h*net.c, sizeof(float)); + memcpy(X, im1.data, im1.w*im1.h*im1.c*sizeof(float)); + memcpy(X+im1.w*im1.h*im1.c, im2.data, im2.w*im2.h*im2.c*sizeof(float)); + float *predictions = network_predict(net, X); + ++total_compares; + + int i; + for(i = 0; i < classes; ++i){ + if(class < 0 || class == i){ + int result = predictions[i*2] > predictions[i*2+1]; + bbox_update(a, b, i, result); + } + } + + free_image(im1); + free_image(im2); + free(X); } void SortMaster3000(char *filename, char *weightfile) @@ -233,7 +258,8 @@ void SortMaster3000(char *filename, char *weightfile) void BattleRoyaleWithCheese(char *filename, char *weightfile) { - int i = 0; + int classes = 20; + int i,j; network net = parse_network_cfg(filename); if(weightfile){ load_weights(&net, weightfile); @@ -241,47 +267,67 @@ void BattleRoyaleWithCheese(char *filename, char *weightfile) srand(time(0)); set_batch_network(&net, 1); - //list *plist = get_paths("data/compare.sort.list"); - list *plist = get_paths("data/compare.cat.list"); + list *plist = get_paths("data/compare.sort.list"); + //list *plist = get_paths("data/compare.small.list"); + //list *plist = get_paths("data/compare.cat.list"); //list *plist = get_paths("data/compare.val.old"); char **paths = (char **)list_to_array(plist); int N = plist->size; + int total = N; free_list(plist); sortable_bbox *boxes = calloc(N, sizeof(sortable_bbox)); printf("Battling %d boxes...\n", N); for(i = 0; i < N; ++i){ boxes[i].filename = paths[i]; boxes[i].net = net; - boxes[i].class = 7; - boxes[i].elo = 1500; + boxes[i].classes = classes; + boxes[i].elos = calloc(classes, sizeof(float));; + for(j = 0; j < classes; ++j){ + boxes[i].elos[j] = 1500; + } } int round; clock_t time=clock(); - for(round = 1; round <= 500; ++round){ + for(round = 1; round <= 4; ++round){ clock_t round_time=clock(); printf("Round: %d\n", round); - qsort(boxes, N, sizeof(sortable_bbox), elo_comparator); - sorta_shuffle(boxes, N, sizeof(sortable_bbox), 10); shuffle(boxes, N, sizeof(sortable_bbox)); for(i = 0; i < N/2; ++i){ - bbox_fight(boxes+i*2, boxes+i*2+1); - } - if(round >= 4 && 0){ - qsort(boxes, N, sizeof(sortable_bbox), elo_comparator); - if(round == 4){ - N = N/2; - }else{ - N = (N*9/10)/2*2; - } + bbox_fight(net, boxes+i*2, boxes+i*2+1, classes, -1); } printf("Round: %f secs, %d remaining\n", sec(clock()-round_time), N); } - qsort(boxes, N, sizeof(sortable_bbox), elo_comparator); - FILE *outfp = fopen("results/battle.log", "w"); - for(i = 0; i < N; ++i){ - fprintf(outfp, "%s %f\n", boxes[i].filename, boxes[i].elo); + + int class; + + for (class = 0; class < classes; ++class){ + + N = total; + current_class = class; + qsort(boxes, N, sizeof(sortable_bbox), elo_comparator); + N /= 2; + + for(round = 1; round <= 20; ++round){ + clock_t round_time=clock(); + printf("Round: %d\n", round); + + sorta_shuffle(boxes, N, sizeof(sortable_bbox), 10); + for(i = 0; i < N/2; ++i){ + bbox_fight(net, boxes+i*2, boxes+i*2+1, classes, class); + } + qsort(boxes, N, sizeof(sortable_bbox), elo_comparator); + N = (N*9/10)/2*2; + + printf("Round: %f secs, %d remaining\n", sec(clock()-round_time), N); + } + char buff[256]; + sprintf(buff, "results/battle_%d.log", class); + FILE *outfp = fopen(buff, "w"); + for(i = 0; i < N; ++i){ + fprintf(outfp, "%s %f\n", boxes[i].filename, boxes[i].elos[class]); + } + fclose(outfp); } - fclose(outfp); printf("Tournament in %d compares, %f secs\n", total_compares, sec(clock()-time)); } diff --git a/src/convolutional_layer.c b/src/convolutional_layer.c index 6e3f38b..f3609ea 100644 --- a/src/convolutional_layer.c +++ b/src/convolutional_layer.c @@ -61,7 +61,7 @@ convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int l.biases = calloc(n, sizeof(float)); l.bias_updates = calloc(n, sizeof(float)); - //float scale = 1./sqrt(size*size*c); + // float scale = 1./sqrt(size*size*c); float scale = sqrt(2./(size*size*c)); for(i = 0; i < c*n*size*size; ++i) l.filters[i] = 2*scale*rand_uniform() - scale; for(i = 0; i < n; ++i){ diff --git a/src/darknet.c b/src/darknet.c index 9632f91..073156b 100644 --- a/src/darknet.c +++ b/src/darknet.c @@ -20,6 +20,7 @@ extern void run_captcha(int argc, char **argv); extern void run_nightmare(int argc, char **argv); extern void run_dice(int argc, char **argv); extern void run_compare(int argc, char **argv); +extern void run_classifier(int argc, char **argv); void change_rate(char *filename, float scale, float add) { @@ -183,6 +184,8 @@ int main(int argc, char **argv) run_swag(argc, argv); } else if (0 == strcmp(argv[1], "coco")){ run_coco(argc, argv); + } else if (0 == strcmp(argv[1], "classifier")){ + run_classifier(argc, argv); } else if (0 == strcmp(argv[1], "compare")){ run_compare(argc, argv); } else if (0 == strcmp(argv[1], "dice")){ diff --git a/src/data.c b/src/data.c index 2853d72..92c3d95 100644 --- a/src/data.c +++ b/src/data.c @@ -366,7 +366,7 @@ void free_data(data d) } } -data load_data_region(int n, char **paths, int m, int w, int h, int size, int classes) +data load_data_region(int n, char **paths, int m, int w, int h, int size, int classes, float jitter) { char **random_paths = get_random_paths(paths, n, m); int i; @@ -385,8 +385,8 @@ data load_data_region(int n, char **paths, int m, int w, int h, int size, int cl int oh = orig.h; int ow = orig.w; - int dw = ow/10; - int dh = oh/10; + int dw = (ow*jitter); + int dh = (oh*jitter); int pleft = (rand_uniform() * 2*dw - dw); int pright = (rand_uniform() * 2*dw - dw); @@ -556,7 +556,7 @@ void *load_thread(void *ptr) } else if (a.type == WRITING_DATA){ *a.d = load_data_writing(a.paths, a.n, a.m, a.w, a.h, a.out_w, a.out_h); } else if (a.type == REGION_DATA){ - *a.d = load_data_region(a.n, a.paths, a.m, a.w, a.h, a.num_boxes, a.classes); + *a.d = load_data_region(a.n, a.paths, a.m, a.w, a.h, a.num_boxes, a.classes, a.jitter); } else if (a.type == COMPARE_DATA){ *a.d = load_data_compare(a.n, a.paths, a.m, a.classes, a.w, a.h); } else if (a.type == IMAGE_DATA){ diff --git a/src/data.h b/src/data.h index b91819f..0dacea2 100644 --- a/src/data.h +++ b/src/data.h @@ -44,6 +44,7 @@ typedef struct load_args{ int num_boxes; int classes; int background; + float jitter; data *d; image *im; image *resized; diff --git a/src/dice.c b/src/dice.c index fdc535e..6f148b0 100644 --- a/src/dice.c +++ b/src/dice.c @@ -61,7 +61,7 @@ void validate_dice(char *filename, char *weightfile) free_list(plist); data val = load_data(paths, m, 0, labels, 6, net.w, net.h); - float *acc = network_accuracies(net, val); + float *acc = network_accuracies(net, val, 2); printf("Validation Accuracy: %f, %d images\n", acc[0], m); free_data(val); } diff --git a/src/imagenet.c b/src/imagenet.c index 567a8c4..1701a2a 100644 --- a/src/imagenet.c +++ b/src/imagenet.c @@ -133,7 +133,7 @@ void validate_imagenet(char *filename, char *weightfile) printf("Loaded: %d images in %lf seconds\n", val.X.rows, sec(clock()-time)); time=clock(); - float *acc = network_accuracies(net, val); + float *acc = network_accuracies(net, val, 5); avg_acc += acc[0]; avg_top5 += acc[1]; printf("%d: top1: %f, top5: %f, %lf seconds, %d images\n", i, avg_acc/i, avg_top5/i, sec(clock()-time), val.X.rows); diff --git a/src/layer.h b/src/layer.h index 808aba4..49f144d 100644 --- a/src/layer.h +++ b/src/layer.h @@ -29,6 +29,9 @@ typedef struct { COST_TYPE cost_type; int batch; int forced; + int object_logistic; + int class_logistic; + int coord_logistic; int inputs; int outputs; int truths; @@ -45,6 +48,7 @@ typedef struct { int sqrt; int flip; float angle; + float jitter; float saturation; float exposure; int softmax; diff --git a/src/network.c b/src/network.c index 7f19318..063a1bb 100644 --- a/src/network.c +++ b/src/network.c @@ -540,12 +540,12 @@ float network_accuracy(network net, data d) return acc; } -float *network_accuracies(network net, data d) +float *network_accuracies(network net, data d, int n) { static float acc[2]; matrix guess = network_predict_data(net, d); - acc[0] = matrix_topk_accuracy(d.y, guess,1); - acc[1] = matrix_topk_accuracy(d.y, guess,5); + acc[0] = matrix_topk_accuracy(d.y, guess, 1); + acc[1] = matrix_topk_accuracy(d.y, guess, n); free_matrix(guess); return acc; } diff --git a/src/network.h b/src/network.h index 5a39f08..78ad0fe 100644 --- a/src/network.h +++ b/src/network.h @@ -70,7 +70,7 @@ float train_network_sgd(network net, data d, int n); matrix network_predict_data(network net, data test); float *network_predict(network net, float *input); float network_accuracy(network net, data d); -float *network_accuracies(network net, data d); +float *network_accuracies(network net, data d, int n); float network_accuracy_multi(network net, data d, int n); void top_predictions(network net, int n, int *index); float *get_network_output(network net); diff --git a/src/option_list.c b/src/option_list.c index f5536e1..7d68ead 100644 --- a/src/option_list.c +++ b/src/option_list.c @@ -3,6 +3,24 @@ #include <string.h> #include "option_list.h" +int read_option(char *s, list *options) +{ + size_t i; + size_t len = strlen(s); + char *val = 0; + for(i = 0; i < len; ++i){ + if(s[i] == '='){ + s[i] = '\0'; + val = s+i+1; + break; + } + } + if(i == len-1) return 0; + char *key = s; + option_insert(options, key, val); + return 1; +} + void option_insert(list *l, char *key, char *val) { kvp *p = malloc(sizeof(kvp)); diff --git a/src/option_list.h b/src/option_list.h index 4441462..d0417aa 100644 --- a/src/option_list.h +++ b/src/option_list.h @@ -9,6 +9,7 @@ typedef struct{ } kvp; +int read_option(char *s, list *options); void option_insert(list *l, char *key, char *val); char *option_find(list *l, char *key); char *option_find_str(list *l, char *key, char *def); diff --git a/src/parser.c b/src/parser.c index 6daeb13..a3400d0 100644 --- a/src/parser.c +++ b/src/parser.c @@ -186,11 +186,16 @@ region_layer parse_region(list *options, size_params params) layer.softmax = option_find_int(options, "softmax", 0); layer.sqrt = option_find_int(options, "sqrt", 0); + layer.object_logistic = option_find_int(options, "object_logistic", 0); + layer.class_logistic = option_find_int(options, "class_logistic", 0); + layer.coord_logistic = option_find_int(options, "coord_logistic", 0); + layer.coord_scale = option_find_float(options, "coord_scale", 1); layer.forced = option_find_int(options, "forced", 0); layer.object_scale = option_find_float(options, "object_scale", 1); layer.noobject_scale = option_find_float(options, "noobject_scale", 1); layer.class_scale = option_find_float(options, "class_scale", 1); + layer.jitter = option_find_float(options, "jitter", .1); return layer; } @@ -532,24 +537,6 @@ int is_route(section *s) return (strcmp(s->type, "[route]")==0); } -int read_option(char *s, list *options) -{ - size_t i; - size_t len = strlen(s); - char *val = 0; - for(i = 0; i < len; ++i){ - if(s[i] == '='){ - s[i] = '\0'; - val = s+i+1; - break; - } - } - if(i == len-1) return 0; - char *key = s; - option_insert(options, key, val); - return 1; -} - list *read_cfg(char *filename) { FILE *file = fopen(filename, "r"); diff --git a/src/region_layer.c b/src/region_layer.c index 4d8c2a4..3239f87 100644 --- a/src/region_layer.c +++ b/src/region_layer.c @@ -57,6 +57,28 @@ void forward_region_layer(const region_layer l, network_state state) activate_array(l.output + index + offset, locations*l.n*(1+l.coords), LOGISTIC); } } + if (l.object_logistic) { + for(b = 0; b < l.batch; ++b){ + int index = b*l.inputs; + int p_index = index + locations*l.classes; + activate_array(l.output + p_index, locations*l.n, LOGISTIC); + } + } + + if (l.coord_logistic) { + for(b = 0; b < l.batch; ++b){ + int index = b*l.inputs; + int coord_index = index + locations*(l.classes + l.n); + activate_array(l.output + coord_index, locations*l.n*l.coords, LOGISTIC); + } + } + + if (l.class_logistic) { + for(b = 0; b < l.batch; ++b){ + int class_index = b*l.inputs; + activate_array(l.output + class_index, locations*l.classes, LOGISTIC); + } + } if(state.train){ float avg_iou = 0; @@ -85,7 +107,6 @@ void forward_region_layer(const region_layer l, network_state state) float best_rmse = 20; if (!is_obj){ - //printf("."); continue; } @@ -113,6 +134,7 @@ void forward_region_layer(const region_layer l, network_state state) } float iou = box_iou(out, truth); + //iou = 0; float rmse = box_rmse(out, truth); if(best_iou > 0 || iou > 0){ if(iou > best_iou){ @@ -175,6 +197,20 @@ void forward_region_layer(const region_layer l, network_state state) gradient_array(l.output + index + locations*l.classes, locations*l.n*(1+l.coords), LOGISTIC, l.delta + index + locations*l.classes); } + if (l.object_logistic) { + int p_index = index + locations*l.classes; + gradient_array(l.output + p_index, locations*l.n, LOGISTIC, l.delta + p_index); + } + + if (l.class_logistic) { + int class_index = index; + gradient_array(l.output + class_index, locations*l.classes, LOGISTIC, l.delta + class_index); + } + + if (l.coord_logistic) { + int coord_index = index + locations*(l.classes + l.n); + gradient_array(l.output + coord_index, locations*l.n*l.coords, LOGISTIC, l.delta + coord_index); + } //printf("\n"); } printf("Region Avg IOU: %f, Pos Cat: %f, All Cat: %f, Pos Obj: %f, Any Obj: %f, count: %d\n", avg_iou/count, avg_cat/count, avg_allcat/(count*l.classes), avg_obj/count, avg_anyobj/(l.batch*locations*l.n), count); diff --git a/src/swag.c b/src/swag.c index ec58f0d..8c9ce3c 100644 --- a/src/swag.c +++ b/src/swag.c @@ -73,6 +73,7 @@ void train_swag(char *cfgfile, char *weightfile) int side = l.side; int classes = l.classes; + float jitter = l.jitter; list *plist = get_paths(train_images); //int N = plist->size; @@ -85,6 +86,7 @@ void train_swag(char *cfgfile, char *weightfile) args.n = imgs; args.m = plist->size; args.classes = classes; + args.jitter = jitter; args.num_boxes = side; args.d = &buffer; args.type = REGION_DATA; @@ -127,7 +129,7 @@ void train_swag(char *cfgfile, char *weightfile) save_weights(net, buff); } -void convert_swag_detections(float *predictions, int classes, int num, int square, int side, int w, int h, float thresh, float **probs, box *boxes) +void convert_swag_detections(float *predictions, int classes, int num, int square, int side, int w, int h, float thresh, float **probs, box *boxes, int only_objectness) { int i,j,n; //int per_cell = 5*num+classes; @@ -148,6 +150,9 @@ void convert_swag_detections(float *predictions, int classes, int num, int squar float prob = scale*predictions[class_index+j]; probs[index][j] = (prob > thresh) ? prob : 0; } + if(only_objectness){ + probs[index][0] = scale; + } } } } @@ -250,7 +255,7 @@ void validate_swag(char *cfgfile, char *weightfile) float *predictions = network_predict(net, X); int w = val[t].w; int h = val[t].h; - convert_swag_detections(predictions, classes, l.n, square, side, w, h, thresh, probs, boxes); + convert_swag_detections(predictions, classes, l.n, square, side, w, h, thresh, probs, boxes, 0); if (nms) do_nms(boxes, probs, side*side*l.n, classes, iou_thresh); print_swag_detections(fps, id, boxes, probs, side*side*l.n, classes, w, h); free(id); @@ -261,6 +266,95 @@ void validate_swag(char *cfgfile, char *weightfile) fprintf(stderr, "Total Detection Time: %f Seconds\n", (double)(time(0) - start)); } +void validate_swag_recall(char *cfgfile, char *weightfile) +{ + network net = parse_network_cfg(cfgfile); + if(weightfile){ + load_weights(&net, weightfile); + } + set_batch_network(&net, 1); + fprintf(stderr, "Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay); + srand(time(0)); + + char *base = "results/comp4_det_test_"; + list *plist = get_paths("/home/pjreddie/data/voc/test/2007_test.txt"); + char **paths = (char **)list_to_array(plist); + + layer l = net.layers[net.n-1]; + int classes = l.classes; + int square = l.sqrt; + int side = l.side; + + int j, k; + FILE **fps = calloc(classes, sizeof(FILE *)); + for(j = 0; j < classes; ++j){ + char buff[1024]; + snprintf(buff, 1024, "%s%s.txt", base, voc_names[j]); + fps[j] = fopen(buff, "w"); + } + box *boxes = calloc(side*side*l.n, sizeof(box)); + float **probs = calloc(side*side*l.n, sizeof(float *)); + for(j = 0; j < side*side*l.n; ++j) probs[j] = calloc(classes, sizeof(float *)); + + int m = plist->size; + int i=0; + + float thresh = .001; + int nms = 0; + float iou_thresh = .5; + float nms_thresh = .5; + + int total = 0; + int correct = 0; + int proposals = 0; + float avg_iou = 0; + + for(i = 0; i < m; ++i){ + char *path = paths[i]; + image orig = load_image_color(path, 0, 0); + image sized = resize_image(orig, net.w, net.h); + char *id = basecfg(path); + float *predictions = network_predict(net, sized.data); + int w = orig.w; + int h = orig.h; + convert_swag_detections(predictions, classes, l.n, square, side, 1, 1, thresh, probs, boxes, 1); + if (nms) do_nms(boxes, probs, side*side*l.n, 1, nms_thresh); + + char *labelpath = find_replace(path, "images", "labels"); + labelpath = find_replace(labelpath, "JPEGImages", "labels"); + labelpath = find_replace(labelpath, ".jpg", ".txt"); + labelpath = find_replace(labelpath, ".JPEG", ".txt"); + + int num_labels = 0; + box_label *truth = read_boxes(labelpath, &num_labels); + for(k = 0; k < side*side*l.n; ++k){ + if(probs[k][0] > thresh){ + ++proposals; + } + } + for (j = 0; j < num_labels; ++j) { + ++total; + box t = {truth[j].x, truth[j].y, truth[j].w, truth[j].h}; + float best_iou = 0; + for(k = 0; k < side*side*l.n; ++k){ + float iou = box_iou(boxes[k], t); + if(probs[k][0] > thresh && iou > best_iou){ + best_iou = iou; + } + } + avg_iou += best_iou; + if(best_iou > iou_thresh){ + ++correct; + } + } + + fprintf(stderr, "%5d %5d %5d\tRPs/Img: %.2f\tIOU: %.2f%%\tRecall:%.2f%%\n", i, correct, total, (float)proposals/(i+1), avg_iou*100/total, 100.*correct/total); + free(id); + free_image(orig); + free_image(sized); + } +} + void test_swag(char *cfgfile, char *weightfile, char *filename, float thresh) { @@ -316,4 +410,5 @@ void run_swag(int argc, char **argv) if(0==strcmp(argv[2], "test")) test_swag(cfg, weights, filename, thresh); else if(0==strcmp(argv[2], "train")) train_swag(cfg, weights); else if(0==strcmp(argv[2], "valid")) validate_swag(cfg, weights); + else if(0==strcmp(argv[2], "recall")) validate_swag_recall(cfg, weights); } -- GitLab