Skip to content
Snippets Groups Projects
Commit 59e35673 authored by Joseph Redmon's avatar Joseph Redmon
Browse files

writing stuff

parent fed6d6e3
No related branches found
No related tags found
No related merge requests found
[net]
batch=64
subdivisions=1
height=256
width=256
channels=3
learning_rate=0.00001
momentum=0.9
decay=0.0005
seen=0
[crop]
crop_height=256
crop_width=256
flip=0
angle=0
saturation=1
exposure=1
[convolutional]
filters=32
size=3
stride=1
pad=1
activation=ramp
[convolutional]
filters=32
size=3
stride=1
pad=1
activation=ramp
[convolutional]
filters=32
size=3
stride=1
pad=1
activation=ramp
[convolutional]
filters=1
size=5
stride=1
pad=1
activation=logistic
[cost]
......@@ -54,7 +54,12 @@ matrix load_image_paths_gray(char **paths, int n, int w, int h)
X.cols = 0;
for(i = 0; i < n; ++i){
image im = load_image(paths[i], w, h, 1);
image im = load_image(paths[i], w, h, 3);
image gray = grayscale_image(im);
free_image(im);
im = gray;
X.vals[i] = im.data;
X.cols = im.h*im.w*im.c;
}
......@@ -571,14 +576,14 @@ pthread_t load_data_in_thread(load_args args)
return thread;
}
data load_data_writing(char **paths, int n, int m, int w, int h)
data load_data_writing(char **paths, int n, int m, int w, int h, int downsample)
{
if(m) paths = get_random_paths(paths, n, m);
char **replace_paths = find_replace_paths(paths, n, ".png", "-label.png");
data d;
d.shallow = 0;
d.X = load_image_paths(paths, n, w, h);
d.y = load_image_paths_gray(replace_paths, n, w/8, h/8);
d.y = load_image_paths_gray(replace_paths, n, w/downsample, h/downsample);
if(m) free(paths);
int i;
for(i = 0; i < n; ++i) free(replace_paths[i]);
......
......@@ -68,7 +68,7 @@ box_label *read_boxes(char *filename, int *n);
data load_cifar10_data(char *filename);
data load_all_cifar10();
data load_data_writing(char **paths, int n, int m, int w, int h);
data load_data_writing(char **paths, int n, int m, int w, int h, int downsample);
list *get_paths(char *filename);
char **get_labels(char *filename);
......
This diff is collapsed.
......@@ -61,6 +61,7 @@ void forward_region_layer(const region_layer l, network_state state)
if(state.train){
float avg_iou = 0;
float avg_cat = 0;
float avg_allcat = 0;
float avg_obj = 0;
float avg_anyobj = 0;
int count = 0;
......@@ -90,6 +91,7 @@ void forward_region_layer(const region_layer l, network_state state)
l.delta[class_index+j] = l.class_scale * (state.truth[truth_index+1+j] - l.output[class_index+j]);
*(l.cost) += l.class_scale * pow(state.truth[truth_index+1+j] - l.output[class_index+j], 2);
if(state.truth[truth_index + 1 + j]) avg_cat += l.output[class_index+j];
avg_allcat += l.output[class_index+j];
}
box truth = float_to_box(state.truth + truth_index + 1 + l.classes);
......@@ -151,7 +153,7 @@ void forward_region_layer(const region_layer l, network_state state)
LOGISTIC, l.delta + index + locations*l.classes);
}
}
printf("Region Avg IOU: %f, Avg Cat Pred: %f, Avg Obj: %f, Avg Any: %f, count: %d\n", avg_iou/count, avg_cat/count, avg_obj/count, avg_anyobj/(l.batch*locations*l.n), count);
printf("Region Avg IOU: %f, Pos Cat: %f, All Cat: %f, Pos Obj: %f, Any Obj: %f, count: %d\n", avg_iou/count, avg_cat/count, avg_allcat/(count*l.classes), avg_obj/count, avg_anyobj/(l.batch*locations*l.n), count);
}
}
......
......@@ -132,21 +132,22 @@ void train_swag(char *cfgfile, char *weightfile)
void convert_swag_detections(float *predictions, int classes, int num, int square, int side, int w, int h, float thresh, float **probs, box *boxes)
{
int i,j,n;
int per_cell = 5*num+classes;
//int per_cell = 5*num+classes;
for (i = 0; i < side*side; ++i){
int row = i / side;
int col = i % side;
for(n = 0; n < num; ++n){
int offset = i*per_cell + 5*n;
float scale = predictions[offset];
int index = i*num + n;
boxes[index].x = (predictions[offset + 1] + col) / side * w;
boxes[index].y = (predictions[offset + 2] + row) / side * h;
boxes[index].w = pow(predictions[offset + 3], (square?2:1)) * w;
boxes[index].h = pow(predictions[offset + 4], (square?2:1)) * h;
int p_index = side*side*classes + i*num + n;
float scale = predictions[p_index];
int box_index = side*side*(classes + num) + (i*num + n)*4;
boxes[index].x = (predictions[box_index + 0] + col) / side * w;
boxes[index].y = (predictions[box_index + 1] + row) / side * h;
boxes[index].w = pow(predictions[box_index + 2], (square?2:1)) * w;
boxes[index].h = pow(predictions[box_index + 3], (square?2:1)) * h;
for(j = 0; j < classes; ++j){
offset = i*per_cell + 5*num;
float prob = scale*predictions[offset+j];
int class_index = i*classes;
float prob = scale*predictions[class_index+j];
probs[index][j] = (prob > thresh) ? prob : 0;
}
}
......
......@@ -2,8 +2,13 @@
#include "utils.h"
#include "parser.h"
#ifdef OPENCV
#include "opencv2/highgui/highgui_c.h"
#endif
void train_writing(char *cfgfile, char *weightfile)
{
char *backup_directory = "/home/pjreddie/backup/";
data_seed = time(0);
srand(time(0));
float avg_loss = -1;
......@@ -23,41 +28,78 @@ void train_writing(char *cfgfile, char *weightfile)
while(1){
++i;
time=clock();
data train = load_data_writing(paths, imgs, plist->size, 512, 512);
data train = load_data_writing(paths, imgs, plist->size, 256, 256, 1);
printf("Loaded %lf seconds\n",sec(clock()-time));
time=clock();
float loss = train_network(net, train);
#ifdef GPU
float *out = get_network_output_gpu(net);
#else
float *out = get_network_output(net);
#endif
image pred = float_to_image(64, 64, 1, out);
print_image(pred);
/*
image im = float_to_image(256, 256, 3, train.X.vals[0]);
image lab = float_to_image(64, 64, 1, train.y.vals[0]);
/*
image pred = float_to_image(64, 64, 1, out);
show_image(im, "image");
show_image(lab, "label");
print_image(lab);
show_image(pred, "pred");
cvWaitKey(0);
print_image(pred);
*/
/*
image im = float_to_image(256, 256, 3, train.X.vals[0]);
image lab = float_to_image(64, 64, 1, train.y.vals[0]);
image pred = float_to_image(64, 64, 1, out);
show_image(im, "image");
show_image(lab, "label");
print_image(lab);
show_image(pred, "pred");
cvWaitKey(0);
*/
if(avg_loss == -1) avg_loss = loss;
avg_loss = avg_loss*.9 + loss*.1;
printf("%d: %f, %f avg, %lf seconds, %d images\n", i, loss, avg_loss, sec(clock()-time), *net.seen);
free_data(train);
if((i % 20000) == 0) net.learning_rate *= .1;
//if(i%100 == 0 && net.learning_rate > .00001) net.learning_rate *= .97;
if(i%1000==0){
char buff[256];
sprintf(buff, "/home/pjreddie/imagenet_backup/%s_%d.weights",base, i);
sprintf(buff, "%s/%s_%d.weights", backup_directory, base, i);
save_weights(net, buff);
}
}
}
void test_writing(char *cfgfile, char *weightfile, char *outfile)
{
network net = parse_network_cfg(cfgfile);
if(weightfile){
load_weights(&net, weightfile);
}
set_batch_network(&net, 1);
srand(2222222);
clock_t time;
char filename[256];
fgets(filename, 256, stdin);
strtok(filename, "\n");
image im = load_image_color(filename, 0, 0);
//image im = load_image_color("/home/pjreddie/darknet/data/figs/C02-1001-Figure-1.png", 0, 0);
image sized = resize_image(im, net.w, net.h);
printf("%d %d %d\n", im.h, im.w, im.c);
float *X = sized.data;
time=clock();
network_predict(net, X);
printf("%s: Predicted in %f seconds.\n", filename, sec(clock()-time));
image pred = get_network_image(net);
if (outfile) {
printf("Save image as %s.png (shape: %d %d)\n", outfile, pred.w, pred.h);
save_image(pred, outfile);
} else {
show_image(pred, "prediction");
#ifdef OPENCV
cvWaitKey(0);
cvDestroyAllWindows();
#endif
}
free_image(im);
free_image(sized);
}
void run_writing(int argc, char **argv)
{
if(argc < 4){
......@@ -67,6 +109,8 @@ void run_writing(int argc, char **argv)
char *cfg = argv[3];
char *weights = (argc > 4) ? argv[4] : 0;
char *outfile = (argc > 5) ? argv[5] : 0;
if(0==strcmp(argv[2], "train")) train_writing(cfg, weights);
else if(0==strcmp(argv[2], "test")) test_writing(cfg, weights, outfile);
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment