From 1a4ad4f36387ef67f12f2a7341532b2eb6bb6c33 Mon Sep 17 00:00:00 2001
From: Scott Workman <sworkman@dzynetech.com>
Date: Sun, 25 Feb 2024 16:18:21 -0500
Subject: [PATCH 1/7] add weighted sampling to NestedPool

partially addresses #8
---
 .../tasks/fusion/datamodules/data_utils.py    | 232 +++++++++++++-----
 1 file changed, 167 insertions(+), 65 deletions(-)

diff --git a/geowatch/tasks/fusion/datamodules/data_utils.py b/geowatch/tasks/fusion/datamodules/data_utils.py
index a9a2f7878..f68ba54d6 100644
--- a/geowatch/tasks/fusion/datamodules/data_utils.py
+++ b/geowatch/tasks/fusion/datamodules/data_utils.py
@@ -329,11 +329,11 @@ def _boxes_snap_to_edges(given_box, snap_target):
     return adjusted_box
 
 
-class NestedPool(list):
+class NestedPool():
     """
-    Manages a sampling from a tree of indexes (represented as nested lists).
+    Manages a sampling from a tree of indexes (nodes are dictionaries, leafs are lists).
 
-    Helps with balancing samples over multiple criteria
+    Helps with balancing samples over multiple criteria.
 
     Example:
         >>> from geowatch.tasks.fusion.datamodules.data_utils import NestedPool
@@ -342,22 +342,27 @@ class NestedPool(list):
         >>> # In this case region1 occurs more often than region2 and there is
         >>> # a rare category that only appears twice.
         >>> sample_grid = [
-        >>>     {'region': 'region1', 'category': 'background'},
-        >>>     {'region': 'region1', 'category': 'rare'},
-        >>>     {'region': 'region1', 'category': 'background'},
-        >>>     {'region': 'region1', 'category': 'background'},
-        >>>     {'region': 'region1', 'category': 'background'},
-        >>>     {'region': 'region1', 'category': 'background'},
-        >>>     {'region': 'region1', 'category': 'background'},
-        >>>     {'region': 'region1', 'category': 'background'},
-        >>>     {'region': 'region1', 'category': 'background'},
-        >>>     {'region': 'region2', 'category': 'background'},
-        >>>     {'region': 'region2', 'category': 'background'},
-        >>>     {'region': 'region2', 'category': 'rare'},
+        >>>     { 'region': 'region1', 'category': 'background', 'color': "blue" },
+        >>>     { 'region': 'region1', 'category': 'background', 'color': "purple" },
+        >>>     { 'region': 'region1', 'category': 'background', 'color': "blue" },
+        >>>     { 'region': 'region1', 'category': 'background', 'color': "red" },
+        >>>     { 'region': 'region1', 'category': 'background', 'color': "green" },
+        >>>     { 'region': 'region1', 'category': 'background', 'color': "purple" },
+        >>>     { 'region': 'region1', 'category': 'background', 'color': "blue" },
+        >>>     { 'region': 'region1', 'category': 'rare',       'color': "red" },
+        >>>     { 'region': 'region1', 'category': 'rare',       'color': "green" },
+        >>>     { 'region': 'region2', 'category': 'background', 'color': "red" },
+        >>>     { 'region': 'region2', 'category': 'background', 'color': "green" },
+        >>>     { 'region': 'region2', 'category': 'background', 'color': "blue" },
+        >>>     { 'region': 'region2', 'category': 'background', 'color': "purple" },
+        >>>     { 'region': 'region2', 'category': 'background', 'color': "red" },
+        >>>     { 'region': 'region2', 'category': 'background', 'color': "green" },
+        >>>     { 'region': 'region2', 'category': 'rare',       'color': "purple" },
+        >>>     { 'region': 'region2', 'category': 'rare',       'color': "blue" },
         >>> ]
         >>> #
         >>> # First we can just create a flat uniform sampling grid
-        >>> # And inspect the imbalance that causes.
+        >>> # and inspect the imbalance that causes.
         >>> sample_idxs = list(range(len(sample_grid)))
         >>> self = NestedPool(sample_idxs)
         >>> print(f'self={self}')
@@ -378,23 +383,72 @@ class NestedPool(list):
         >>> sampled = list(self._sample_many(100, sample_grid))
         >>> hist2 = ub.dict_hist([(g['region'], g['category']) for g in sampled])
         >>> print('hist2 = {}'.format(ub.urepr(hist2, nl=1)))
+        >>> #
+        >>> # We can further subdivide by color, using custom weights.
+        >>> weights = { 'red': .25, 'blue': .25, 'green': .4, 'purple': .1 }
+        >>> self.subdivide([g['color'] for g in sample_grid], weights=weights)
+        >>> print(f'self={self}')
+        >>> sampled = list(self._sample_many(100, sample_grid))
+        >>> hist3 = ub.dict_hist([
+        >>>     (g['region'], g['category'], g['color']) for g in sampled
+        >>> ])
+        >>> print('hist3 = {}'.format(ub.urepr(hist3, nl=1)))
+        >>> hist3_color = ub.dict_hist([(g['color']) for g in sampled])
+        >>> print('color weights = {}'.format(ub.urepr(weights, nl=1)))
+        >>> print('hist3 (color) = {}'.format(ub.urepr(hist3_color, nl=1)))
 
     Example:
         >>> from geowatch.tasks.fusion.datamodules.data_utils import *  # NOQA
         >>> nested1 = NestedPool([[[1], [2, 3], [4, 5, 6], [7, 8, 9, 0]], [[11, 12, 13]]])
-        >>> list(nested1.leafs())
-
         >>> print({nested1.sample() for i in range(100)})
         >>> nested2 = NestedPool([[101], [102, 103], [104, 105, 106], [107, 8, 9, 0]])
         >>> print({nested2.sample() for i in range(100)})
-        >>> nested3 = NestedPool([nested1, nested2, [4, 59, 9, [], []]])
-        >>> print({nested3.sample() for i in range(100)})
-        >>> print(ub.urepr(ub.dict_hist(nested3.sample() for i in range(100))))
     """
     def __init__(self, pools, rng=None):
-        super().__init__(pools)
         self.rng = rng = kwarray.ensure_rng(rng)
-        self.pools = pools
+        self.pools = self._convert_to_weighted(self._validate(pools))
+
+    def _validate(self, _input):
+        # TODO: robustly validate the input to __init__
+        if not isinstance(_input, list):
+            raise ValueError('NestedPool requires a list as input.')
+        if len(_input) == 0:
+            raise ValueError('NestedPool received an empty list as input.')
+
+        def remove_empty_leafs(nested):
+            if not isinstance(nested, list):
+                return nested
+            return list(filter(lambda x: x != [], (map(remove_empty_leafs, nested))))
+        return remove_empty_leafs(_input)
+
+    def _compute_depth(self, x):
+        return isinstance(x, list) and max(map(self._compute_depth, x)) + 1
+
+    def _make_node(self, x):
+        return {"weights": None, "children": x}
+
+    def _convert_to_weighted(self, nested):
+        """
+        Convert from a tree (as a nested list of leaf values) to a representation
+        where nodes are dictionaries and children are lists. This allows specifying a
+        weight at every node to use when sampling.
+
+        Note: A single level is still just a flat list internally (a leaf).
+        """
+        if not isinstance(nested, list):
+            return nested
+
+        max_depth = self._compute_depth(nested)
+        if max_depth == 1:
+            return nested
+
+        if max_depth == 2 and len(nested) >= 2:
+            return self._make_node(nested)
+        else:
+            collect = []
+            for o in nested:
+                collect.append(self._convert_to_weighted(o))
+            return self._make_node(collect)
 
     def _sample_many(self, num, items):
         for _ in range(num):
@@ -402,7 +456,62 @@ class NestedPool(list):
             item = items[idx]
             yield item
 
-    def subdivide(self, items, key=None):
+    def sample(self):
+        chosen = self.pools
+        while ub.iterable(chosen):
+            if isinstance(chosen, dict):
+                # processing a node, sample using weights
+                weights = chosen["weights"]
+                children = chosen["children"]
+                num = len(children)
+                if weights is None:
+                    idx = self.rng.randint(0, num)
+                else:
+                    idx = self.rng.choice(num, 1, p=weights)[0]
+                chosen = children[idx]
+            elif isinstance(chosen, list):
+                # processing a leaf, sample uniformly
+                num = len(chosen)
+                idx = self.rng.randint(0, num)
+                chosen = chosen[idx]
+        return chosen
+
+    def _subdivide_leaf(self, leaf, items, key=None, weights=None):
+        assert isinstance(leaf, list)
+        if len(leaf) == 1:
+            return leaf
+
+        if key is not None:
+            groupids = list(map(key, ub.take(items, leaf)))
+        else:
+            groupids = list(ub.take(items, leaf))
+
+        groups = ub.group_items(leaf, groupids)
+        group_keys = groups.keys()
+        group_values = list(groups.values())
+
+        if len(group_values) > 1:
+            if weights is not None:
+                group_weights = np.asarray(list(ub.take(weights, group_keys)))
+                weights = group_weights / group_weights.sum()
+            return {"weights": weights, "children": group_values}
+
+    def _subdivide_dict(self, nested, items, key=None, weights=None):
+        if not isinstance(nested, (list, dict)):
+            return nested
+        if isinstance(nested, dict):
+            if isinstance(nested["children"][0], list):
+                # children are leafs
+                nested["children"] = self._subdivide_dict(nested["children"], items, key=key, weights=weights)
+                return nested
+            else:
+                # children are nodes
+                nested["children"] = [self._subdivide_dict(x, items, key=key, weights=weights) for x in nested["children"]]
+                return nested
+        else:
+            return [self._subdivide_leaf(o, items, key=key, weights=weights) for o in nested]
+
+    def subdivide(self, items, key=None, weights=None):
         """
         Args:
             items (List):
@@ -413,50 +522,43 @@ class NestedPool(list):
             key (None | Callable):
                 if specified, for each ``items[i]`` found transform it into
                 the group-id based on ``key(items[i])``.
-        """
-        for leaf in self.leafs():
-            if key is not None:
-                groupids = list(map(key, ub.take(items, leaf)))
-            else:
-                groupids = list(ub.take(items, leaf))
-            new_subleafs = list(ub.group_items(leaf, groupids).values())
-            if len(new_subleafs) > 1:
-                # Clear the current leaf and replace it with new subleafs
-                leaf[:] = new_subleafs
 
-    def leafs(self):
+            weights (None | Dict):
+                a dictionary of weights for possible categories in items.
         """
-        Iterate over the deepest index lists in this pool.
-        """
-        stack = [self]
-        while stack:
-            curr = stack.pop()
-            assert ub.iterable(curr)
-            if len(curr) == 0 or not ub.iterable(curr[0]):
-                # Found a leaf
-                yield curr
-            else:
-                for child in curr:
-                    stack.append(child)
+        if isinstance(self.pools, list):
+            self.pools = self._subdivide_leaf(self.pools,
+                                              items,
+                                              key=key,
+                                              weights=weights)
+        else:
+            self.pools = self._subdivide_dict(self.pools,
+                                              items,
+                                              key=key,
+                                              weights=weights)
+
+    def __len__(self):
+        if isinstance(self.pools, list):
+            return len(self.pools)
+        else:
+            def nested_len(nested):
+                return sum(nested_len(x) if isinstance(x, list) else 1 for x in nested)
+            pool_list = self._traverse(self.pools)
+            return nested_len(pool_list)
+
+    def __str__(self):
+        if isinstance(self.pools, list):
+            return str(self.pools)
+        else:
+            return str(self._traverse(self.pools))
 
