From 1a4ad4f36387ef67f12f2a7341532b2eb6bb6c33 Mon Sep 17 00:00:00 2001 From: Scott Workman <sworkman@dzynetech.com> Date: Sun, 25 Feb 2024 16:18:21 -0500 Subject: [PATCH 1/7] add weighted sampling to NestedPool partially addresses #8 --- .../tasks/fusion/datamodules/data_utils.py | 232 +++++++++++++----- 1 file changed, 167 insertions(+), 65 deletions(-) diff --git a/geowatch/tasks/fusion/datamodules/data_utils.py b/geowatch/tasks/fusion/datamodules/data_utils.py index a9a2f7878..f68ba54d6 100644 --- a/geowatch/tasks/fusion/datamodules/data_utils.py +++ b/geowatch/tasks/fusion/datamodules/data_utils.py @@ -329,11 +329,11 @@ def _boxes_snap_to_edges(given_box, snap_target): return adjusted_box -class NestedPool(list): +class NestedPool(): """ - Manages a sampling from a tree of indexes (represented as nested lists). + Manages a sampling from a tree of indexes (nodes are dictionaries, leafs are lists). - Helps with balancing samples over multiple criteria + Helps with balancing samples over multiple criteria. Example: >>> from geowatch.tasks.fusion.datamodules.data_utils import NestedPool @@ -342,22 +342,27 @@ class NestedPool(list): >>> # In this case region1 occurs more often than region2 and there is >>> # a rare category that only appears twice. >>> sample_grid = [ - >>> {'region': 'region1', 'category': 'background'}, - >>> {'region': 'region1', 'category': 'rare'}, - >>> {'region': 'region1', 'category': 'background'}, - >>> {'region': 'region1', 'category': 'background'}, - >>> {'region': 'region1', 'category': 'background'}, - >>> {'region': 'region1', 'category': 'background'}, - >>> {'region': 'region1', 'category': 'background'}, - >>> {'region': 'region1', 'category': 'background'}, - >>> {'region': 'region1', 'category': 'background'}, - >>> {'region': 'region2', 'category': 'background'}, - >>> {'region': 'region2', 'category': 'background'}, - >>> {'region': 'region2', 'category': 'rare'}, + >>> { 'region': 'region1', 'category': 'background', 'color': "blue" }, + >>> { 'region': 'region1', 'category': 'background', 'color': "purple" }, + >>> { 'region': 'region1', 'category': 'background', 'color': "blue" }, + >>> { 'region': 'region1', 'category': 'background', 'color': "red" }, + >>> { 'region': 'region1', 'category': 'background', 'color': "green" }, + >>> { 'region': 'region1', 'category': 'background', 'color': "purple" }, + >>> { 'region': 'region1', 'category': 'background', 'color': "blue" }, + >>> { 'region': 'region1', 'category': 'rare', 'color': "red" }, + >>> { 'region': 'region1', 'category': 'rare', 'color': "green" }, + >>> { 'region': 'region2', 'category': 'background', 'color': "red" }, + >>> { 'region': 'region2', 'category': 'background', 'color': "green" }, + >>> { 'region': 'region2', 'category': 'background', 'color': "blue" }, + >>> { 'region': 'region2', 'category': 'background', 'color': "purple" }, + >>> { 'region': 'region2', 'category': 'background', 'color': "red" }, + >>> { 'region': 'region2', 'category': 'background', 'color': "green" }, + >>> { 'region': 'region2', 'category': 'rare', 'color': "purple" }, + >>> { 'region': 'region2', 'category': 'rare', 'color': "blue" }, >>> ] >>> # >>> # First we can just create a flat uniform sampling grid - >>> # And inspect the imbalance that causes. + >>> # and inspect the imbalance that causes. >>> sample_idxs = list(range(len(sample_grid))) >>> self = NestedPool(sample_idxs) >>> print(f'self={self}') @@ -378,23 +383,72 @@ class NestedPool(list): >>> sampled = list(self._sample_many(100, sample_grid)) >>> hist2 = ub.dict_hist([(g['region'], g['category']) for g in sampled]) >>> print('hist2 = {}'.format(ub.urepr(hist2, nl=1))) + >>> # + >>> # We can further subdivide by color, using custom weights. + >>> weights = { 'red': .25, 'blue': .25, 'green': .4, 'purple': .1 } + >>> self.subdivide([g['color'] for g in sample_grid], weights=weights) + >>> print(f'self={self}') + >>> sampled = list(self._sample_many(100, sample_grid)) + >>> hist3 = ub.dict_hist([ + >>> (g['region'], g['category'], g['color']) for g in sampled + >>> ]) + >>> print('hist3 = {}'.format(ub.urepr(hist3, nl=1))) + >>> hist3_color = ub.dict_hist([(g['color']) for g in sampled]) + >>> print('color weights = {}'.format(ub.urepr(weights, nl=1))) + >>> print('hist3 (color) = {}'.format(ub.urepr(hist3_color, nl=1))) Example: >>> from geowatch.tasks.fusion.datamodules.data_utils import * # NOQA >>> nested1 = NestedPool([[[1], [2, 3], [4, 5, 6], [7, 8, 9, 0]], [[11, 12, 13]]]) - >>> list(nested1.leafs()) - >>> print({nested1.sample() for i in range(100)}) >>> nested2 = NestedPool([[101], [102, 103], [104, 105, 106], [107, 8, 9, 0]]) >>> print({nested2.sample() for i in range(100)}) - >>> nested3 = NestedPool([nested1, nested2, [4, 59, 9, [], []]]) - >>> print({nested3.sample() for i in range(100)}) - >>> print(ub.urepr(ub.dict_hist(nested3.sample() for i in range(100)))) """ def __init__(self, pools, rng=None): - super().__init__(pools) self.rng = rng = kwarray.ensure_rng(rng) - self.pools = pools + self.pools = self._convert_to_weighted(self._validate(pools)) + + def _validate(self, _input): + # TODO: robustly validate the input to __init__ + if not isinstance(_input, list): + raise ValueError('NestedPool requires a list as input.') + if len(_input) == 0: + raise ValueError('NestedPool received an empty list as input.') + + def remove_empty_leafs(nested): + if not isinstance(nested, list): + return nested + return list(filter(lambda x: x != [], (map(remove_empty_leafs, nested)))) + return remove_empty_leafs(_input) + + def _compute_depth(self, x): + return isinstance(x, list) and max(map(self._compute_depth, x)) + 1 + + def _make_node(self, x): + return {"weights": None, "children": x} + + def _convert_to_weighted(self, nested): + """ + Convert from a tree (as a nested list of leaf values) to a representation + where nodes are dictionaries and children are lists. This allows specifying a + weight at every node to use when sampling. + + Note: A single level is still just a flat list internally (a leaf). + """ + if not isinstance(nested, list): + return nested + + max_depth = self._compute_depth(nested) + if max_depth == 1: + return nested + + if max_depth == 2 and len(nested) >= 2: + return self._make_node(nested) + else: + collect = [] + for o in nested: + collect.append(self._convert_to_weighted(o)) + return self._make_node(collect) def _sample_many(self, num, items): for _ in range(num): @@ -402,7 +456,62 @@ class NestedPool(list): item = items[idx] yield item - def subdivide(self, items, key=None): + def sample(self): + chosen = self.pools + while ub.iterable(chosen): + if isinstance(chosen, dict): + # processing a node, sample using weights + weights = chosen["weights"] + children = chosen["children"] + num = len(children) + if weights is None: + idx = self.rng.randint(0, num) + else: + idx = self.rng.choice(num, 1, p=weights)[0] + chosen = children[idx] + elif isinstance(chosen, list): + # processing a leaf, sample uniformly + num = len(chosen) + idx = self.rng.randint(0, num) + chosen = chosen[idx] + return chosen + + def _subdivide_leaf(self, leaf, items, key=None, weights=None): + assert isinstance(leaf, list) + if len(leaf) == 1: + return leaf + + if key is not None: + groupids = list(map(key, ub.take(items, leaf))) + else: + groupids = list(ub.take(items, leaf)) + + groups = ub.group_items(leaf, groupids) + group_keys = groups.keys() + group_values = list(groups.values()) + + if len(group_values) > 1: + if weights is not None: + group_weights = np.asarray(list(ub.take(weights, group_keys))) + weights = group_weights / group_weights.sum() + return {"weights": weights, "children": group_values} + + def _subdivide_dict(self, nested, items, key=None, weights=None): + if not isinstance(nested, (list, dict)): + return nested + if isinstance(nested, dict): + if isinstance(nested["children"][0], list): + # children are leafs + nested["children"] = self._subdivide_dict(nested["children"], items, key=key, weights=weights) + return nested + else: + # children are nodes + nested["children"] = [self._subdivide_dict(x, items, key=key, weights=weights) for x in nested["children"]] + return nested + else: + return [self._subdivide_leaf(o, items, key=key, weights=weights) for o in nested] + + def subdivide(self, items, key=None, weights=None): """ Args: items (List): @@ -413,50 +522,43 @@ class NestedPool(list): key (None | Callable): if specified, for each ``items[i]`` found transform it into the group-id based on ``key(items[i])``. - """ - for leaf in self.leafs(): - if key is not None: - groupids = list(map(key, ub.take(items, leaf))) - else: - groupids = list(ub.take(items, leaf)) - new_subleafs = list(ub.group_items(leaf, groupids).values()) - if len(new_subleafs) > 1: - # Clear the current leaf and replace it with new subleafs - leaf[:] = new_subleafs - def leafs(self): + weights (None | Dict): + a dictionary of weights for possible categories in items. """ - Iterate over the deepest index lists in this pool. - """ - stack = [self] - while stack: - curr = stack.pop() - assert ub.iterable(curr) - if len(curr) == 0 or not ub.iterable(curr[0]): - # Found a leaf - yield curr - else: - for child in curr: - stack.append(child) + if isinstance(self.pools, list): + self.pools = self._subdivide_leaf(self.pools, + items, + key=key, + weights=weights) + else: + self.pools = self._subdivide_dict(self.pools, + items, + key=key, + weights=weights) + + def __len__(self): + if isinstance(self.pools, list): + return len(self.pools) + else: + def nested_len(nested): + return sum(nested_len(x) if isinstance(x, list) else 1 for x in nested) + pool_list = self._traverse(self.pools) + return nested_len(pool_list) + + def __str__(self): + if isinstance(self.pools, list): + return str(self.pools) + else: + return str(self._traverse(self.pools)) - def sample(self): - # Hack for empty lists - chosen = self - i = 0 - while ub.iterable(chosen): - chosen = self - i += 1 - while ub.iterable(chosen): - i += 1 - num = len(chosen) - # Fixme: not robust - if i > 100000: - raise Exception('Too many samples. Bad balance?') - if not num: - break - idx = self.rng.randint(0, num) - chosen = chosen[idx] - return chosen + def _traverse(self, nested): + if not isinstance(nested, (list, dict)): + return nested + if isinstance(nested, dict): + return self._traverse(nested["children"]) + else: + return [self._traverse(o) for o in nested] def samecolor_nodata_mask(stream, hwc, relevant_bands, use_regions=0, -- GitLab From 21d66b340cee5eaee92834f964aa175e3f32f429 Mon Sep 17 00:00:00 2001 From: Scott Workman <sworkman@dzynetech.com> Date: Mon, 26 Feb 2024 20:46:25 -0500 Subject: [PATCH 2/7] use networkx to manage the tree --- .../tasks/fusion/datamodules/data_utils.py | 271 +++++++----------- 1 file changed, 103 insertions(+), 168 deletions(-) diff --git a/geowatch/tasks/fusion/datamodules/data_utils.py b/geowatch/tasks/fusion/datamodules/data_utils.py index f68ba54d6..d29ee4f41 100644 --- a/geowatch/tasks/fusion/datamodules/data_utils.py +++ b/geowatch/tasks/fusion/datamodules/data_utils.py @@ -6,6 +6,7 @@ import numpy as np import ubelt as ub import kwimage import kwarray +import networkx as nx def resolve_scale_request(request=None, data_gsd=None): @@ -331,16 +332,13 @@ def _boxes_snap_to_edges(given_box, snap_target): class NestedPool(): """ - Manages a sampling from a tree of indexes (nodes are dictionaries, leafs are lists). - - Helps with balancing samples over multiple criteria. + Manages a sampling from a tree of indexes. Helps with balancing + samples over multiple criteria. Example: >>> from geowatch.tasks.fusion.datamodules.data_utils import NestedPool - >>> # Lets say that you have a grid of sample locations with information - >>> # about them - say a source region and what category they contain. - >>> # In this case region1 occurs more often than region2 and there is - >>> # a rare category that only appears twice. + >>> # Given a grid of sample locations and attribute information + >>> # (e.g., region, category). >>> sample_grid = [ >>> { 'region': 'region1', 'category': 'background', 'color': "blue" }, >>> { 'region': 'region1', 'category': 'background', 'color': "purple" }, @@ -351,8 +349,8 @@ class NestedPool(): >>> { 'region': 'region1', 'category': 'background', 'color': "blue" }, >>> { 'region': 'region1', 'category': 'rare', 'color': "red" }, >>> { 'region': 'region1', 'category': 'rare', 'color': "green" }, - >>> { 'region': 'region2', 'category': 'background', 'color': "red" }, - >>> { 'region': 'region2', 'category': 'background', 'color': "green" }, + >>> { 'region': 'region1', 'category': 'background', 'color': "red" }, + >>> { 'region': 'region1', 'category': 'background', 'color': "green" }, >>> { 'region': 'region2', 'category': 'background', 'color': "blue" }, >>> { 'region': 'region2', 'category': 'background', 'color': "purple" }, >>> { 'region': 'region2', 'category': 'background', 'color': "red" }, @@ -363,202 +361,139 @@ class NestedPool(): >>> # >>> # First we can just create a flat uniform sampling grid >>> # and inspect the imbalance that causes. - >>> sample_idxs = list(range(len(sample_grid))) - >>> self = NestedPool(sample_idxs) + >>> self = NestedPool(sample_grid) >>> print(f'self={self}') - >>> sampled = list(self._sample_many(100, sample_grid)) + >>> sampled = list(self._sample_many(100, return_attributes=True)) >>> hist0 = ub.dict_hist([(g['region'], g['category']) for g in sampled]) >>> print('hist0 = {}'.format(ub.urepr(hist0, nl=1))) >>> # >>> # We can subdivide the indexes based on region to improve balance. - >>> self.subdivide([g['region'] for g in sample_grid]) + >>> self.subdivide('region') >>> print(f'self={self}') - >>> sampled = list(self._sample_many(100, sample_grid)) + >>> sampled = list(self._sample_many(100, return_attributes=True)) >>> hist1 = ub.dict_hist([(g['region'], g['category']) for g in sampled]) >>> print('hist1 = {}'.format(ub.urepr(hist1, nl=1))) >>> # >>> # We can further subdivide by category. - >>> self.subdivide([g['category'] for g in sample_grid]) + >>> self.subdivide('category') >>> print(f'self={self}') - >>> sampled = list(self._sample_many(100, sample_grid)) + >>> sampled = list(self._sample_many(100, return_attributes=True)) >>> hist2 = ub.dict_hist([(g['region'], g['category']) for g in sampled]) >>> print('hist2 = {}'.format(ub.urepr(hist2, nl=1))) >>> # - >>> # We can further subdivide by color, using custom weights. + >>> # We can further subdivide by color, with custom weights. >>> weights = { 'red': .25, 'blue': .25, 'green': .4, 'purple': .1 } - >>> self.subdivide([g['color'] for g in sample_grid], weights=weights) + >>> self.subdivide('color', weights=weights) >>> print(f'self={self}') - >>> sampled = list(self._sample_many(100, sample_grid)) - >>> hist3 = ub.dict_hist([ + >>> sampled = list(self._sample_many(100, return_attributes=True)) + >>> hist2 = ub.dict_hist([ >>> (g['region'], g['category'], g['color']) for g in sampled >>> ]) - >>> print('hist3 = {}'.format(ub.urepr(hist3, nl=1))) - >>> hist3_color = ub.dict_hist([(g['color']) for g in sampled]) + >>> print('hist3 = {}'.format(ub.urepr(hist2, nl=1))) + >>> hist2 = ub.dict_hist([(g['color']) for g in sampled]) >>> print('color weights = {}'.format(ub.urepr(weights, nl=1))) - >>> print('hist3 (color) = {}'.format(ub.urepr(hist3_color, nl=1))) - - Example: - >>> from geowatch.tasks.fusion.datamodules.data_utils import * # NOQA - >>> nested1 = NestedPool([[[1], [2, 3], [4, 5, 6], [7, 8, 9, 0]], [[11, 12, 13]]]) - >>> print({nested1.sample() for i in range(100)}) - >>> nested2 = NestedPool([[101], [102, 103], [104, 105, 106], [107, 8, 9, 0]]) - >>> print({nested2.sample() for i in range(100)}) + >>> print('hist3 (color) = {}'.format(ub.urepr(hist2, nl=1))) """ - def __init__(self, pools, rng=None): + def __init__(self, sample_grid, rng=None): self.rng = rng = kwarray.ensure_rng(rng) - self.pools = self._convert_to_weighted(self._validate(pools)) - - def _validate(self, _input): - # TODO: robustly validate the input to __init__ - if not isinstance(_input, list): - raise ValueError('NestedPool requires a list as input.') - if len(_input) == 0: - raise ValueError('NestedPool received an empty list as input.') + self.graph = self._create_graph(sample_grid) - def remove_empty_leafs(nested): - if not isinstance(nested, list): - return nested - return list(filter(lambda x: x != [], (map(remove_empty_leafs, nested)))) - return remove_empty_leafs(_input) + def _create_graph(self, sample_grid): + graph = nx.DiGraph() - def _compute_depth(self, x): - return isinstance(x, list) and max(map(self._compute_depth, x)) + 1 + # make a special root node + root_node = '__root__' + graph.add_node(root_node, weights=None) - def _make_node(self, x): - return {"weights": None, "children": x} - - def _convert_to_weighted(self, nested): - """ - Convert from a tree (as a nested list of leaf values) to a representation - where nodes are dictionaries and children are lists. This allows specifying a - weight at every node to use when sampling. - - Note: A single level is still just a flat list internally (a leaf). - """ - if not isinstance(nested, list): - return nested - - max_depth = self._compute_depth(nested) - if max_depth == 1: - return nested - - if max_depth == 2 and len(nested) >= 2: - return self._make_node(nested) + for index, item in enumerate(sample_grid): + label = f'{index:02d} ' + ub.urepr(item, nl=0, compact=1, nobr=1) + if isinstance(item, dict): + graph.add_node(index, label=label, **item) + else: + graph.add_node(index, label=label) + graph.add_edge(root_node, index) + return graph + + def _get_leafs(self): + """ Return sink nodes for the graph """ + return (n for n in self.graph.nodes if self.graph.out_degree[n] == 0) + + def _get_parent(self, n): + """ Get the parent of a node (assume a tree). None if it doesnt exist """ + preds = self.graph.pred[n] + if len(preds): + assert len(preds) == 1 + return next(iter(preds)) else: - collect = [] - for o in nested: - collect.append(self._convert_to_weighted(o)) - return self._make_node(collect) + return None + + def subdivide(self, key, weights=None): + remove_edges = [] + add_edges = [] + add_nodes = [] + + # Group all leaf nodes by their direct parents + parent_to_leafs = ub.group_items(self._get_leafs(), key=lambda n: self._get_parent(n)) + for parent, children in parent_to_leafs.items(): + # Group children by the new attribute + val_to_subgroup = ub.group_items(children, lambda n: self.graph.nodes[n][key]) + if len(val_to_subgroup) == 1: + # Dont need to do anything if no splits were made + ... + else: + # Otherwise, we have to subdivide the children + for value, subgroup in val_to_subgroup.items(): + # Use a dotted name to make unambiguous tree splits + new_parent = f'{parent}.{key}={value}' + # Mark edges to add / remove to implement the split + remove_edges.extend([(parent, n) for n in subgroup]) + add_edges.extend([(parent, new_parent) for n in subgroup]) + add_edges.extend([(new_parent, n) for n in subgroup]) + add_nodes.append(new_parent) + + # Add weights to the prior parent + if weights is not None: + weights_group = np.asarray(list(ub.take(weights, val_to_subgroup.keys()))) + weights_group = weights_group / weights_group.sum() + self.graph.nodes[parent]['weights'] = weights_group + else: + self.graph.nodes[parent]["weights"] = None + + # Modify the graph + self.graph.remove_edges_from(remove_edges) + self.graph.add_nodes_from(add_nodes, weights=None) + self.graph.add_edges_from(add_edges) - def _sample_many(self, num, items): + def _sample_many(self, num, return_attributes=False): for _ in range(num): idx = self.sample() - item = items[idx] - yield item + if return_attributes: + node = dict(self.graph.nodes[idx]) + node.pop("label") + yield node + else: + yield idx def sample(self): - chosen = self.pools - while ub.iterable(chosen): - if isinstance(chosen, dict): - # processing a node, sample using weights - weights = chosen["weights"] - children = chosen["children"] - num = len(children) - if weights is None: - idx = self.rng.randint(0, num) - else: - idx = self.rng.choice(num, 1, p=weights)[0] - chosen = children[idx] - elif isinstance(chosen, list): - # processing a leaf, sample uniformly - num = len(chosen) - idx = self.rng.randint(0, num) - chosen = chosen[idx] - return chosen - - def _subdivide_leaf(self, leaf, items, key=None, weights=None): - assert isinstance(leaf, list) - if len(leaf) == 1: - return leaf + current = '__root__' + while self.graph.out_degree(current) > 0: + children = list(self.graph.neighbors(current)) + num = len(children) - if key is not None: - groupids = list(map(key, ub.take(items, leaf))) - else: - groupids = list(ub.take(items, leaf)) - - groups = ub.group_items(leaf, groupids) - group_keys = groups.keys() - group_values = list(groups.values()) - - if len(group_values) > 1: - if weights is not None: - group_weights = np.asarray(list(ub.take(weights, group_keys))) - weights = group_weights / group_weights.sum() - return {"weights": weights, "children": group_values} - - def _subdivide_dict(self, nested, items, key=None, weights=None): - if not isinstance(nested, (list, dict)): - return nested - if isinstance(nested, dict): - if isinstance(nested["children"][0], list): - # children are leafs - nested["children"] = self._subdivide_dict(nested["children"], items, key=key, weights=weights) - return nested + weights = self.graph.nodes[current]['weights'] + if weights is None: + idx = self.rng.randint(0, num) else: - # children are nodes - nested["children"] = [self._subdivide_dict(x, items, key=key, weights=weights) for x in nested["children"]] - return nested - else: - return [self._subdivide_leaf(o, items, key=key, weights=weights) for o in nested] + idx = self.rng.choice(num, 1, p=weights)[0] - def subdivide(self, items, key=None, weights=None): - """ - Args: - items (List): - a list of items that the indexes index into. - If these are not the attributes to split nodes on, then - key must be specified: - - key (None | Callable): - if specified, for each ``items[i]`` found transform it into - the group-id based on ``key(items[i])``. - - weights (None | Dict): - a dictionary of weights for possible categories in items. - """ - if isinstance(self.pools, list): - self.pools = self._subdivide_leaf(self.pools, - items, - key=key, - weights=weights) - else: - self.pools = self._subdivide_dict(self.pools, - items, - key=key, - weights=weights) + current = children[idx] + return current def __len__(self): - if isinstance(self.pools, list): - return len(self.pools) - else: - def nested_len(nested): - return sum(nested_len(x) if isinstance(x, list) else 1 for x in nested) - pool_list = self._traverse(self.pools) - return nested_len(pool_list) + return len(list(self._get_leafs())) def __str__(self): - if isinstance(self.pools, list): - return str(self.pools) - else: - return str(self._traverse(self.pools)) - - def _traverse(self, nested): - if not isinstance(nested, (list, dict)): - return nested - if isinstance(nested, dict): - return self._traverse(nested["children"]) - else: - return [self._traverse(o) for o in nested] + return "\n".join(x for x in nx.generate_network_text(self.graph)) def samecolor_nodata_mask(stream, hwc, relevant_bands, use_regions=0, -- GitLab From a3559ea957340bd0cbe85611cf81190e76cbbedd Mon Sep 17 00:00:00 2001 From: Scott Workman <sworkman@dzynetech.com> Date: Tue, 27 Feb 2024 13:25:31 -0500 Subject: [PATCH 3/7] add input validation so tests properly fail --- geowatch/tasks/fusion/datamodules/data_utils.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/geowatch/tasks/fusion/datamodules/data_utils.py b/geowatch/tasks/fusion/datamodules/data_utils.py index d29ee4f41..d556a3b54 100644 --- a/geowatch/tasks/fusion/datamodules/data_utils.py +++ b/geowatch/tasks/fusion/datamodules/data_utils.py @@ -396,7 +396,14 @@ class NestedPool(): """ def __init__(self, sample_grid, rng=None): self.rng = rng = kwarray.ensure_rng(rng) - self.graph = self._create_graph(sample_grid) + + # validate input + if isinstance(sample_grid, list) and sample_grid: + if isinstance(sample_grid[0], (dict, int, float)): + self.graph = self._create_graph(sample_grid) + return + raise ValueError("""NestedPool only supports input in the + form of a flat list or list of dicts.""") def _create_graph(self, sample_grid): graph = nx.DiGraph() -- GitLab From 00d92cce0407598a0c5a0d62e09ee384eca8ee46 Mon Sep 17 00:00:00 2001 From: Scott Workman <sworkman@dzynetech.com> Date: Tue, 27 Feb 2024 14:52:09 -0500 Subject: [PATCH 4/7] propagate changes from NestedPool to balanced sampling --- .../fusion/datamodules/kwcoco_dataset.py | 114 ++++++++---------- 1 file changed, 53 insertions(+), 61 deletions(-) diff --git a/geowatch/tasks/fusion/datamodules/kwcoco_dataset.py b/geowatch/tasks/fusion/datamodules/kwcoco_dataset.py index c90aad709..7b104d05d 100644 --- a/geowatch/tasks/fusion/datamodules/kwcoco_dataset.py +++ b/geowatch/tasks/fusion/datamodules/kwcoco_dataset.py @@ -136,7 +136,6 @@ from typing import NamedTuple from geowatch import heuristics from geowatch.utils import kwcoco_extensions from geowatch.utils import util_bands -from geowatch.utils import util_iter from geowatch.utils import util_kwarray from geowatch.utils import util_kwimage from geowatch.tasks.fusion import utils @@ -2388,11 +2387,8 @@ class BalanceMixin: Helpers to build the sample grid and balance it """ - def _setup_balance_dataframe(self, new_sample_grid): - target_vidids = [v['video_id'] for v in new_sample_grid['targets']] - - # extract video names - unique_vidids, _idx_to_unique_idx = np.unique(target_vidids, return_inverse=True) + def _get_video_names(self, vidids): + unique_vidids, _idx_to_unique_idx = np.unique(vidids, return_inverse=True) coco_dset = self.sampler.dset try: unique_vidnames = self.sampler.dset.videos(unique_vidids).lookup('name') @@ -2406,12 +2402,9 @@ class BalanceMixin: vidname = video_id unique_vidnames.append(vidname) vidnames = list(ub.take(unique_vidnames, _idx_to_unique_idx)) + return vidnames - # associate targets with positive or negative - target_posbit = kwarray.boolmask( - new_sample_grid['positives_indexes'], - len(new_sample_grid['targets'])) - + def _get_region_names(self, vidnames): # create mapping from video name to region name from kwutil import util_pattern pat = util_pattern.Pattern.coerce(r'\w+_[A-Z]\d+_.*', 'regex') @@ -2421,13 +2414,43 @@ class BalanceMixin: self.vidname_to_region_name[vidname] = "_".join(vidname.split('_')[:2]) else: self.vidname_to_region_name[vidname] = vidname + return list(ub.take(self.vidname_to_region_name, vidnames)) + + def _get_observed_annotations(self, targets): + gid_to_category = ub.AutoDict() + for gid in self.sampler.dset.annots().gids: + cats = self.sampler.dset.annots(image_id=gid).category_names + gid_to_category[gid] = cats + + observed_annos = ub.AutoDict() + for idx, target in enumerate(targets): + observed_cats = gid_to_category[target["main_gid"]] + unique_cats = set(observed_cats) + observed_annos[idx] = unique_cats + return observed_annos + + def _setup_attribute_dataframe(self, new_sample_grid): + """ + Build a dataframe of attributes (for each sample) that can be used for balancing. + """ + video_ids = [v['video_id'] for v in new_sample_grid['targets']] + video_names = self._get_video_names(video_ids) + region_names = self._get_region_names(video_names) + observed_annos = self._get_observed_annotations(new_sample_grid['targets']) + observed_phases = ub.util_dict.map_vals(lambda x: set(heuristics.PHASES).intersection(x), observed_annos) + + # associate targets with positive or negative + contains_positive = ['positive' in v for (k, v) in observed_annos.items()] + contains_phase = [any(v) for (k, v) in observed_phases.items()] - # build a dataframe with video attributes + # build a dataframe with target attributes df = pd.DataFrame({ - 'vidid': target_vidids, - 'vidname': vidnames, - 'is_positive': target_posbit, - 'region': list(ub.take(self.vidname_to_region_name, vidnames)) + 'video_id': video_ids, + 'video_name': video_names, + 'region': region_names, + 'contains_positive': contains_positive, + 'contains_phase': contains_phase, + 'phases': observed_phases.values(), }).reset_index(drop=False) return df @@ -2436,55 +2459,24 @@ class BalanceMixin: Build data structure used for balanced sampling. Helper for __init__ which constructs a NestedPool to balance sampling - acrgeowatch/tasks/fusion/datamodules/kwcoco_dataset.pyoss input domains. - - TODO: - HELP WANTED: We would like to configure the distribution in some - easy to specify way. We should be domain aware, or rather accept - some encoding of the domain. We want to oversample underrepresented - or important batch items and undersample overrepresented or - unimportant easy batch items. The "batch item" part is what makes - this hard because we need the notation of goodness, easiness, etc - at the batch level, which can contain multiple annotations. + across input domains. """ - print('Balancing over regions') - df_videos = self._setup_balance_dataframe(new_sample_grid) - - # balance positive / negatives per region - neg_to_pos_ratio = self.config['neg_to_pos_ratio'] - region_to_pool = {} - for region in df_videos['region'].unique(): - df_pos_frames = df_videos.query(f"region == '{region}' and is_positive == True") - df_neg_frames = df_videos.query(f"region == '{region}' and is_positive == False") - pos_frames_idxs = df_pos_frames['index'] - neg_frames_idxs = df_neg_frames['index'] - - n_pos = len(df_pos_frames) - n_neg = len(df_neg_frames) - max_neg = min(int(max(1, (neg_to_pos_ratio * n_pos))), n_neg) - - neg_region_pool_ = list(util_iter.chunks(neg_frames_idxs, nchunks=max_neg)) - pos_region_pool_ = list(util_iter.chunks(pos_frames_idxs, nchunks=n_pos)) - region_pool = pos_region_pool_ + neg_region_pool_ - region_to_pool[region] = [p for p in region_pool if p] - - # compute maximum to take per region - freqs = list(map(len, region_to_pool.values())) - if len(freqs) == 0: - max_per_region = 100 - warnings.warn('Warning: no region pool') - else: - max_per_region = int(np.median(freqs)) + print('Balancing over attributes') + df_sample_attributes = self._setup_attribute_dataframe(new_sample_grid) + + # Initialize an instance of NestedPool + self.nested_pool = data_utils.NestedPool(df_sample_attributes.to_dict('records')) + + # Compute weights for subdivide + npr = self.config['neg_to_pos_ratio'] + npr_dist = np.asarray([1, npr]) / (1 + npr) + weights_pos = dict(zip([True, False], npr_dist)) - # balance across regions - all_chunks = [] - for region, region_pool in region_to_pool.items(): - rechunked_region_pool = list(util_iter.chunks(region_pool, nchunks=max_per_region)) - all_chunks.extend(rechunked_region_pool) + self.nested_pool.subdivide('region') + self.nested_pool.subdivide('contains_positive', weights=weights_pos) + self.nested_pool.subdivide('contains_phase') - # initialize nested pool - self.nested_pool = data_utils.NestedPool(all_chunks) if self.config['reseed_fit_random_generators']: self.reseed() -- GitLab From 0db3f7d8502d102fe2bf1b9e34369fa3fd5d3eaf Mon Sep 17 00:00:00 2001 From: Scott Workman <sworkman@dzynetech.com> Date: Tue, 27 Feb 2024 16:39:08 -0500 Subject: [PATCH 5/7] minor improvements --- .../tasks/fusion/datamodules/data_utils.py | 33 +++++++++---------- 1 file changed, 15 insertions(+), 18 deletions(-) diff --git a/geowatch/tasks/fusion/datamodules/data_utils.py b/geowatch/tasks/fusion/datamodules/data_utils.py index d556a3b54..c4551ef36 100644 --- a/geowatch/tasks/fusion/datamodules/data_utils.py +++ b/geowatch/tasks/fusion/datamodules/data_utils.py @@ -330,7 +330,7 @@ def _boxes_snap_to_edges(given_box, snap_target): return adjusted_box -class NestedPool(): +class NestedPool(ub.NiceRepr): """ Manages a sampling from a tree of indexes. Helps with balancing samples over multiple criteria. @@ -395,12 +395,14 @@ class NestedPool(): >>> print('hist3 (color) = {}'.format(ub.urepr(hist2, nl=1))) """ def __init__(self, sample_grid, rng=None): + super().__init__() self.rng = rng = kwarray.ensure_rng(rng) # validate input if isinstance(sample_grid, list) and sample_grid: if isinstance(sample_grid[0], (dict, int, float)): self.graph = self._create_graph(sample_grid) + self._leaf_nodes = [n for n in self.graph.nodes if self.graph.out_degree[n] == 0] return raise ValueError("""NestedPool only supports input in the form of a flat list or list of dicts.""") @@ -421,10 +423,6 @@ class NestedPool(): graph.add_edge(root_node, index) return graph - def _get_leafs(self): - """ Return sink nodes for the graph """ - return (n for n in self.graph.nodes if self.graph.out_degree[n] == 0) - def _get_parent(self, n): """ Get the parent of a node (assume a tree). None if it doesnt exist """ preds = self.graph.pred[n] @@ -440,7 +438,7 @@ class NestedPool(): add_nodes = [] # Group all leaf nodes by their direct parents - parent_to_leafs = ub.group_items(self._get_leafs(), key=lambda n: self._get_parent(n)) + parent_to_leafs = ub.group_items(self._leaf_nodes, key=lambda n: self._get_parent(n)) for parent, children in parent_to_leafs.items(): # Group children by the new attribute val_to_subgroup = ub.group_items(children, lambda n: self.graph.nodes[n][key]) @@ -471,20 +469,15 @@ class NestedPool(): self.graph.add_nodes_from(add_nodes, weights=None) self.graph.add_edges_from(add_edges) - def _sample_many(self, num, return_attributes=False): + def _sample_many(self, num): for _ in range(num): idx = self.sample() - if return_attributes: - node = dict(self.graph.nodes[idx]) - node.pop("label") - yield node - else: - yield idx + yield idx def sample(self): current = '__root__' while self.graph.out_degree(current) > 0: - children = list(self.graph.neighbors(current)) + children = list(self.graph.successors(current)) num = len(children) weights = self.graph.nodes[current]['weights'] @@ -497,10 +490,14 @@ class NestedPool(): return current def __len__(self): - return len(list(self._get_leafs())) - - def __str__(self): - return "\n".join(x for x in nx.generate_network_text(self.graph)) + return len(list(self._leaf_nodes)) + + def __nice__(self): + n_nodes = self.graph.number_of_nodes() + n_edges = self.graph.number_of_edges() + n_leafs = self.__len__() + n_depth = len(nx.algorithms.dag.dag_longest_path(self.graph)) + return f'nodes={n_nodes}, edges={n_edges}, leafs={n_leafs}, depth={n_depth}' def samecolor_nodata_mask(stream, hwc, relevant_bands, use_regions=0, -- GitLab From 9e4112323686ef5dd9fd77d9f4ac14bccbee9fbf Mon Sep 17 00:00:00 2001 From: Scott Workman <sworkman@dzynetech.com> Date: Tue, 27 Feb 2024 16:51:43 -0500 Subject: [PATCH 6/7] rename NestedPool to BalancedSampleTree --- geowatch/tasks/fusion/datamodules/data_utils.py | 8 ++++---- geowatch/tasks/fusion/datamodules/kwcoco_dataset.py | 6 +++--- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/geowatch/tasks/fusion/datamodules/data_utils.py b/geowatch/tasks/fusion/datamodules/data_utils.py index c4551ef36..aeabb1e89 100644 --- a/geowatch/tasks/fusion/datamodules/data_utils.py +++ b/geowatch/tasks/fusion/datamodules/data_utils.py @@ -330,13 +330,13 @@ def _boxes_snap_to_edges(given_box, snap_target): return adjusted_box -class NestedPool(ub.NiceRepr): +class BalancedSampleTree(ub.NiceRepr): """ Manages a sampling from a tree of indexes. Helps with balancing samples over multiple criteria. Example: - >>> from geowatch.tasks.fusion.datamodules.data_utils import NestedPool + >>> from geowatch.tasks.fusion.datamodules.data_utils import BalancedSampleTree >>> # Given a grid of sample locations and attribute information >>> # (e.g., region, category). >>> sample_grid = [ @@ -361,7 +361,7 @@ class NestedPool(ub.NiceRepr): >>> # >>> # First we can just create a flat uniform sampling grid >>> # and inspect the imbalance that causes. - >>> self = NestedPool(sample_grid) + >>> self = BalancedSampleTree(sample_grid) >>> print(f'self={self}') >>> sampled = list(self._sample_many(100, return_attributes=True)) >>> hist0 = ub.dict_hist([(g['region'], g['category']) for g in sampled]) @@ -404,7 +404,7 @@ class NestedPool(ub.NiceRepr): self.graph = self._create_graph(sample_grid) self._leaf_nodes = [n for n in self.graph.nodes if self.graph.out_degree[n] == 0] return - raise ValueError("""NestedPool only supports input in the + raise ValueError("""BalancedSampleTree only supports input in the form of a flat list or list of dicts.""") def _create_graph(self, sample_grid): diff --git a/geowatch/tasks/fusion/datamodules/kwcoco_dataset.py b/geowatch/tasks/fusion/datamodules/kwcoco_dataset.py index 7b104d05d..4e1b44e74 100644 --- a/geowatch/tasks/fusion/datamodules/kwcoco_dataset.py +++ b/geowatch/tasks/fusion/datamodules/kwcoco_dataset.py @@ -2458,15 +2458,15 @@ class BalanceMixin: """ Build data structure used for balanced sampling. - Helper for __init__ which constructs a NestedPool to balance sampling + Helper for __init__ which constructs a BalancedSampleTree to balance sampling across input domains. """ print('Balancing over attributes') df_sample_attributes = self._setup_attribute_dataframe(new_sample_grid) - # Initialize an instance of NestedPool - self.nested_pool = data_utils.NestedPool(df_sample_attributes.to_dict('records')) + # Initialize an instance of BalancedSampleTree + self.nested_pool = data_utils.BalancedSampleTree(df_sample_attributes.to_dict('records')) # Compute weights for subdivide npr = self.config['neg_to_pos_ratio'] -- GitLab From 671e55fcf99b7603e2f63752033a961d3ae678dc Mon Sep 17 00:00:00 2001 From: Scott Workman <sworkman@dzynetech.com> Date: Tue, 27 Feb 2024 21:44:25 -0500 Subject: [PATCH 7/7] handle edge case, update test --- .../tasks/fusion/datamodules/data_utils.py | 21 +++++++++++-------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/geowatch/tasks/fusion/datamodules/data_utils.py b/geowatch/tasks/fusion/datamodules/data_utils.py index aeabb1e89..c9dff6055 100644 --- a/geowatch/tasks/fusion/datamodules/data_utils.py +++ b/geowatch/tasks/fusion/datamodules/data_utils.py @@ -363,21 +363,21 @@ class BalancedSampleTree(ub.NiceRepr): >>> # and inspect the imbalance that causes. >>> self = BalancedSampleTree(sample_grid) >>> print(f'self={self}') - >>> sampled = list(self._sample_many(100, return_attributes=True)) + >>> sampled = list(ub.take(sample_grid, self._sample_many(100))) >>> hist0 = ub.dict_hist([(g['region'], g['category']) for g in sampled]) >>> print('hist0 = {}'.format(ub.urepr(hist0, nl=1))) >>> # >>> # We can subdivide the indexes based on region to improve balance. >>> self.subdivide('region') >>> print(f'self={self}') - >>> sampled = list(self._sample_many(100, return_attributes=True)) + >>> sampled = list(ub.take(sample_grid, self._sample_many(100))) >>> hist1 = ub.dict_hist([(g['region'], g['category']) for g in sampled]) >>> print('hist1 = {}'.format(ub.urepr(hist1, nl=1))) >>> # >>> # We can further subdivide by category. >>> self.subdivide('category') >>> print(f'self={self}') - >>> sampled = list(self._sample_many(100, return_attributes=True)) + >>> sampled = list(ub.take(sample_grid, self._sample_many(100))) >>> hist2 = ub.dict_hist([(g['region'], g['category']) for g in sampled]) >>> print('hist2 = {}'.format(ub.urepr(hist2, nl=1))) >>> # @@ -385,14 +385,14 @@ class BalancedSampleTree(ub.NiceRepr): >>> weights = { 'red': .25, 'blue': .25, 'green': .4, 'purple': .1 } >>> self.subdivide('color', weights=weights) >>> print(f'self={self}') - >>> sampled = list(self._sample_many(100, return_attributes=True)) - >>> hist2 = ub.dict_hist([ + >>> sampled = list(ub.take(sample_grid, self._sample_many(100))) + >>> hist3 = ub.dict_hist([ >>> (g['region'], g['category'], g['color']) for g in sampled >>> ]) - >>> print('hist3 = {}'.format(ub.urepr(hist2, nl=1))) - >>> hist2 = ub.dict_hist([(g['color']) for g in sampled]) + >>> print('hist3 = {}'.format(ub.urepr(hist3, nl=1))) + >>> hist3_color = ub.dict_hist([(g['color']) for g in sampled]) >>> print('color weights = {}'.format(ub.urepr(weights, nl=1))) - >>> print('hist3 (color) = {}'.format(ub.urepr(hist2, nl=1))) + >>> print('hist3 (color) = {}'.format(ub.urepr(hist3_color, nl=1))) """ def __init__(self, sample_grid, rng=None): super().__init__() @@ -459,7 +459,10 @@ class BalancedSampleTree(ub.NiceRepr): # Add weights to the prior parent if weights is not None: weights_group = np.asarray(list(ub.take(weights, val_to_subgroup.keys()))) - weights_group = weights_group / weights_group.sum() + denom = weights_group.sum() + if denom == 0: + raise NotImplementedError('Zero weighted branches are not handled yet.') + weights_group = weights_group / denom self.graph.nodes[parent]['weights'] = weights_group else: self.graph.nodes[parent]["weights"] = None -- GitLab