Better balanced sampling in KWCocoVideoDataset
In the KWCocoVideoDataset, we construct a list of "targets", which represent the result of sliding a windows across spacetime in each input video. By default, both "positive" (target windows that contain an annotation) and "negative" (target windows that do not contain an annotation) are included in this list.
The flat list of all possible targets is initialized here: https://gitlab.kitware.com/computer-vision/geowatch/-/blob/64c93bcc13ac2c93a55c48224945251f67ac9af8/geowatch/tasks/fusion/datamodules/kwcoco_dataset.py?page=4#L3159
Naively, the torch dataset would just iterate through this simple flat list, but that would cause a problem in typical use cases because the number of "positive" and "negative" samples will be highly imbalanced. Thus we build a data structure data_utils.NestedPool
which lets us randomly sample from the list, but it will oversample / understample to roughly balance the number of "positive" and "negative" examples given to the network at train time.
A few lines down this data structure is initialized:
if 1:
self._init_balance(new_sample_grid)
Which has recently been reorganized into a mixin class so the balancing code could be experimented with. That relevent code is here: https://gitlab.kitware.com/computer-vision/geowatch/-/blob/64c93bcc13ac2c93a55c48224945251f67ac9af8/geowatch/tasks/fusion/datamodules/kwcoco_dataset.py?page=4#L2294
The code for the NestedPool itself is in: (https://gitlab.kitware.com/computer-vision/geowatch/-/blob/64c93bcc13ac2c93a55c48224945251f67ac9af8/geowatch/tasks/fusion/datamodules/data_utils.py#L332)
When calling getitem
in train mode, the integer index will be "coerced" to a target, which ultimately calls this line: https://gitlab.kitware.com/computer-vision/geowatch/-/blob/64c93bcc13ac2c93a55c48224945251f67ac9af8/geowatch/tasks/fusion/datamodules/kwcoco_dataset.py?page=4#L1342 which samples an index from the NestedPool and then looks up that target.
In fact building this data_utils.NestedPool
does a bit more than positive / negative balancing. It also attempts to balance over regions, but I haven't vetted this code in a bit and I have a suspicion it might not be working exactly correct. However, when there is more than one attribute you want to balance over, it may be the case that those two attributes are in conflict. My idea for working around this is to build a tree, and each level of the tree represents a different attribute you might care to balance over. The ordering of these attributes will matter, but I have an idea of extension that might mitigate this (more on this later).
Given a single tree, you randomly go down the branches until you get to a leaf, which contains an index into the chosen target.
I recently have done some work on the NestedPool class which introduces the "subdivide" method, which I think is a more intuitive and understandable way to build the sampling tree. I think the doctest gives a reasonable example of this:
from geowatch.tasks.fusion.datamodules.data_utils import NestedPool
# Lets say that you have a grid of sample locations with information
# about them - say a source region and what category they contain.
# In this case region1 occurs more often than region2 and there is
# a rare category that only appears twice.
sample_grid = [
{'region': 'region1', 'category': 'background'},
{'region': 'region1', 'category': 'rare'},
{'region': 'region1', 'category': 'background'},
{'region': 'region1', 'category': 'background'},
{'region': 'region1', 'category': 'background'},
{'region': 'region1', 'category': 'background'},
{'region': 'region1', 'category': 'background'},
{'region': 'region1', 'category': 'background'},
{'region': 'region1', 'category': 'background'},
{'region': 'region2', 'category': 'background'},
{'region': 'region2', 'category': 'background'},
{'region': 'region2', 'category': 'rare'},
]
#
# First we can just create a flat uniform sampling grid
# And inspect the imbalance that causes.
sample_idxs = list(range(len(sample_grid)))
self = NestedPool(sample_idxs)
print(f'self={self}')
sampled = list(self._sample_many(100, sample_grid))
hist0 = ub.dict_hist([(g['region'], g['category']) for g in sampled])
print('hist0 = {}'.format(ub.urepr(hist0, nl=1)))
#
# We can subdivide the indexes based on region to improve balance.
self.subdivide([g['region'] for g in sample_grid])
print(f'self={self}')
sampled = list(self._sample_many(100, sample_grid))
hist1 = ub.dict_hist([(g['region'], g['category']) for g in sampled])
print('hist1 = {}'.format(ub.urepr(hist1, nl=1)))
#
# We can further subdivide by category.
self.subdivide([g['category'] for g in sample_grid])
print(f'self={self}')
sampled = list(self._sample_many(100, sample_grid))
hist2 = ub.dict_hist([(g['region'], g['category']) for g in sampled])
print('hist2 = {}'.format(ub.urepr(hist2, nl=1)))
What I would like is if the user could specify an ordered list of "attribues" that should be balanced over. As long as these attributes are stored in the target dictionary it should be possible to initialize a flat NestedPool (i.e. a tree with one level corresponding to each index), and then call subdivide on each of the resquested attributes. This should implement the user-requested balancing scheme.
The thing I'm most interested in balancing is the annotations with / without phase labels. It might be the case that an attribute representing this needs to be added to the target dictionary, but once that is done, then calling subdivide on it should balance over it. The problem I'm seeing is that when learning class labels the network is incentivized not to care about them that much because the overwhelming number of positive cases have no phase labels.
To demonstrate that this works, it would be interesting to see a histogram of attributes that are actually sampled when calling getitem
multiple times in train mode.
As an extension to mitigate the fact that the order of the attribute splits matters, I'm wondering what the effect of a random-forest-like sampling approach would be. In other words, start with multiple flat NestedPool
objects, and for each of them call subdivide, but use a random ordering of the user-requested attributes. The top-level of the nested pool is always perfectly balanced, so there would be at least one pool that is perfectly balanced. Then randomly sampling a pool, and then sampling from that pool, might result in a "best-of-all-worlds" balance scheme, although I'm not sure exactly how the math works on in this case, it would be interesting to see the attribute comparison histogram in this case as well.
The important part about this feature is that is agnostic to the problem (heavy construction / etc...) and gives the user a configuration-level mechanism to control this important part of training without needing to hack datasets or modify code.