-    def sample(self):
-        # Hack for empty lists
-        chosen = self
-        i = 0
-        while ub.iterable(chosen):
-            chosen = self
-            i += 1
-            while ub.iterable(chosen):
-                i += 1
-                num = len(chosen)
-                # Fixme: not robust
-                if i > 100000:
-                    raise Exception('Too many samples. Bad balance?')
-                if not num:
-                    break
-                idx = self.rng.randint(0, num)
-                chosen = chosen[idx]
-        return chosen
+    def _traverse(self, nested):
+        if not isinstance(nested, (list, dict)):
+            return nested
+        if isinstance(nested, dict):
+            return self._traverse(nested["children"])
+        else:
+            return [self._traverse(o) for o in nested]
 
 
 def samecolor_nodata_mask(stream, hwc, relevant_bands, use_regions=0,
-- 
GitLab


From 21d66b340cee5eaee92834f964aa175e3f32f429 Mon Sep 17 00:00:00 2001
From: Scott Workman <sworkman@dzynetech.com>
Date: Mon, 26 Feb 2024 20:46:25 -0500
Subject: [PATCH 2/7] use networkx to manage the tree

---
 .../tasks/fusion/datamodules/data_utils.py    | 271 +++++++-----------
 1 file changed, 103 insertions(+), 168 deletions(-)

diff --git a/geowatch/tasks/fusion/datamodules/data_utils.py b/geowatch/tasks/fusion/datamodules/data_utils.py
index f68ba54d6..d29ee4f41 100644
--- a/geowatch/tasks/fusion/datamodules/data_utils.py
+++ b/geowatch/tasks/fusion/datamodules/data_utils.py
@@ -6,6 +6,7 @@ import numpy as np
 import ubelt as ub
 import kwimage
 import kwarray
+import networkx as nx
 
 
 def resolve_scale_request(request=None, data_gsd=None):
@@ -331,16 +332,13 @@ def _boxes_snap_to_edges(given_box, snap_target):
 
 class NestedPool():
     """
-    Manages a sampling from a tree of indexes (nodes are dictionaries, leafs are lists).
-
-    Helps with balancing samples over multiple criteria.
+    Manages a sampling from a tree of indexes. Helps with balancing
+    samples over multiple criteria.
 
     Example:
         >>> from geowatch.tasks.fusion.datamodules.data_utils import NestedPool
-        >>> # Lets say that you have a grid of sample locations with information
-        >>> # about them - say a source region and what category they contain.
-        >>> # In this case region1 occurs more often than region2 and there is
-        >>> # a rare category that only appears twice.
+        >>> # Given a grid of sample locations and attribute information
+        >>> # (e.g., region, category).
         >>> sample_grid = [
         >>>     { 'region': 'region1', 'category': 'background', 'color': "blue" },
         >>>     { 'region': 'region1', 'category': 'background', 'color': "purple" },
@@ -351,8 +349,8 @@ class NestedPool():
         >>>     { 'region': 'region1', 'category': 'background', 'color': "blue" },
         >>>     { 'region': 'region1', 'category': 'rare',       'color': "red" },
         >>>     { 'region': 'region1', 'category': 'rare',       'color': "green" },
-        >>>     { 'region': 'region2', 'category': 'background', 'color': "red" },
-        >>>     { 'region': 'region2', 'category': 'background', 'color': "green" },
+        >>>     { 'region': 'region1', 'category': 'background', 'color': "red" },
+        >>>     { 'region': 'region1', 'category': 'background', 'color': "green" },
         >>>     { 'region': 'region2', 'category': 'background', 'color': "blue" },
         >>>     { 'region': 'region2', 'category': 'background', 'color': "purple" },
         >>>     { 'region': 'region2', 'category': 'background', 'color': "red" },
@@ -363,202 +361,139 @@ class NestedPool():
         >>> #
         >>> # First we can just create a flat uniform sampling grid
         >>> # and inspect the imbalance that causes.
-        >>> sample_idxs = list(range(len(sample_grid)))
-        >>> self = NestedPool(sample_idxs)
+        >>> self = NestedPool(sample_grid)
         >>> print(f'self={self}')
-        >>> sampled = list(self._sample_many(100, sample_grid))
+        >>> sampled = list(self._sample_many(100, return_attributes=True))
         >>> hist0 = ub.dict_hist([(g['region'], g['category']) for g in sampled])
         >>> print('hist0 = {}'.format(ub.urepr(hist0, nl=1)))
         >>> #
         >>> # We can subdivide the indexes based on region to improve balance.
-        >>> self.subdivide([g['region'] for g in sample_grid])
+        >>> self.subdivide('region')
         >>> print(f'self={self}')
-        >>> sampled = list(self._sample_many(100, sample_grid))
+        >>> sampled = list(self._sample_many(100, return_attributes=True))
         >>> hist1 = ub.dict_hist([(g['region'], g['category']) for g in sampled])
         >>> print('hist1 = {}'.format(ub.urepr(hist1, nl=1)))
         >>> #
         >>> # We can further subdivide by category.
-        >>> self.subdivide([g['category'] for g in sample_grid])
+        >>> self.subdivide('category')
         >>> print(f'self={self}')
-        >>> sampled = list(self._sample_many(100, sample_grid))
+        >>> sampled = list(self._sample_many(100, return_attributes=True))
         >>> hist2 = ub.dict_hist([(g['region'], g['category']) for g in sampled])
         >>> print('hist2 = {}'.format(ub.urepr(hist2, nl=1)))
         >>> #
-        >>> # We can further subdivide by color, using custom weights.
+        >>> # We can further subdivide by color, with custom weights.
         >>> weights = { 'red': .25, 'blue': .25, 'green': .4, 'purple': .1 }
-        >>> self.subdivide([g['color'] for g in sample_grid], weights=weights)
+        >>> self.subdivide('color', weights=weights)
         >>> print(f'self={self}')
-        >>> sampled = list(self._sample_many(100, sample_grid))
-        >>> hist3 = ub.dict_hist([
+        >>> sampled = list(self._sample_many(100, return_attributes=True))
+        >>> hist2 = ub.dict_hist([
         >>>     (g['region'], g['category'], g['color']) for g in sampled
         >>> ])
-        >>> print('hist3 = {}'.format(ub.urepr(hist3, nl=1)))
-        >>> hist3_color = ub.dict_hist([(g['color']) for g in sampled])
+        >>> print('hist3 = {}'.format(ub.urepr(hist2, nl=1)))
+        >>> hist2 = ub.dict_hist([(g['color']) for g in sampled])
         >>> print('color weights = {}'.format(ub.urepr(weights, nl=1)))
-        >>> print('hist3 (color) = {}'.format(ub.urepr(hist3_color, nl=1)))
-
-    Example:
-        >>> from geowatch.tasks.fusion.datamodules.data_utils import *  # NOQA
-        >>> nested1 = NestedPool([[[1], [2, 3], [4, 5, 6], [7, 8, 9, 0]], [[11, 12, 13]]])
-        >>> print({nested1.sample() for i in range(100)})
-        >>> nested2 = NestedPool([[101], [102, 103], [104, 105, 106], [107, 8, 9, 0]])
-        >>> print({nested2.sample() for i in range(100)})
+        >>> print('hist3 (color) = {}'.format(ub.urepr(hist2, nl=1)))
     """
-    def __init__(self, pools, rng=None):
+    def __init__(self, sample_grid, rng=None):
         self.rng = rng = kwarray.ensure_rng(rng)
-        self.pools = self._convert_to_weighted(self._validate(pools))
-
-    def _validate(self, _input):
-        # TODO: robustly validate the input to __init__
-        if not isinstance(_input, list):
-            raise ValueError('NestedPool requires a list as input.')
-        if len(_input) == 0:
-            raise ValueError('NestedPool received an empty list as input.')
+        self.graph = self._create_graph(sample_grid)
 
-        def remove_empty_leafs(nested):
-            if not isinstance(nested, list):
-                return nested
-            return list(filter(lambda x: x != [], (map(remove_empty_leafs, nested))))
-        return remove_empty_leafs(_input)
+    def _create_graph(self, sample_grid):
+        graph = nx.DiGraph()
 
-    def _compute_depth(self, x):
-        return isinstance(x, list) and max(map(self._compute_depth, x)) + 1
+        # make a special root node
+        root_node = '__root__'
+        graph.add_node(root_node, weights=None)
 
-    def _make_node(self, x):
-        return {"weights": None, "children": x}
-
-    def _convert_to_weighted(self, nested):
-        """
-        Convert from a tree (as a nested list of leaf values) to a representation
-        where nodes are dictionaries and children are lists. This allows specifying a
-        weight at every node to use when sampling.
-
-        Note: A single level is still just a flat list internally (a leaf).
-        """
-        if not isinstance(nested, list):
-            return nested
-
-        max_depth = self._compute_depth(nested)
-        if max_depth == 1:
-            return nested
-
-        if max_depth == 2 and len(nested) >= 2:
-            return self._make_node(nested)
+        for index, item in enumerate(sample_grid):
+            label = f'{index:02d} ' + ub.urepr(item, nl=0, compact=1, nobr=1)
+            if isinstance(item, dict):
+                graph.add_node(index, label=label, **item)
+            else:
+                graph.add_node(index, label=label)
+            graph.add_edge(root_node, index)
+        return graph
+
+    def _get_leafs(self):
+        """ Return sink nodes for the graph """
+        return (n for n in self.graph.nodes if self.graph.out_degree[n] == 0)
+
+    def _get_parent(self, n):
+        """ Get the parent of a node (assume a tree). None if it doesnt exist """
+        preds = self.graph.pred[n]
+        if len(preds):
+            assert len(preds) == 1
+            return next(iter(preds))
         else:
-            collect = []
-            for o in nested:
-                collect.append(self._convert_to_weighted(o))
-            return self._make_node(collect)
+            return None
+
+    def subdivide(self, key, weights=None):
+        remove_edges = []
+        add_edges = []
+        add_nodes = []
+
+        # Group all leaf nodes by their direct parents
+        parent_to_leafs = ub.group_items(self._get_leafs(), key=lambda n: self._get_parent(n))
+        for parent, children in parent_to_leafs.items():
+            # Group children by the new attribute
+            val_to_subgroup = ub.group_items(children, lambda n:  self.graph.nodes[n][key])
+            if len(val_to_subgroup) == 1:
+                # Dont need to do anything if no splits were made
+                ...
+            else:
+                # Otherwise, we have to subdivide the children
+                for value, subgroup in val_to_subgroup.items():
+                    # Use a dotted name to make unambiguous tree splits
+                    new_parent = f'{parent}.{key}={value}'
+                    # Mark edges to add / remove to implement the split
+                    remove_edges.extend([(parent, n) for n in subgroup])
+                    add_edges.extend([(parent, new_parent) for n in subgroup])
+                    add_edges.extend([(new_parent, n) for n in subgroup])
+                    add_nodes.append(new_parent)
+
+                # Add weights to the prior parent
+                if weights is not None:
+                    weights_group = np.asarray(list(ub.take(weights, val_to_subgroup.keys())))
+                    weights_group = weights_group / weights_group.sum()
+                    self.graph.nodes[parent]['weights'] = weights_group
+                else:
+                    self.graph.nodes[parent]["weights"] = None
+
+        # Modify the graph
+        self.graph.remove_edges_from(remove_edges)
+        self.graph.add_nodes_from(add_nodes, weights=None)
+        self.graph.add_edges_from(add_edges)
 
-    def _sample_many(self, num, items):
+    def _sample_many(self, num, return_attributes=False):
         for _ in range(num):
             idx = self.sample()
-            item = items[idx]
-            yield item
+            if return_attributes:
+                node = dict(self.graph.nodes[idx])
+                node.pop("label")
+                yield node
+            else:
+                yield idx
 
     def sample(self):
-        chosen = self.pools
-        while ub.iterable(chosen):
-            if isinstance(chosen, dict):
-                # processing a node, sample using weights
-                weights = chosen["weights"]
-                children = chosen["children"]
-                num = len(children)
-                if weights is None:
-                    idx = self.rng.randint(0, num)
-                else:
-                    idx = self.rng.choice(num, 1, p=weights)[0]
-                chosen = children[idx]
-            elif isinstance(chosen, list):
-                # processing a leaf, sample uniformly
-                num = len(chosen)
-                idx = self.rng.randint(0, num)
-                chosen = chosen[idx]
-        return chosen
-
-    def _subdivide_leaf(self, leaf, items, key=None, weights=None):
-        assert isinstance(leaf, list)
-        if len(leaf) == 1:
-            return leaf
+        current = '__root__'
+        while self.graph.out_degree(current) > 0:
+            children = list(self.graph.neighbors(current))
+            num = len(children)
 
-        if key is not None:
-            groupids = list(map(key, ub.take(items, leaf)))
-        else:
-            groupids = list(ub.take(items, leaf))
-
-        groups = ub.group_items(leaf, groupids)
-        group_keys = groups.keys()
-        group_values = list(groups.values())
-
-        if len(group_values) > 1:
-            if weights is not None:
-                group_weights = np.asarray(list(ub.take(weights, group_keys)))
-                weights = group_weights / group_weights.sum()
-            return {"weights": weights, "children": group_values}
-
-    def _subdivide_dict(self, nested, items, key=None, weights=None):
-        if not isinstance(nested, (list, dict)):
-            return nested
-        if isinstance(nested, dict):
-            if isinstance(nested["children"][0], list):
-                # children are leafs
-                nested["children"] = self._subdivide_dict(nested["children"], items, key=key, weights=weights)
-                return nested
+            weights = self.graph.nodes[current]['weights']
+            if weights is None:
+                idx = self.rng.randint(0, num)
             else:
-                # children are nodes
-                nested["children"] = [self._subdivide_dict(x, items, key=key, weights=weights) for x in nested["children"]]
-                return nested
-        else:
-            return [self._subdivide_leaf(o, items, key=key, weights=weights) for o in nested]
+                idx = self.rng.choice(num, 1, p=weights)[0]
 
-    def subdivide(self, items, key=None, weights=None):
-        """
-        Args:
-            items (List):
-                a list of items that the indexes index into.
-                If these are not the attributes to split nodes on, then
-                key must be specified:
-
-            key (None | Callable):
-                if specified, for each ``items[i]`` found transform it into
-                the group-id based on ``key(items[i])``.
-
-            weights (None | Dict):
-                a dictionary of weights for possible categories in items.
-        """
-        if isinstance(self.pools, list):
-            self.pools = self._subdivide_leaf(self.pools,
-                                              items,
-                                              key=key,
-                                              weights=weights)
-        else:
-            self.pools = self._subdivide_dict(self.pools,
-                                              items,
-                                              key=key,
-                                              weights=weights)
+            current = children[idx]
+        return current
 
     def __len__(self):
-        if isinstance(self.pools, list):
-            return len(self.pools)
-        else:
-            def nested_len(nested):
-                return sum(nested_len(x) if isinstance(x, list) else 1 for x in nested)
-            pool_list = self._traverse(self.pools)
-            return nested_len(pool_list)
+        return len(list(self._get_leafs()))
 
     def __str__(self):
-        if isinstance(self.pools, list):
-            return str(self.pools)
-        else:
-            return str(self._traverse(self.pools))
-
-    def _traverse(self, nested):
-        if not isinstance(nested, (list, dict)):
-            return nested
-        if isinstance(nested, dict):
-            return self._traverse(nested["children"])
-        else:
-            return [self._traverse(o) for o in nested]
+        return "\n".join(x for x in nx.generate_network_text(self.graph))
 
 
 def samecolor_nodata_mask(stream, hwc, relevant_bands, use_regions=0,
-- 
GitLab


From a3559ea957340bd0cbe85611cf81190e76cbbedd Mon Sep 17 00:00:00 2001
From: Scott Workman <sworkman@dzynetech.com>
Date: Tue, 27 Feb 2024 13:25:31 -0500
Subject: [PATCH 3/7] add input validation so tests properly fail

---
 geowatch/tasks/fusion/datamodules/data_utils.py | 9 ++++++++-
 1 file changed, 8 insertions(+), 1 deletion(-)

diff --git a/geowatch/tasks/fusion/datamodules/data_utils.py b/geowatch/tasks/fusion/datamodules/data_utils.py
index d29ee4f41..d556a3b54 100644
--- a/geowatch/tasks/fusion/datamodules/data_utils.py
+++ b/geowatch/tasks/fusion/datamodules/data_utils.py
@@ -396,7 +396,14 @@ class NestedPool():
     """
     def __init__(self, sample_grid, rng=None):
         self.rng = rng = kwarray.ensure_rng(rng)
-        self.graph = self._create_graph(sample_grid)
+
+        # validate input
+        if isinstance(sample_grid, list) and sample_grid:
+            if isinstance(sample_grid[0], (dict, int, float)):
+                self.graph = self._create_graph(sample_grid)
+                return
+        raise ValueError("""NestedPool only supports input in the
+            form of a flat list or list of dicts.""")
 
     def _create_graph(self, sample_grid):
         graph = nx.DiGraph()
-- 
GitLab


From 00d92cce0407598a0c5a0d62e09ee384eca8ee46 Mon Sep 17 00:00:00 2001
From: Scott Workman <sworkman@dzynetech.com>
Date: Tue, 27 Feb 2024 14:52:09 -0500
Subject: [PATCH 4/7] propagate changes from NestedPool to balanced sampling

---
 .../fusion/datamodules/kwcoco_dataset.py      | 114 ++++++++----------
 1 file changed, 53 insertions(+), 61 deletions(-)

diff --git a/geowatch/tasks/fusion/datamodules/kwcoco_dataset.py b/geowatch/tasks/fusion/datamodules/kwcoco_dataset.py
index c90aad709..7b104d05d 100644
--- a/geowatch/tasks/fusion/datamodules/kwcoco_dataset.py
+++ b/geowatch/tasks/fusion/datamodules/kwcoco_dataset.py
@@ -136,7 +136,6 @@ from typing import NamedTuple
 from geowatch import heuristics
 from geowatch.utils import kwcoco_extensions
 from geowatch.utils import util_bands
-from geowatch.utils import util_iter
 from geowatch.utils import util_kwarray
 from geowatch.utils import util_kwimage
 from geowatch.tasks.fusion import utils
@@ -2388,11 +2387,8 @@ class BalanceMixin:
     Helpers to build the sample grid and balance it
     """
 
-    def _setup_balance_dataframe(self, new_sample_grid):
-        target_vidids = [v['video_id'] for v in new_sample_grid['targets']]
-
-        # extract video names
-        unique_vidids, _idx_to_unique_idx = np.unique(target_vidids, return_inverse=True)
+    def _get_video_names(self, vidids):
+        unique_vidids, _idx_to_unique_idx = np.unique(vidids, return_inverse=True)
         coco_dset = self.sampler.dset
         try:
             unique_vidnames = self.sampler.dset.videos(unique_vidids).lookup('name')
@@ -2406,12 +2402,9 @@ class BalanceMixin:
                     vidname = video_id
                 unique_vidnames.append(vidname)
         vidnames = list(ub.take(unique_vidnames, _idx_to_unique_idx))
+        return vidnames
 
-        # associate targets with positive or negative
-        target_posbit = kwarray.boolmask(
-            new_sample_grid['positives_indexes'],
-            len(new_sample_grid['targets']))
-
+    def _get_region_names(self, vidnames):
         # create mapping from video name to region name
         from kwutil import util_pattern
         pat = util_pattern.Pattern.coerce(r'\w+_[A-Z]\d+_.*', 'regex')
@@ -2421,13 +2414,43 @@ class BalanceMixin:
                 self.vidname_to_region_name[vidname] = "_".join(vidname.split('_')[:2])
             else:
                 self.vidname_to_region_name[vidname] = vidname
+        return list(ub.take(self.vidname_to_region_name, vidnames))
+
+    def _get_observed_annotations(self, targets):
+        gid_to_category = ub.AutoDict()
+        for gid in self.sampler.dset.annots().gids:
+            cats = self.sampler.dset.annots(image_id=gid).category_names
+            gid_to_category[gid] = cats
+
+        observed_annos = ub.AutoDict()
+        for idx, target in enumerate(targets):
+            observed_cats = gid_to_category[target["main_gid"]]
+            unique_cats = set(observed_cats)
+            observed_annos[idx] = unique_cats
+        return observed_annos
+
+    def _setup_attribute_dataframe(self, new_sample_grid):
+        """
+        Build a dataframe of attributes (for each sample) that can be used for balancing.
+        """
+        video_ids = [v['video_id'] for v in new_sample_grid['targets']]
+        video_names = self._get_video_names(video_ids)
+        region_names = self._get_region_names(video_names)
+        observed_annos = self._get_observed_annotations(new_sample_grid['targets'])
+        observed_phases = ub.util_dict.map_vals(lambda x: set(heuristics.PHASES).intersection(x), observed_annos)
+
+        # associate targets with positive or negative
+        contains_positive = ['positive' in v for (k, v) in observed_annos.items()]
+        contains_phase = [any(v) for (k, v) in observed_phases.items()]
 
-        # build a dataframe with video attributes
+        # build a dataframe with target attributes
         df = pd.DataFrame({
-            'vidid': target_vidids,
-            'vidname': vidnames,
-            'is_positive': target_posbit,
-            'region': list(ub.take(self.vidname_to_region_name, vidnames))
+            'video_id': video_ids,
+            'video_name': video_names,
+            'region': region_names,
+            'contains_positive': contains_positive,
+            'contains_phase': contains_phase,
+            'phases': observed_phases.values(),
         }).reset_index(drop=False)
         return df
 
@@ -2436,55 +2459,24 @@ class BalanceMixin:
         Build data structure used for balanced sampling.
 
         Helper for __init__ which constructs a NestedPool to balance sampling
-        acrgeowatch/tasks/fusion/datamodules/kwcoco_dataset.pyoss input domains.
-
-        TODO:
-            HELP WANTED: We would like to configure the distribution in some
-            easy to specify way. We should be domain aware, or rather accept
-            some encoding of the domain. We want to oversample underrepresented
-            or important batch items and undersample overrepresented or
-            unimportant easy batch items. The "batch item" part is what makes
-            this hard because we need the notation of goodness, easiness, etc
-            at the batch level, which can contain multiple annotations.
+        across input domains.
         """
 
-        print('Balancing over regions')
-        df_videos = self._setup_balance_dataframe(new_sample_grid)
-
-        # balance positive / negatives per region
-        neg_to_pos_ratio = self.config['neg_to_pos_ratio']
-        region_to_pool = {}
-        for region in df_videos['region'].unique():
-            df_pos_frames = df_videos.query(f"region == '{region}' and is_positive == True")
-            df_neg_frames = df_videos.query(f"region == '{region}' and is_positive == False")
-            pos_frames_idxs = df_pos_frames['index']
-            neg_frames_idxs = df_neg_frames['index']
-
-            n_pos = len(df_pos_frames)
-            n_neg = len(df_neg_frames)
-            max_neg = min(int(max(1, (neg_to_pos_ratio * n_pos))), n_neg)
-
-            neg_region_pool_ = list(util_iter.chunks(neg_frames_idxs, nchunks=max_neg))
-            pos_region_pool_ = list(util_iter.chunks(pos_frames_idxs, nchunks=n_pos))
-            region_pool = pos_region_pool_ + neg_region_pool_
-            region_to_pool[region] = [p for p in region_pool if p]
-
-        # compute maximum to take per region
-        freqs = list(map(len, region_to_pool.values()))
-        if len(freqs) == 0:
-            max_per_region = 100
-            warnings.warn('Warning: no region pool')
-        else:
-            max_per_region = int(np.median(freqs))
+        print('Balancing over attributes')
+        df_sample_attributes = self._setup_attribute_dataframe(new_sample_grid)
+
+        # Initialize an instance of NestedPool
+        self.nested_pool = data_utils.NestedPool(df_sample_attributes.to_dict('records'))
+
+        # Compute weights for subdivide
+        npr = self.config['neg_to_pos_ratio']
+        npr_dist = np.asarray([1, npr]) / (1 + npr)
+        weights_pos = dict(zip([True, False], npr_dist))
 
-        # balance across regions
-        all_chunks = []
-        for region, region_pool in region_to_pool.items():
-            rechunked_region_pool = list(util_iter.chunks(region_pool, nchunks=max_per_region))
-            all_chunks.extend(rechunked_region_pool)
+        self.nested_pool.subdivide('region')
+        self.nested_pool.subdivide('contains_positive', weights=weights_pos)
+        self.nested_pool.subdivide('contains_phase')
 
-        # initialize nested pool
-        self.nested_pool = data_utils.NestedPool(all_chunks)
         if self.config['reseed_fit_random_generators']:
             self.reseed()
 
-- 
GitLab


From 0db3f7d8502d102fe2bf1b9e34369fa3fd5d3eaf Mon Sep 17 00:00:00 2001
From: Scott Workman <sworkman@dzynetech.com>
Date: Tue, 27 Feb 2024 16:39:08 -0500
Subject: [PATCH 5/7] minor improvements

---
 .../tasks/fusion/datamodules/data_utils.py    | 33 +++++++++----------
 1 file changed, 15 insertions(+), 18 deletions(-)

diff --git a/geowatch/tasks/fusion/datamodules/data_utils.py b/geowatch/tasks/fusion/datamodules/data_utils.py
index d556a3b54..c4551ef36 100644
--- a/geowatch/tasks/fusion/datamodules/data_utils.py
+++ b/geowatch/tasks/fusion/datamodules/data_utils.py
@@ -330,7 +330,7 @@ def _boxes_snap_to_edges(given_box, snap_target):
     return adjusted_box
 
 
-class NestedPool():
+class NestedPool(ub.NiceRepr):
     """
     Manages a sampling from a tree of indexes. Helps with balancing
     samples over multiple criteria.
@@ -395,12 +395,14 @@ class NestedPool():
         >>> print('hist3 (color) = {}'.format(ub.urepr(hist2, nl=1)))
     """
     def __init__(self, sample_grid, rng=None):
+        super().__init__()
         self.rng = rng = kwarray.ensure_rng(rng)
 
         # validate input
         if isinstance(sample_grid, list) and sample_grid:
             if isinstance(sample_grid[0], (dict, int, float)):
                 self.graph = self._create_graph(sample_grid)
+                self._leaf_nodes = [n for n in self.graph.nodes if self.graph.out_degree[n] == 0]
                 return
         raise ValueError("""NestedPool only supports input in the
             form of a flat list or list of dicts.""")
@@ -421,10 +423,6 @@ class NestedPool():
             graph.add_edge(root_node, index)
         return graph
 
-    def _get_leafs(self):
-        """ Return sink nodes for the graph """
-        return (n for n in self.graph.nodes if self.graph.out_degree[n] == 0)
-
     def _get_parent(self, n):
         """ Get the parent of a node (assume a tree). None if it doesnt exist """
         preds = self.graph.pred[n]
@@ -440,7 +438,7 @@ class NestedPool():
         add_nodes = []
 
         # Group all leaf nodes by their direct parents
-        parent_to_leafs = ub.group_items(self._get_leafs(), key=lambda n: self._get_parent(n))
+        parent_to_leafs = ub.group_items(self._leaf_nodes, key=lambda n: self._get_parent(n))
         for parent, children in parent_to_leafs.items():
             # Group children by the new attribute
             val_to_subgroup = ub.group_items(children, lambda n:  self.graph.nodes[n][key])
@@ -471,20 +469,15 @@ class NestedPool():
         self.graph.add_nodes_from(add_nodes, weights=None)
         self.graph.add_edges_from(add_edges)
 
-    def _sample_many(self, num, return_attributes=False):
+    def _sample_many(self, num):
         for _ in range(num):
             idx = self.sample()
-            if return_attributes:
-                node = dict(self.graph.nodes[idx])
-                node.pop("label")
-                yield node
-            else:
-                yield idx
+            yield idx
 
     def sample(self):
         current = '__root__'
         while self.graph.out_degree(current) > 0:
-            children = list(self.graph.neighbors(current))
+            children = list(self.graph.successors(current))
             num = len(children)
 
             weights = self.graph.nodes[current]['weights']
@@ -497,10 +490,14 @@ class NestedPool():
         return current
 
     def __len__(self):
-        return len(list(self._get_leafs()))
-
-    def __str__(self):
-        return "\n".join(x for x in nx.generate_network_text(self.graph))
+        return len(list(self._leaf_nodes))
+
+    def __nice__(self):
+        n_nodes = self.graph.number_of_nodes()
+        n_edges = self.graph.number_of_edges()
+        n_leafs = self.__len__()
+        n_depth = len(nx.algorithms.dag.dag_longest_path(self.graph))
+        return f'nodes={n_nodes}, edges={n_edges}, leafs={n_leafs}, depth={n_depth}'
 
 
 def samecolor_nodata_mask(stream, hwc, relevant_bands, use_regions=0,
-- 
GitLab


From 9e4112323686ef5dd9fd77d9f4ac14bccbee9fbf Mon Sep 17 00:00:00 2001
From: Scott Workman <sworkman@dzynetech.com>
Date: Tue, 27 Feb 2024 16:51:43 -0500
Subject: [PATCH 6/7] rename NestedPool to BalancedSampleTree

---
 geowatch/tasks/fusion/datamodules/data_utils.py     | 8 ++++----
 geowatch/tasks/fusion/datamodules/kwcoco_dataset.py | 6 +++---
 2 files changed, 7 insertions(+), 7 deletions(-)

diff --git a/geowatch/tasks/fusion/datamodules/data_utils.py b/geowatch/tasks/fusion/datamodules/data_utils.py
index c4551ef36..aeabb1e89 100644
--- a/geowatch/tasks/fusion/datamodules/data_utils.py
+++ b/geowatch/tasks/fusion/datamodules/data_utils.py
@@ -330,13 +330,13 @@ def _boxes_snap_to_edges(given_box, snap_target):
     return adjusted_box
 
 
-class NestedPool(ub.NiceRepr):
+class BalancedSampleTree(ub.NiceRepr):
     """
     Manages a sampling from a tree of indexes. Helps with balancing
     samples over multiple criteria.
 
     Example:
-        >>> from geowatch.tasks.fusion.datamodules.data_utils import NestedPool
+        >>> from geowatch.tasks.fusion.datamodules.data_utils import BalancedSampleTree
         >>> # Given a grid of sample locations and attribute information
         >>> # (e.g., region, category).
         >>> sample_grid = [
@@ -361,7 +361,7 @@ class NestedPool(ub.NiceRepr):
         >>> #
         >>> # First we can just create a flat uniform sampling grid
         >>> # and inspect the imbalance that causes.
-        >>> self = NestedPool(sample_grid)
+        >>> self = BalancedSampleTree(sample_grid)
         >>> print(f'self={self}')
         >>> sampled = list(self._sample_many(100, return_attributes=True))
         >>> hist0 = ub.dict_hist([(g['region'], g['category']) for g in sampled])
@@ -404,7 +404,7 @@ class NestedPool(ub.NiceRepr):
                 self.graph = self._create_graph(sample_grid)
                 self._leaf_nodes = [n for n in self.graph.nodes if self.graph.out_degree[n] == 0]
                 return
-        raise ValueError("""NestedPool only supports input in the
+        raise ValueError("""BalancedSampleTree only supports input in the
             form of a flat list or list of dicts.""")
 
     def _create_graph(self, sample_grid):
diff --git a/geowatch/tasks/fusion/datamodules/kwcoco_dataset.py b/geowatch/tasks/fusion/datamodules/kwcoco_dataset.py
index 7b104d05d..4e1b44e74 100644
--- a/geowatch/tasks/fusion/datamodules/kwcoco_dataset.py
+++ b/geowatch/tasks/fusion/datamodules/kwcoco_dataset.py
@@ -2458,15 +2458,15 @@ class BalanceMixin:
         """
         Build data structure used for balanced sampling.
 
-        Helper for __init__ which constructs a NestedPool to balance sampling
+        Helper for __init__ which constructs a BalancedSampleTree to balance sampling
         across input domains.
         """
 
         print('Balancing over attributes')
         df_sample_attributes = self._setup_attribute_dataframe(new_sample_grid)
 
-        # Initialize an instance of NestedPool
-        self.nested_pool = data_utils.NestedPool(df_sample_attributes.to_dict('records'))
+        # Initialize an instance of BalancedSampleTree
+        self.nested_pool = data_utils.BalancedSampleTree(df_sample_attributes.to_dict('records'))
 
         # Compute weights for subdivide
         npr = self.config['neg_to_pos_ratio']
-- 
GitLab


From 671e55fcf99b7603e2f63752033a961d3ae678dc Mon Sep 17 00:00:00 2001
From: Scott Workman <sworkman@dzynetech.com>
Date: Tue, 27 Feb 2024 21:44:25 -0500
Subject: [PATCH 7/7] handle edge case, update test

---
 .../tasks/fusion/datamodules/data_utils.py    | 21 +++++++++++--------
 1 file changed, 12 insertions(+), 9 deletions(-)

diff --git a/geowatch/tasks/fusion/datamodules/data_utils.py b/geowatch/tasks/fusion/datamodules/data_utils.py
index aeabb1e89..c9dff6055 100644
--- a/geowatch/tasks/fusion/datamodules/data_utils.py
+++ b/geowatch/tasks/fusion/datamodules/data_utils.py
@@ -363,21 +363,21 @@ class BalancedSampleTree(ub.NiceRepr):
         >>> # and inspect the imbalance that causes.
         >>> self = BalancedSampleTree(sample_grid)
         >>> print(f'self={self}')
-        >>> sampled = list(self._sample_many(100, return_attributes=True))
+        >>> sampled = list(ub.take(sample_grid, self._sample_many(100)))
         >>> hist0 = ub.dict_hist([(g['region'], g['category']) for g in sampled])
         >>> print('hist0 = {}'.format(ub.urepr(hist0, nl=1)))
         >>> #
         >>> # We can subdivide the indexes based on region to improve balance.
         >>> self.subdivide('region')
         >>> print(f'self={self}')
-        >>> sampled = list(self._sample_many(100, return_attributes=True))
+        >>> sampled = list(ub.take(sample_grid, self._sample_many(100)))
         >>> hist1 = ub.dict_hist([(g['region'], g['category']) for g in sampled])
         >>> print('hist1 = {}'.format(ub.urepr(hist1, nl=1)))
         >>> #
         >>> # We can further subdivide by category.
         >>> self.subdivide('category')
         >>> print(f'self={self}')
-        >>> sampled = list(self._sample_many(100, return_attributes=True))
+        >>> sampled = list(ub.take(sample_grid, self._sample_many(100)))
         >>> hist2 = ub.dict_hist([(g['region'], g['category']) for g in sampled])
         >>> print('hist2 = {}'.format(ub.urepr(hist2, nl=1)))
         >>> #
@@ -385,14 +385,14 @@ class BalancedSampleTree(ub.NiceRepr):
         >>> weights = { 'red': .25, 'blue': .25, 'green': .4, 'purple': .1 }
         >>> self.subdivide('color', weights=weights)
         >>> print(f'self={self}')
-        >>> sampled = list(self._sample_many(100, return_attributes=True))
-        >>> hist2 = ub.dict_hist([
+        >>> sampled = list(ub.take(sample_grid, self._sample_many(100)))
+        >>> hist3 = ub.dict_hist([
         >>>     (g['region'], g['category'], g['color']) for g in sampled
         >>> ])
-        >>> print('hist3 = {}'.format(ub.urepr(hist2, nl=1)))
-        >>> hist2 = ub.dict_hist([(g['color']) for g in sampled])
+        >>> print('hist3 = {}'.format(ub.urepr(hist3, nl=1)))
+        >>> hist3_color = ub.dict_hist([(g['color']) for g in sampled])
         >>> print('color weights = {}'.format(ub.urepr(weights, nl=1)))
-        >>> print('hist3 (color) = {}'.format(ub.urepr(hist2, nl=1)))
+        >>> print('hist3 (color) = {}'.format(ub.urepr(hist3_color, nl=1)))
     """
     def __init__(self, sample_grid, rng=None):
         super().__init__()
@@ -459,7 +459,10 @@ class BalancedSampleTree(ub.NiceRepr):
                 # Add weights to the prior parent
                 if weights is not None:
                     weights_group = np.asarray(list(ub.take(weights, val_to_subgroup.keys())))
-                    weights_group = weights_group / weights_group.sum()
+                    denom = weights_group.sum()
+                    if denom == 0:
+                        raise NotImplementedError('Zero weighted branches are not handled yet.')
+                    weights_group = weights_group / denom
                     self.graph.nodes[parent]['weights'] = weights_group
                 else:
                     self.graph.nodes[parent]["weights"] = None
-- 
GitLab