diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index b377b73d6cd631a935ad9229ce80a2ef1e189e26..51383c714f7304a0a0ebbd24ae442c2a31f61cf8 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -26,6 +26,7 @@ stages: # Tags define which runners will accept which jobs - docker - linux + - build variables: # Change pip's cache directory to be inside the project directory since we can @@ -75,6 +76,8 @@ stages: - python -V # Print out python version for debugging - pip install --progress-bar off -r requirements.txt - pip install . + # FIXME: we should start from a docker iamge that already has LibGL setup + - apt update && apt install libgl1-mesa-glx -y && rm -rf /var/lib/apt/lists/* script: - ./run_tests.py diff --git a/CHANGELOG.md b/CHANGELOG.md index ac474641325d5f2f032dfbd708294442cc99116d..5ad6538cf54b8f33dcc609607f0aa0cf78ab547c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,7 +4,30 @@ This changelog follows the specifications detailed in: [Keep a Changelog](https: This project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html), although we have not yet reached a `1.0.0` release. -## Version 0.5.7 - Unreleased +## Version 0.5.9 - Unreleased + +### Changed + +* `_dump_monitor_tensorboard` now additionally writes a bash script to quickly + let the user re-visualize results in the case of mpl backend failure. + +* `load_partial_state` now has an algorithm to better match model keys when the + only difference is in key prefixes. + - adds keyword arg association which defaults to prefix-hack, the old default was module-hack, and embedding is more theoretically correct but too slow. + + +### Fixes +* Optimizer.coerce now works correctly with any `torch.optim` or `torch_optimizer` optimizer. + +### Added + +* `BatchContainer.pack` for easier use of non-container aware models. +* `colored` option to `FitHarnPreferences`, which can be set to False to disable ANSI coloring + + +## Version 0.5.8 - Released + +## Version 0.5.7 - Released ### Changed * `harn.deploy_fpath` is now populated when the model is deployed. @@ -489,5 +512,3 @@ This project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.htm ### Added * Early and undocumented commits - -## Version 0.5.8 - Unreleased diff --git a/dev/manage_snapshots.py b/dev/manage_snapshots.py index f878efdc6e6b3557fd0e580830ea7c28496e4f65..ddc08d99aba16e82d122502a2bfccf3d7cc27032 100755 --- a/dev/manage_snapshots.py +++ b/dev/manage_snapshots.py @@ -311,6 +311,20 @@ def _devcheck_manage_monitor(workdir, dry=True): _choose_action(file_infos) all_files.extend(file_infos) + dpath = join(session.dpath, 'monitor', 'train') + fpaths = list(glob.glob(join(dpath, '*.jpg'))) + file_infos = [{'size': os.stat(p).st_size, 'fpath': p} + for p in fpaths] + _choose_action(file_infos) + all_files.extend(file_infos) + + dpath = join(session.dpath, 'monitor', 'vali') + fpaths = list(glob.glob(join(dpath, '*.jpg'))) + file_infos = [{'size': os.stat(p).st_size, 'fpath': p} + for p in fpaths] + _choose_action(file_infos) + all_files.extend(file_infos) + grouped_actions = ub.group_items(all_files, lambda x: x['action']) for key, group in grouped_actions.items(): @@ -450,8 +464,9 @@ def main(): if mode == 'runs': _devcheck_remove_dead_runs(workdir=ns['workdir'], dry=ns['dry']) elif mode == 'snapshots': - print("A") _devcheck_manage_snapshots(**ns) + elif mode == 'monitor': + _devcheck_manage_monitor(workdir=ns['workdir'], dry=ns['dry']) else: raise KeyError(mode) @@ -463,9 +478,16 @@ if __name__ == '__main__': find . -iname "explit_checkpoints" -d - python ~/code/netharn/dev/manage_snapshots.py --mode=snapshots --workdir=~/work/voc_yolo2/ + python ~/code/netharn/dev/manage_snapshots.py --mode=snapshots --workdir=~/work/voc_yolo2/ --recent 2 --factor 40 python ~/code/netharn/dev/manage_snapshots.py --mode=runs --workdir=~/work/voc_yolo2/ + python ~/code/netharn/dev/manage_snapshots.py --mode=monitor --workdir=~/work/voc_yolo2/ + + Notes: + # Remove random files + # https://superuser.com/questions/1186350/delete-all-but-1000-random-files-in-a-directory + find . -type f -print0 | sort -zR | tail -zn +501 | xargs -0 rm + + - python ~/code/netharn/dev/manage_snapshots.py --mode=snapshots --workdir=~/work/mc_harn3/ --recent 2 --factor 40 """ main() diff --git a/netharn/analytic/analytic_for.py b/netharn/analytic/analytic_for.py index 9c413eb360cde22e31b575e4475047a8cdbe46d8..583476a754c12befe634e822c6ab0fb7bdfe0199 100644 --- a/netharn/analytic/analytic_for.py +++ b/netharn/analytic/analytic_for.py @@ -1,6 +1,128 @@ """ Code for commonalities between "X for" objects that compute analytic properties of networks like OutputShapeFor and ReceptiveFieldFor + + +The purpose of analysic modules is to make it easy to introspect both the final +and intermediate tensor shapes and receptive fields. As long as the relevant +``output_shape_for`` ``receptive_field_for`` OR ``_analytic_forward`` methods +are defined the computation will be fully symbolic. SeeAlso +:class:`netharn.layers.AnalyticModule`. + + +Example: + >>> import torch + >>> import netharn as nh + >>> # Inheriting from nh.layers.AnalyticModule lets us define _analytic_forward + >>> class MyNetwork(nh.layers.AnalyticModule): + >>> def __init__(self, classes): + >>> super().__init__() + >>> self.classes = classes + >>> # Note we are just using regular torch layers here + >>> # No special tricks required as long as the computation for + >>> # receptive field / output shape is registered. + >>> self.backbone = torch.nn.Sequential(*[ + >>> torch.nn.Conv2d(3, 32, kernel_size=3), + >>> torch.nn.BatchNorm2d(32), + >>> torch.nn.MaxPool2d(2, stride=2), + >>> torch.nn.ReLU(), + >>> torch.nn.Conv2d(32, 256, kernel_size=3, stride=2), + >>> torch.nn.BatchNorm2d(256), + >>> ]) + >>> self.clf_head = torch.nn.Conv2d(256, len(self.classes), kernel_size=1) + >>> def _analytic_forward(self, inputs, _OutputFor, _Output, _Hidden, + >>> **kwargs): + >>> # Defining the analytic forward function and using the _OutputFor + >>> # wrappers instead of calling each module directly will + >>> # automatically define the symbolic computation for + >>> # output_shape_for, receptive_field_for, and the real + >>> # computation for forward. Using Hidden will track any + >>> # intermediate states. + >>> x = inputs + >>> hidden = _Hidden() + >>> x = hidden['backbone'] = _OutputFor(self.backbone)(x) + >>> x = hidden['clf_head'] = _OutputFor(self.clf_head)(x) + >>> outputs = { + >>> 'class_energy': x, + >>> } + >>> outputs = _Output.coerce(outputs, hidden) + >>> return outputs + >>> # We can create an instance of our network + >>> self = MyNetwork(['a', 'b']) + >>> # Asking about the output shape for any input shape is computed + >>> # without directly invoking any tensor operations. + >>> output_shape = self.output_shape_for((None, 3, 32, 32)) + >>> print('output_shape = {!r}'.format(output_shape)) + >>> print(ub.repr2(output_shape.hidden, nl=-1)) + output_shape = OutputShapeDict([('class_energy', (None, 2, 7, 7))]) + { + 'backbone': { + '0': (None, 32, 30, 30), + '1': (None, 32, 30, 30), + '2': (None, 32, 15, 15), + '3': (None, 32, 15, 15), + '4': (None, 256, 7, 7), + '5': (None, 256, 7, 7) + }, + 'clf_head': (None, 2, 7, 7) + } + >>> # In most cases the receptive field does not need to know about the + >>> # input shape (adaptive layers are the exception here) + >>> rf = self.receptive_field_for() + >>> print('rf = {}'.format(ub.repr2(rf, nl=2))) + >>> print(ub.repr2(rf.hidden, nl=3)) + rf = { + 'class_energy': { + 'crop': np.array([3.5, 3.5], dtype=np.float64), + 'shape': np.array([8., 8.], dtype=np.float64), + 'stride': np.array([4., 4.], dtype=np.float64), + }, + } + { + 'backbone': { + '0': { + 'crop': np.array([1., 1.], dtype=np.float64), + 'shape': np.array([3., 3.], dtype=np.float64), + 'stride': np.array([1., 1.], dtype=np.float64), + }, + '1': { + 'crop': np.array([1., 1.], dtype=np.float64), + 'shape': np.array([3., 3.], dtype=np.float64), + 'stride': np.array([1., 1.], dtype=np.float64), + }, + '2': { + 'crop': np.array([1.5, 1.5], dtype=np.float64), + 'shape': np.array([4., 4.], dtype=np.float64), + 'stride': np.array([2., 2.], dtype=np.float64), + }, + '3': { + 'crop': np.array([1.5, 1.5], dtype=np.float64), + 'shape': np.array([4., 4.], dtype=np.float64), + 'stride': np.array([2., 2.], dtype=np.float64), + }, + '4': { + 'crop': np.array([3.5, 3.5], dtype=np.float64), + 'shape': np.array([8., 8.], dtype=np.float64), + 'stride': np.array([4., 4.], dtype=np.float64), + }, + '5': { + 'crop': np.array([3.5, 3.5], dtype=np.float64), + 'shape': np.array([8., 8.], dtype=np.float64), + 'stride': np.array([4., 4.], dtype=np.float64), + }, + }, + 'clf_head': { + 'crop': np.array([3.5, 3.5], dtype=np.float64), + 'shape': np.array([8., 8.], dtype=np.float64), + 'stride': np.array([4., 4.], dtype=np.float64), + }, + } + >>> # analytic forward ensures that your forward definition is consistent + >>> # with output_shape_for and analytic_for + >>> inputs = torch.rand(1, 3, 32, 32) + >>> outputs = self.forward(inputs) + >>> print('class_energy = {}'.format(outputs['class_energy'].shape)) + class_energy = torch.Size([1, 2, 7, 7]) """ import ubelt as ub from collections import OrderedDict diff --git a/netharn/analytic/output_shape_for.py b/netharn/analytic/output_shape_for.py index 0f4123f121282aece504fec7c40842a74c0ecbfa..a7cababf6b6b90af1ae54e047db55a9817e338cd 100644 --- a/netharn/analytic/output_shape_for.py +++ b/netharn/analytic/output_shape_for.py @@ -952,7 +952,8 @@ class OutputShapeFor(analytic_for.OutputFor): >>> from netharn.analytic.output_shape_for import * >>> module = torchvision.models.resnet50() >>> input_shape = (1, 3, 224, 224) - >>> field = OutputShapeFor(module)(input_shape=input_shape) + >>> shape = OutputShapeFor(module)(input_shape=input_shape) + >>> print(ub.repr2(shape.hidden, nl=-1)) """ shape = input_shape diff --git a/netharn/api.py b/netharn/api.py index ff161c26b45a59c694e4ab58152e890762f035e5..cb75bc607f67f53633a75e2fa85807dcb70b4185 100644 --- a/netharn/api.py +++ b/netharn/api.py @@ -102,7 +102,7 @@ class Initializer(object): >>> print(ub.repr2(nh.Initializer.coerce(config))) ( , - {'fpath': '/fit/nice/untitled', 'leftover': None, 'mangle': True}, + {... 'fpath': '/fit/nice/untitled', 'leftover': None, 'mangle': True}, ) >>> print(ub.repr2(nh.Initializer.coerce({'init': 'kaiming_normal'}))) ( @@ -151,6 +151,7 @@ class Initializer(object): 'fpath': ub.expandpath(config['pretrained_fpath']), 'leftover': kw.get('leftover', None), 'mangle': kw.get('mangle', True), + 'association': kw.get('association', None), }) elif config['init'] == 'cls': # Indicate that the model will initialize itself @@ -186,6 +187,9 @@ class Optimizer(object): https://datascience.stackexchange.com/questions/26792/difference-between-rmsprop-with-momentum-and-adam-optimizers https://github.com/jettify/pytorch-optimizer + CommandLine: + xdoctest -m /home/joncrall/code/netharn/netharn/api.py Optimizer.coerce + Example: >>> config = {'optimizer': 'sgd'} >>> optim_ = Optimizer.coerce(config) @@ -198,6 +202,14 @@ class Optimizer(object): >>> config = {'optimizer': 'Yogi'} >>> optim_ = Optimizer.coerce(config) >>> print('optim_ = {!r}'.format(optim_)) + + >>> from netharn.api import * # NOQA + >>> Optimizer.coerce({'optimizer': 'ASGD'}) + + TODO: + - [ ] https://pytorch.org/blog/stochastic-weight-averaging-in-pytorch/ + + """ import netharn as nh _update_defaults(config, kw) @@ -242,27 +254,43 @@ class Optimizer(object): 'alpha': 0.9, }) else: + from netharn.util import util_inspect try: import torch_optimizer except Exception: torch_optimizer = None - raise KeyError(key) - else: - - known = ['AccSGD', 'AdaBound', 'AdaMod', 'DiffGrad', 'Lamb', - 'Lookahead', 'NovoGrad', 'RAdam', 'SGDW', 'Yogi'] - - from netharn.util import util_inspect - if 0: - for key in known: - cls = getattr(torch_optimizer, key, None) - print('cls = {!r}'.format(cls)) - defaultkw = util_inspect.default_kwargs(cls) - print('defaultkw = {!r}'.format(defaultkw)) - - _lut = {k.lower(): k for k in known} - key = _lut[key] + _lut = {} + + if torch_optimizer is not None: + # known = ['AccSGD', 'AdaBound', 'AdaMod', 'DiffGrad', 'Lamb', + # 'Lookahead', 'NovoGrad', 'RAdam', 'SGDW', 'Yogi'] + # if 0: + # for key in known: + # cls = getattr(torch_optimizer, key, None) + # print('cls = {!r}'.format(cls)) + # defaultkw = util_inspect.default_kwargs(cls) + # print('defaultkw = {!r}'.format(defaultkw)) + # _lut.update({k.lower(): k for k in known}) + _lut.update({ + k: c.__name__ + for k, c in torch_optimizer._NAME_OPTIM_MAP.items()}) + + _lut.update({ + k.lower(): k for k in dir(torch.optim) + if not k.startswith('_')}) + + key = _lut[key] + + cls = getattr(torch.optim, key, None) + if cls is not None: + defaultkw = util_inspect.default_kwargs(cls) + kw = defaultkw.copy() + kw.update() + optim_ = (cls, kw) + else: + if torch_optimizer is None: + raise KeyError(key) cls = getattr(torch_optimizer, key, None) if cls is not None: defaultkw = util_inspect.default_kwargs(cls) diff --git a/netharn/criterions/triplet.py b/netharn/criterions/triplet.py index 81f5cfc94ceeb43734aa1a3fc4325283d762d019..58b1c3a8982c73a6b8b85ce44d9ff5b5bb49ae2f 100644 --- a/netharn/criterions/triplet.py +++ b/netharn/criterions/triplet.py @@ -23,6 +23,12 @@ def all_pairwise_distances(x, y=None, squared=False, approx=False): References: https://discuss.pytorch.org/t/efficient-distance-matrix-computation/9065 + SeeAlso: + torch.nn.functional.pairwise_distance + torch.nn.functional.pdist + torch.norm(input[:, None] - input, dim=2, p=p) + + Example: >>> from netharn.criterions.triplet import * >>> N, d = 5, 3 diff --git a/netharn/data/data_containers.py b/netharn/data/data_containers.py index c8abfdc3fc273e00b7db7342f07dc55cb28f483e..bda1696f8113f5362a3d5b87f5c72f4034044176 100644 --- a/netharn/data/data_containers.py +++ b/netharn/data/data_containers.py @@ -56,6 +56,10 @@ class BatchContainer(ub.NiceRepr): Attributes: data (List): Unlike ItemContainer, data is always a list where len(data) is the number of devices this batch will run on. + Each item in the list may be either a pre-batched Tensor (in the + case where the each item in the batch has the same shape) or a list + of individual item Tensors (in the case where different batch items + may have different shapes). """ def __init__(self, data, stack=False, padding_value=-1, cpu_only=False, pad_dims=2): @@ -68,9 +72,11 @@ class BatchContainer(ub.NiceRepr): } def __nice__(self): - shape_repr = ub.repr2(nestshape(self.data), nl=-2) - # return 'nestshape(data)={}, **{}'.format(shape_repr, ub.repr2(self.meta, nl=0)) - return 'nestshape(data)={}'.format(shape_repr) + try: + shape_repr = ub.repr2(nestshape(self.data), nl=-2) + return 'nestshape(data)={}'.format(shape_repr) + except Exception: + return super().__repr__() def __getitem__(self, index): cls = self.__class__ @@ -98,10 +104,10 @@ class BatchContainer(ub.NiceRepr): Concatenate data in multiple BatchContainers Example: - d1 = BatchContainer([torch.rand(3, 3, 1, 1), torch.rand(2, 3, 1, 1)]) - d2 = BatchContainer([torch.rand(3, 1, 1, 1), torch.rand(2, 1, 1, 1)]) - items = [d1, d2] - self = BatchContainer.cat(items, dim=1) + >>> d1 = BatchContainer([torch.rand(3, 3, 1, 1), torch.rand(2, 3, 1, 1)]) + >>> d2 = BatchContainer([torch.rand(3, 1, 1, 1), torch.rand(2, 1, 1, 1)]) + >>> items = [d1, d2] + >>> self = BatchContainer.cat(items, dim=1) """ newdata = [] num_devices = len(items[0].data) @@ -112,6 +118,48 @@ class BatchContainer(ub.NiceRepr): self = cls(newdata, **items[0].meta) return self + @classmethod + def demo(cls, key='img', n=5, num_devices=1): + inbatch = [ItemContainer.demo(key) for _ in range(n)] + self = ItemContainer._collate(inbatch, num_devices=num_devices) + return self + + def pack(self): + """ + Pack all of the data in this container into a single tensor. + + Returns: + Tensor: packed data, padded with ``self.padding_value`` if + ``self.stack`` is False. + + Example: + >>> self = BatchContainer.demo('img') + >>> print(self.pack()) + >>> self = BatchContainer.demo('box') + >>> print(self.pack()) + >>> self = BatchContainer.demo('labels') + >>> print(self.pack()) + """ + if self.stack: + # Should be a straight forward concatenation + packed = torch.cat(self.data, dim=0) + else: + # Need to account for padding values + from netharn.data.collate import padded_collate + inbatch = list(ub.flatten(self.data)) + packed = padded_collate(inbatch, fill_value=self.padding_value) + return packed + + def to(self, device): + """ inplace move data onto a device """ + for item in self.data: + if torch.is_tensor(item): + item.to(item) + else: + for subitem in item: + subitem.to(device) + return self + class ItemContainer(ub.NiceRepr): """ @@ -137,14 +185,24 @@ class ItemContainer(ub.NiceRepr): } def __nice__(self): - shape_repr = ub.repr2(nestshape(self.data), nl=-2) - return 'nestshape(data)={}'.format(shape_repr) + try: + shape_repr = ub.repr2(nestshape(self.data), nl=-2) + return 'nestshape(data)={}'.format(shape_repr) + except Exception: + return super().__repr__() # return 'nestshape(data)={}, **{}'.format(shape_repr, ub.repr2(self.meta, nl=0)) @classmethod def demo(cls, key='img', rng=None, **kwargs): """ Create data for tests + + Example: + >>> from netharn.data.data_containers import * # NOQA + >>> print(ItemContainer.demo('img')) + >>> print(ItemContainer.demo('labels')) + >>> print(ItemContainer.demo('box')) + """ import kwarray rng = kwarray.ensure_rng(rng) @@ -158,6 +216,11 @@ class ItemContainer(ub.NiceRepr): data = rng.randint(0, 10, n) data = torch.from_numpy(data) self = cls(data, stack=False) + elif key == 'box': + n = rng.randint(0, 10) + data = rng.rand(n, 4) + data = torch.from_numpy(data) + self = cls(data, stack=False) else: raise KeyError(key) return self @@ -219,11 +282,11 @@ class ItemContainer(ub.NiceRepr): >>> print('Collate Image ItemContainer') >>> inbatch = [ItemContainer.demo('img') for _ in range(5)] >>> print('inbatch = {}'.format(ub.repr2(inbatch))) - >>> result = ItemContainer._collate(inbatch, 2) + >>> result = ItemContainer._collate(inbatch, num_devices=2) >>> print('result1 = {}'.format(ub.repr2(result, nl=1))) - >>> result = ItemContainer._collate(inbatch, 1) + >>> result = ItemContainer._collate(inbatch, num_devices=1) >>> print('result2 = {}'.format(ub.repr2(result, nl=1))) - >>> result = ItemContainer._collate(inbatch, None) + >>> result = ItemContainer._collate(inbatch, num_devices=None) >>> print('resultN = {}'.format(ub.repr2(result, nl=1))) >>> print('Collate Label ItemContainer') @@ -722,15 +785,12 @@ def container_gather(outputs, target_device, dim=0): # xdev.embed() return OrigGather.apply(target_device, dim, *outputs_) if isinstance(out, BatchContainer): - # if out.datatype is list: newdata = [d for dc in outputs_ for d in dc.data] if not out.cpu_only: import netharn as nh target_xpu = nh.XPU(target_device) newdata = target_xpu.move(newdata) return newdata - # else: - # raise NotImplementedError(repr(out.datatype)) if out is None: return None if isinstance(out, dict): @@ -788,6 +848,25 @@ class ContainerXPU(XPU): model = DataSerial(model) return model + def move(xpu, data, **kwargs): + try: + if xpu.is_gpu(): + return data.to(xpu._main_device_id, **kwargs) + else: + return data.to('cpu') + except AttributeError: + # Recursive move + if isinstance(data, container_abcs.Mapping): + cls = data.__class__ + return cls((k, xpu.move(v)) for k, v in data.items()) + elif isinstance(data, (container_abcs.Sequence, container_abcs.Set)): + cls = data.__class__ + return cls(xpu.move(v) for v in data) + elif isinstance(data, BatchContainer): + return data.to(xpu._main_device_id, **kwargs) + else: + raise TypeError('Unknown type {}'.format(type(data))) + def nestshape(data): import ubelt as ub @@ -844,3 +923,12 @@ def _debug_inbatch_shapes(inbatch): return ub.repr2(dict(type=str(type(data)), shape=data.shape), nl=1, sv=1) print('inbatch = ' + ub.repr2(inbatch, extensions=extensions, nl=True)) + + +if __name__ == '__main__': + """ + CommandLine: + xdoctest netharn.data.data_containers all + """ + import xdoctest + xdoctest.doctest_module(__file__) diff --git a/netharn/fit_harn.py b/netharn/fit_harn.py index b76aea5e8db1ad279e1942d061bb51fbd8884d86..e2937e4978cdf9b948a34e251c1b4dd0a8b7630e 100644 --- a/netharn/fit_harn.py +++ b/netharn/fit_harn.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -""" +r""" Notes: when profiling ensure CUDA_LAUNCH_BLOCKING=1 @@ -88,6 +88,7 @@ Example: >>> # non-algorithmic behavior configs (do not change learned models) >>> harn.preferences['use_tensorboard'] = False >>> harn.preferences['timeout'] = 0.5 + >>> # harn.preferences['colored'] = False >>> # start training. >>> harn.initialize(reset='delete') >>> harn.run() # note: run calls initialize it hasn't already been called. @@ -788,6 +789,8 @@ class ProgMixin(object): desc = 'epoch lr:{} | {}'.format(lr_str, harn.monitor.message()) else: desc = 'epoch lr:{} │ {}'.format(lr_str, harn.monitor.message()) + if not harn.preferences['colored']: + desc = strip_ansi(desc) harn.debug(desc) harn.main_prog.set_description(desc, refresh=False) if isinstance(harn.main_prog, ub.ProgIter): @@ -844,6 +847,8 @@ class LogMixin(object): Args: msg (str): an info message to log """ + if not harn.preferences['colored']: + msg = strip_ansi(msg) harn._ensure_prog_newline() if harn._log: try: @@ -865,6 +870,8 @@ class LogMixin(object): msg = strip_ansi(msg) harn._log.error(msg) else: + if not harn.preferences['colored']: + msg = strip_ansi(msg) print(msg) def warn(harn, msg): @@ -879,6 +886,8 @@ class LogMixin(object): msg = strip_ansi(msg) harn._log.warning(msg) else: + if not harn.preferences['colored']: + msg = strip_ansi(msg) print(msg) def debug(harn, msg): @@ -1952,6 +1961,8 @@ class CoreMixin(object): msg = harn._batch_msg({'loss': ave_metrics['loss']}, bsize, learn) + if not harn.preferences['colored']: + desc = strip_ansi(desc) prog.set_description(tag + ' ' + msg, refresh=False) # log_iter_train, log_iter_test, log_iter_vali @@ -2790,6 +2801,10 @@ class FitHarnPreferences(scfg.Config): # Deprecated 'use_tqdm': scfg.Value(None, help='deprecated'), + + 'colored': scfg.Value(True, help=( + 'allow for ANSI colored text in stdout logs, ' + 'otherwise it is stripped')), } diff --git a/netharn/initializers/_nx_extensions.py b/netharn/initializers/_nx_extensions.py new file mode 100644 index 0000000000000000000000000000000000000000..e364ef8a1ebe92ed05326037f9d796df9344947f --- /dev/null +++ b/netharn/initializers/_nx_extensions.py @@ -0,0 +1,1004 @@ +""" +EXPERIMENTAL : NEW WORK ON THIS IS HAPPENING IN NETWORKX ITSELF + +ONCE THAT IS DONE I WILL MODIFY THE ALGORITHMS HERE. +""" + +import operator +import ubelt as ub +import networkx as nx + +try: + import xdev + profile = xdev.profile +except Exception: + profile = ub.identity + + +# Cython gives a 40x speed boost in the nx version but not here +TRY_USE_CYTHON = 0 + + +@profile +def maximum_common_ordered_tree_embedding(tree1, tree2, node_affinity='auto'): + """ + Finds the maximum common subtree-embedding between two ordered trees. + + A tree S is an embedded subtree of T if it can be obtained from T by a + series of edge contractions. + + Note this produces a subtree embedding, which is not necessarilly a + subgraph isomorphism (although a subgraph isomorphism is also an + embedding.) + + The maximum common embedded subtree problem can be solved in in + `O(n1 * n2 * min(d1, l1) * min(d2, l2))` time on ordered trees with n1 and + n2 nodes, of depth d1 and d2 and with l1 and l2 leaves, respectively + + Implements algorithm described in [1]_. + + References: + On the Maximum Common Embedded Subtree Problem for Ordered Trees + https://pdfs.semanticscholar.org/0b6e/061af02353f7d9b887f9a378be70be64d165.pdf + + http://algo.inria.fr/flajolet/Publications/FlSiSt90.pdf + + Notes: + Exact algorithms for computing the tree edit distance between unordered trees - https://pdf.sciencedirectassets.com/271538/1-s2.0-S0304397510X00299/1-s2.0-S0304397510005463/main.pdf ? + + Tree Edit Distance and Common Subtrees - https://upcommons.upc.edu/bitstream/handle/2117/97554/R02-20.pdf + + A Survey on Tree Edit Distance and Related Problems - https://grfia.dlsi.ua.es/ml/algorithms/references/editsurvey_bille.pdf + + Args: + + tree1 (nx.OrderedDiGraph): first ordered tree + tree2 (nx.OrderedDiGraph): second ordered tree + node_affinity (callable): function + + Example: + >>> from netharn.initializers._nx_extensions import * # NOQA + >>> from netharn.initializers._nx_extensions import _lcs, _print_forest + >>> def random_ordered_tree(n, seed=None): + >>> tree = nx.dfs_tree(nx.random_tree(n, seed=seed)) + >>> otree = nx.OrderedDiGraph() + >>> otree.add_edges_from(tree.edges) + >>> return otree + >>> tree1 = random_ordered_tree(10, seed=1) + >>> tree2 = random_ordered_tree(10, seed=2) + >>> print('tree1') + >>> _print_forest(tree1) + >>> print('tree2') + >>> _print_forest(tree2) + + >>> embedding1, embedding2 = maximum_common_ordered_tree_embedding(tree1, tree2 ) + >>> print('embedding1') + >>> _print_forest(embedding1) + >>> print('embedding2') + >>> _print_forest(embedding2) + """ + if not (isinstance(tree1, nx.OrderedDiGraph) and nx.is_forest(tree1)): + raise nx.NetworkXNotImplemented('only implemented for directed ordered trees') + if not (isinstance(tree1, nx.OrderedDiGraph) and nx.is_forest(tree2)): + raise nx.NetworkXNotImplemented('only implemented for directed ordered trees') + + # Convert the trees to balanced sequences + sequence1, open_to_close, toks = tree_to_balanced_sequence(tree1, open_to_close=None, toks=None) + sequence2, open_to_close, toks = tree_to_balanced_sequence(tree2, open_to_close, toks) + seq1 = sequence1 + seq2 = sequence2 + + open_to_tok = ub.invert_dict(toks) + + # Solve the longest common balanced sequence problem + best, value = longest_common_balanced_sequence( + seq1, seq2, open_to_close, open_to_tok=open_to_tok, node_affinity=node_affinity) + subseq1, subseq2 = best + + # Convert the subsequence back into a tree + embedding1 = seq_to_tree(subseq1, open_to_close, toks) + embedding2 = seq_to_tree(subseq2, open_to_close, toks) + return embedding1, embedding2 + + +@profile +def maximum_common_ordered_subtree_isomorphism(tree1, tree2, node_affinity='auto'): + """ + Isomorphic version of `maximum_common_ordered_tree_embedding`. + + CommandLine: + xdoctest -m /home/joncrall/code/netharn/netharn/initializers/_nx_extensions.py maximum_common_ordered_subtree_isomorphism:1 --profile && cat profile_output.txt + + Ignore: + >>> from netharn.initializers._nx_extensions import * # NOQA + >>> from netharn.initializers._nx_extensions import _lcs, _print_forest + >>> def random_ordered_tree(n, seed=None): + >>> tree = nx.dfs_tree(nx.random_tree(n, seed=seed)) + >>> otree = nx.OrderedDiGraph() + >>> otree.add_edges_from(tree.edges) + >>> return otree + >>> tree1 = random_ordered_tree(10, seed=3) + >>> tree2 = random_ordered_tree(10, seed=2) + >>> tree1.add_edges_from(tree2.edges, weight=1) + >>> tree1 = nx.minimum_spanning_arborescence(tree1) + >>> tree2.add_edges_from(tree1.edges, weight=1) + >>> tree2 = nx.minimum_spanning_arborescence(tree2) + >>> tree1.remove_edge(4, 7) + >>> tree1.remove_edge(4, 9) + >>> tree1.add_edge(4, 10) + >>> tree1.add_edge(10, 7) + >>> tree1.add_edge(10, 9) + >>> #tree1.add_edges_from([(9, 11), (11, 12), (12, 13), (13, 14)]) + >>> #tree2.add_edges_from([(9, 11), (11, 12), (12, 13), (13, 14)]) + >>> tree1.add_edges_from([(9, 11), (11, 12)]) + >>> tree2.add_edges_from([(9, 11), (11, 12)]) + >>> tree2.add_edge(100, 0) + >>> tree1.add_edge(102, 100) + >>> tree1.add_edge(100, 101) + >>> tree1.add_edge(101, 0) + >>> tree1.add_edge(5, 201) + >>> tree1.add_edge(5, 202) + >>> tree1.add_edge(5, 203) + >>> tree1.add_edge(201, 2000) + >>> tree1.add_edge(2000, 2001) + >>> tree1.add_edge(2001, 2002) + >>> tree1.add_edge(2002, 2003) + >>> tree2.add_edge(5, 202) + >>> tree2.add_edge(5, 203) + >>> tree2.add_edge(5, 201) + >>> tree2.add_edge(201, 2000) + >>> tree2.add_edge(2000, 2001) + >>> tree2.add_edge(2001, 2002) + >>> tree2.add_edge(2002, 2003) + >>> print('-----') + >>> print('tree1') + >>> _print_forest(tree1) + >>> print('tree2') + >>> _print_forest(tree2) + >>> subtree1, subtree2 = maximum_common_ordered_subtree_isomorphism(tree1, tree2 ) + >>> print('-----') + >>> print('subtree1') + >>> _print_forest(subtree1) + >>> print('subtree2') + >>> _print_forest(subtree2) + >>> embedding1, embedding2 = maximum_common_ordered_tree_embedding(tree1, tree2) + >>> print('-----') + >>> print('embedding1') + >>> _print_forest(embedding1) + >>> print('embedding2') + >>> _print_forest(embedding2) + >>> if 0: + >>> ti = timerit.Timerit(6, bestof=2, verbose=2) + >>> for timer in ti.reset('isomorphism'): + >>> with timer: + >>> maximum_common_ordered_subtree_isomorphism(tree1, tree2 ) + >>> for timer in ti.reset('embedding'): + >>> with timer: + >>> maximum_common_ordered_tree_embedding(tree1, tree2 ) + >>> from networkx import isomorphism + >>> assert isomorphism.DiGraphMatcher(tree1, subtree1).subgraph_is_isomorphic() + >>> assert isomorphism.DiGraphMatcher(tree2, subtree2).subgraph_is_isomorphic() + >>> list(isomorphism.DiGraphMatcher(tree1, tree2).subgraph_isomorphisms_iter()) + >>> list(isomorphism.DiGraphMatcher(tree1, tree2).subgraph_monomorphisms_iter()) + >>> list(isomorphism.DiGraphMatcher(subtree1, subtree2).subgraph_isomorphisms_iter()) + >>> list(isomorphism.DiGraphMatcher(tree1, subtree1).subgraph_isomorphisms_iter()) + >>> list(isomorphism.DiGraphMatcher(tree2, subtree2).subgraph_isomorphisms_iter()) + + Ignore: + >>> from netharn.initializers._nx_extensions import * # NOQA + >>> from netharn.initializers._nx_extensions import _lcs, _print_forest + >>> def random_ordered_tree(n, seed=None): + >>> if n > 0: + >>> tree = nx.dfs_tree(nx.random_tree(n, seed=seed)) + >>> otree = nx.OrderedDiGraph() + >>> if n > 0: + >>> otree.add_edges_from(tree.edges) + >>> return otree + >>> import random + >>> rng = random.Random(90269698983701724775426457020022) + >>> num = 1000 + >>> def _gen_seeds(num): + >>> for _ in range(num): + >>> yield (rng.randint(0, 50), rng.randint(0, 50), rng.randint(0, 2 ** 64), rng.randint(0, 2 ** 64)) + >>> for n1, n2, s1, s2 in ub.ProgIter(_gen_seeds(num=num), total=num, verbose=3): + >>> tree1 = random_ordered_tree(n1, seed=s1) + >>> tree2 = random_ordered_tree(n2, seed=s2) + >>> #print('-----') + >>> #print('tree1') + >>> #_print_forest(tree1) + >>> #print('tree2') + >>> #_print_forest(tree2) + >>> subtree1, subtree2 = maximum_common_ordered_subtree_isomorphism(tree1, tree2, node_affinity='auto') + >>> #print('-----') + >>> #print('subtree1') + >>> #_print_forest(subtree1) + >>> #print('subtree2') + >>> #_print_forest(subtree2) + >>> from networkx import isomorphism + >>> assert isomorphism.DiGraphMatcher(tree1, subtree1).subgraph_is_isomorphic() + >>> assert isomorphism.DiGraphMatcher(tree2, subtree2).subgraph_is_isomorphic() + + """ + try: + if not (isinstance(tree1, nx.OrderedDiGraph) and nx.is_forest(tree1)): + raise nx.NetworkXNotImplemented('only implemented for directed ordered trees') + if not (isinstance(tree1, nx.OrderedDiGraph) and nx.is_forest(tree2)): + raise nx.NetworkXNotImplemented('only implemented for directed ordered trees') + except nx.NetworkXPointlessConcept: + subtree1 = nx.OrderedDiGraph() + subtree2 = nx.OrderedDiGraph() + return subtree1, subtree2 + + # Convert the trees to balanced sequences + sequence1, open_to_close, toks = tree_to_balanced_sequence(tree1, open_to_close=None, toks=None, mode='chr') + sequence2, open_to_close, toks = tree_to_balanced_sequence(tree2, open_to_close, toks, mode='chr') + seq1 = sequence1 + seq2 = sequence2 + + open_to_tok = ub.invert_dict(toks) + + # Solve the longest common balanced sequence problem + best, value = longest_common_isomorphic_sequence( + seq1, seq2, open_to_close, open_to_tok=open_to_tok, node_affinity=node_affinity) + subseq1, subseq2 = best + + # Convert the subsequence back into a tree + subtree1 = seq_to_tree(subseq1, open_to_close, toks) + subtree2 = seq_to_tree(subseq2, open_to_close, toks) + return subtree1, subtree2 + + +class UnbalancedException(Exception): + pass + + +def tree_to_balanced_sequence(tree, open_to_close=None, toks=None, mode='tuple'): + from collections import namedtuple + Token = namedtuple('Token', ['action', 'value']) + # mapping between opening and closing tokens + sources = [n for n in tree.nodes if tree.in_degree[n] == 0] + sequence = [] + + if open_to_close is None: + open_to_close = {} + if toks is None: + toks = {} + + for source in sources: + for u, v, etype in nx.dfs_labeled_edges(tree, source=source): + if etype == 'forward': + # u has been visited by v has not + if v not in toks: + if mode == 'tuple': + # TODO: token encoding scheme where subdirectories + # are matchable via a custom operation. + # open_tok = '<{}>'.format(v) + # close_tok = ''.format(v) + open_tok = Token('open', v) + close_tok = Token('close', v) + elif mode == 'number': + open_tok = len(toks) + 1 + close_tok = -open_tok + elif mode == 'paren': + open_tok = '{}('.format(v) + close_tok = '){}'.format(v) + elif mode == 'chr': + open_tok = str(v) + close_tok = str(v) + u'\u0301' + # chr(ord(v) + 128) + toks[v] = open_tok + open_to_close[open_tok] = close_tok + open_tok = toks[v] + sequence.append(open_tok) + elif etype == 'reverse': + # Both u and v are visited and the edge is in the tree + close_tok = open_to_close[toks[v]] + sequence.append(close_tok) + else: + raise KeyError(etype) + sequence = tuple(sequence) + return sequence, open_to_close, toks + + +def seq_to_tree(subseq, open_to_close, toks): + open_to_tok = ub.invert_dict(toks) + subtree = nx.OrderedDiGraph() + stack = [] + for token in subseq: + if token in open_to_close: + node = open_to_tok[token] + if stack: + parent = open_to_tok[stack[-1]] + subtree.add_edge(parent, node) + else: + subtree.add_node(node) + stack.append(token) + else: + if not stack: + raise Exception + prev_open = stack.pop() + want_close = open_to_close[prev_open] + if token != want_close: + raise Exception + return subtree + + +def random_ordered_tree(n, seed=None): + tree = nx.dfs_tree(nx.random_tree(n, seed=seed)) + otree = nx.OrderedDiGraph() + otree.add_edges_from(tree.edges) + return otree + + +@profile +def generate_balance_unsafe_python(sequence, open_to_close): + """ + Benchmark: + >>> tree = random_ordered_tree(1000) + >>> sequence, open_to_close, toks = tree_to_balanced_sequence(tree, mode='tuple') + >>> sequence, open_to_close, toks = tree_to_balanced_sequence(tree, mode='number') + >>> import timerit + >>> ti = timerit.Timerit(100, bestof=10, verbose=2) + >>> for timer in ti.reset('time'): + >>> with timer: + >>> list(generate_balance_unsafe(sequence, open_to_close)) + >>> import timerit + >>> ti = timerit.Timerit(100, bestof=10, verbose=2) + >>> for timer in ti.reset('time'): + >>> with timer: + >>> list(generate_balance_unsafe_cython(sequence, open_to_close)) + """ + stacklen = 0 + for token in sequence: + if token in open_to_close: + stacklen += 1 + else: + stacklen -= 1 + yield stacklen == 0, token + + +@profile +def balanced_decomp(sequence, open_to_close): + """ + Note this is not exactly the same as the decomposition in the paper. + That is because we also return the "wrapping" element, and we let the + user do the head + tail concatenation. + + Example: + >>> open_to_close = {0: 1} + >>> sequence = [0, 0, 0, 1, 1, 1, 0, 1] + >>> open_to_close = {'{': '}', '(': ')', '[': ']'} + >>> sequence = '({[[]]})[[][]]' + >>> a1, b1, head, tail = balanced_decomp(sequence, open_to_close) + >>> a2, b2, tail1, tail2 = balanced_decomp(tail, open_to_close) + """ + gen = generate_balance(sequence, open_to_close) + + bal_curr, tok_curr = next(gen) + pop_open = sequence[0:1] + want_close = open_to_close[tok_curr] + + head_stop = 1 + for head_stop, (bal_curr, tok_curr) in enumerate(gen, start=1): + if tok_curr is None: + break + elif bal_curr and tok_curr == want_close: + pop_close = sequence[head_stop:head_stop + 1] + break + head = sequence[1:head_stop] + # if __debug__: + # list(gen) # exhaust the generator to check we are balanced + tail = sequence[head_stop + 1:] + return pop_open, pop_close, head, tail + + +@profile +def balanced_decomp_unsafe(sequence, open_to_close): + """ + open_to_close = {0: 1} + sequence = [0, 0, 0, 1, 1, 1, 0, 1] + open_to_close = {'{': '}', '(': ')', '[': ']'} + sequence = '({[[]]})[[][]]' + a1, b1, head, tail = balanced_decomp(sequence, open_to_close) + a2, b2, tail1, tail2 = balanced_decomp(tail, open_to_close) + + Benchmark: + >>> from netharn.initializers._nx_extensions import * # NOQA + >>> tree = random_ordered_tree(100) + >>> sequence, open_to_close, toks = tree_to_balanced_sequence(tree) + >>> import timerit + >>> ti = timerit.Timerit(100, bestof=10, verbose=2, unit='us') + >>> for timer in ti.reset('safe-python'): + >>> with timer: + >>> list(balanced_decomp(sequence, open_to_close)) + >>> for timer in ti.reset('unsafe-python'): + >>> with timer: + >>> list(balanced_decomp_unsafe(sequence, open_to_close)) + >>> for timer in ti.reset('unsafe-python-v2'): + >>> with timer: + >>> list(balanced_decomp_unsafe2_python(sequence, open_to_close)) + >>> for timer in ti.reset('unsafe-c/python-v2'): + >>> with timer: + >>> list(balanced_decomp_unsafe2(sequence, open_to_close)) + """ + gen = generate_balance_unsafe(sequence, open_to_close) + + bal_curr, tok_curr = next(gen) + pop_open = sequence[0:1] + want_close = open_to_close[tok_curr] + + head_stop = 1 + for head_stop, (bal_curr, tok_curr) in enumerate(gen, start=1): + if bal_curr and tok_curr == want_close: + pop_close = sequence[head_stop:head_stop + 1] + break + head = sequence[1:head_stop] + tail = sequence[head_stop + 1:] + return pop_open, pop_close, head, tail + + +@profile +def balanced_decomp_unsafe2_python(sequence, open_to_close): + stacklen = 0 + seq_iter = iter(sequence) + tok_curr = next(seq_iter) + stacklen += 1 if tok_curr in open_to_close else -1 + want_close = open_to_close[tok_curr] + + head_stop = 1 + for head_stop, tok_curr in enumerate(seq_iter, start=1): + stacklen += 1 if tok_curr in open_to_close else -1 + if stacklen == 0 and tok_curr == want_close: + break + + pop_close = sequence[head_stop:head_stop + 1] + pop_open = sequence[0:1] + head = sequence[1:head_stop] + tail = sequence[head_stop + 1:] + return pop_open, pop_close, head, tail + + +generate_balance_unsafe = generate_balance_unsafe_python +balanced_decomp_unsafe2 = balanced_decomp_unsafe2_python + + +if TRY_USE_CYTHON: + try: + from netharn.initializers import _nx_extensions_cython_backend as cyb + + generate_balance_unsafe_cython = cyb.generate_balance_unsafe_cython + generate_balance_unsafe = cyb.generate_balance_unsafe_cython + + balanced_decomp_unsafe2_cython = cyb.balanced_decomp_unsafe2_cython + balanced_decomp_unsafe2 = cyb.balanced_decomp_unsafe2_cython + except Exception: + pass + + +def generate_balance(sequence, open_to_close, safe=True): + """ + Args: + safe (bool): if True we will error if the sequence is not balanced + if you are SURE the sequence is balanced set safe=False to slightly + improve runtime. + + + CommandLine: + xdoctest -m /home/joncrall/code/netharn/netharn/initializers/_nx_extensions.py generate_balance:1 --profile + + Example: + >>> open_to_close = {0: 1} + >>> sequence = [0, 0, 0, 1, 1, 1] + >>> gen = list(generate_balance(sequence, open_to_close)) + >>> for flag, token in gen: + >>> print('flag={:d}, token={}'.format(flag, token)) + + Example: + >>> tree = random_ordered_tree(1000) + >>> sequence, open_to_close, toks = tree_to_balanced_sequence(tree) + >>> gen = list(generate_balance(sequence, open_to_close)) + >>> for flag, token in gen: + >>> print('flag={:d}, token={}'.format(flag, token)) + + Benchmark: + >>> from netharn.initializers._nx_extensions import * # NOQA + >>> tree = random_ordered_tree(100) + >>> sequence, open_to_close, toks = tree_to_balanced_sequence(tree) + >>> import timerit + >>> ti = timerit.Timerit(100, bestof=10, verbose=2, unit='us') + >>> for timer in ti.reset('safe-python'): + >>> with timer: + >>> list(generate_balance(sequence, open_to_close)) + >>> for timer in ti.reset('unsafe-python'): + >>> with timer: + >>> list(generate_balance_unsafe(sequence, open_to_close)) + + Ignore: + from netharn.initializers._nx_extensions import * # NOQA + from numba import jit + jit_generate_balance = jit(forceobj=True)(generate_balance) + + open_to_close = {0: 1} + sequence = [0, 0, 0, 1, 1, 1] + list(jit_generate_balance(sequence, open_to_close)) + + tree = random_ordered_tree(1000) + sequence, open_to_close, toks = tree_to_balanced_sequence(tree) + + import timerit + ti = timerit.Timerit(100, bestof=10, verbose=2, unit='us') + + for timer in ti.reset('safe-python'): + with timer: + list(generate_balance(sequence, open_to_close)) + + for timer in ti.reset('unsafe-python'): + with timer: + list(generate_balance_unsafe(sequence, open_to_close)) + + for timer in ti.reset('numba'): + with timer: + list(jit_generate_balance(sequence, open_to_close)) + """ + if safe: + stack = [] + # Traversing the Expression + for token in sequence: + + if token in open_to_close: + # Push opening elements onto the stack + stack.append(token) + else: + # Check that closing elements + if not stack: + raise UnbalancedException + prev_open = stack.pop() + want_close = open_to_close[prev_open] + + if token != want_close: + raise UnbalancedException + + # If the stack is empty the sequence is currently balanced + currently_balanced = not bool(stack) + yield currently_balanced, token + + if stack: + raise UnbalancedException + else: + yield from generate_balance_unsafe(sequence, open_to_close) + + +@profile +def longest_common_balanced_sequence(seq1, seq2, open_to_close, node_affinity='auto', open_to_tok=None): + """ + CommandLine: + xdoctest -m /home/joncrall/code/netharn/netharn/initializers/_nx_extensions.py longest_common_balanced_sequence:0 --profile && cat profile_output.txt + + Example: + >>> tree1 = random_ordered_tree(100, seed=1) + >>> tree2 = random_ordered_tree(100, seed=2) + >>> seq1, open_to_close, toks = tree_to_balanced_sequence(tree1) + >>> seq2, open_to_close, toks = tree_to_balanced_sequence(tree2, open_to_close, toks) + >>> longest_common_balanced_sequence(seq1, seq2, open_to_close) + + Benchmark: + >>> tree1 = random_ordered_tree(20, seed=1) + >>> tree2 = random_ordered_tree(20, seed=2) + >>> seq1, open_to_close, toks = tree_to_balanced_sequence(tree1) + >>> seq2, open_to_close, toks = tree_to_balanced_sequence(tree2, open_to_close, toks) + >>> longest_common_balanced_sequence(seq1, seq2, open_to_close) + + import sys, ubelt + sys.path.append(ubelt.expandpath('~/code/netharn')) + from netharn.initializers._nx_extensions import * # NOQA + from netharn.initializers._nx_extensions import _best_prefix_transform, _lcs, _print_forest + + open_to_close = {'0': '1'} + seq1 = '0010010010111100001011011011' + seq2 = '001000101101110001000100101110111011' + + open_to_close = {'(': ')'} + seq1 = '(()(()(()())))(((()())())())' + seq2 = '(()((()())()))((()((()(()()))()))())' + longest_common_balanced_sequence(seq1, seq2, open_to_close) + + open_to_close = {'0': '1'} + seq1 = '0010010010111100001011011011' + seq2 = '001000101101110001000100101110111011' + longest_common_balanced_sequence(seq1, seq2, open_to_close) + + open_to_close = {'0': '1'} + seq1 = '001101' + seq2 = '00110011' + seq1 = '001101' + seq2 = '00110011' + longest_common_balanced_sequence(seq1, seq2, open_to_close) + + open_to_close = {'{': '}', '(': ')', '[': ']'} + seq1 = '(({}{([])}[{}]))' + seq2 = '((({}[{{}}])))' + + seq1 = '({[[[]]]}){}' + seq2 = '{}{[[[]]]}' + best, value = longest_common_balanced_sequence(seq1, seq2, open_to_close) + subseq1, subseq2 = best + print('subseq1 = {!r}'.format(subseq1)) + """ + if node_affinity == 'auto': + node_affinity = operator.eq + if node_affinity is None: + def _matchany(a, b): + return True + node_affinity = _matchany + _memo = {} + _seq_memo = {} + if open_to_tok is None: + class Dummy: + def __getitem__(self, key): + return key + open_to_tok = Dummy() + best, value = _lcs(seq1, seq2, open_to_close, node_affinity, open_to_tok, _memo, _seq_memo) + return best, value + + +@profile +def _lcs(seq1, seq2, open_to_close, node_affinity, open_to_tok, _memo, _seq_memo): + if not seq1: + return (seq1, seq1), 0 + elif not seq2: + return (seq2, seq2), 0 + else: + # if len(seq2) < len(seq1): + # seq1, seq2 = seq2, seq1 + # key = (seq1, seq2) + key1 = hash(seq1) # using hash(seq) is faster than seq itself + key2 = hash(seq2) + key = hash((key1, key2)) + if key in _memo: + return _memo[key] + + # TODO: we can probably just do a single linear run through the + # sequences to index the sub-sequence locations and then apply an + # offset when we run the decomposed sequence. + if key1 in _seq_memo: + a1, b1, head1, tail1, head1_tail1 = _seq_memo[key1] + else: + a1, b1, head1, tail1 = balanced_decomp_unsafe2(seq1, open_to_close) + head1_tail1 = head1 + tail1 + _seq_memo[key1] = a1, b1, head1, tail1, head1_tail1 + + if key2 in _seq_memo: + a2, b2, head2, tail2, head2_tail2 = _seq_memo[key2] + else: + a2, b2, head2, tail2 = balanced_decomp_unsafe2(seq2, open_to_close) + head2_tail2 = head2 + tail2 + _seq_memo[key2] = a2, b2, head2, tail2, head2_tail2 + + # Case 2: The current edge in sequence1 is deleted + best, val = _lcs(head1_tail1, seq2, open_to_close, node_affinity, open_to_tok, _memo, _seq_memo) + + # Case 3: The current edge in sequence2 is deleted + cand, val_alt = _lcs(seq1, head2_tail2, open_to_close, node_affinity, open_to_tok, _memo, _seq_memo) + if val_alt > val: + best = cand + val = val_alt + + # Case 1: The LCS involves this edge + t1 = open_to_tok[a1[0]] + t2 = open_to_tok[a2[0]] + # if node_affinity(a1[0], a2[0]): + affinity = node_affinity(t1, t2) + if affinity: + new_heads, pval_h = _lcs(head1, head2, open_to_close, node_affinity, open_to_tok, _memo, _seq_memo) + new_tails, pval_t = _lcs(tail1, tail2, open_to_close, node_affinity, open_to_tok, _memo, _seq_memo) + + new_head1, new_head2 = new_heads + new_tail1, new_tail2 = new_tails + + subseq1 = a1 + new_head1 + b1 + new_tail1 + subseq2 = a2 + new_head2 + b2 + new_tail2 + + cand = (subseq1, subseq2) + val_alt = pval_h + pval_t + affinity + if val_alt > val: + best = cand + val = val_alt + + found = (best, val) + _memo[key] = found + return found + + +@profile +def longest_common_isomorphic_sequence(seq1, seq2, open_to_close, node_affinity='auto', open_to_tok=None): + if node_affinity == 'auto': + node_affinity = operator.eq + if node_affinity is None: + def _matchany(a, b): + return True + node_affinity = _matchany + _memo = {} + _seq_memo = {} + if open_to_tok is None: + class Dummy: + def __getitem__(self, key): + return key + open_to_tok = Dummy() + best_lvl, value_lvl, best_low, value_low = _lcsi(seq1, seq2, open_to_close, node_affinity, open_to_tok, _memo, _seq_memo) + + if value_lvl > value_low: + best = best_lvl + value = value_lvl + else: + best = best_low + value = value_low + + return best, value + + +@profile +def _lcsi(seq1, seq2, open_to_close, node_affinity, open_to_tok, _memo, _seq_memo): + """ + Prototype isomorphic only version + """ + if not seq1: + return (seq1, seq1), 0, (seq1, seq1), 0 + elif not seq2: + return (seq2, seq2), 0, (seq2, seq2), 0 + else: + key1 = hash(seq1) + key2 = hash(seq2) + key = hash((key1, key2)) + if key in _memo: + return _memo[key] + + if key1 in _seq_memo: + a1, b1, head1, tail1, head1_tail1 = _seq_memo[key1] + else: + a1, b1, head1, tail1 = balanced_decomp_unsafe2(seq1, open_to_close) + head1_tail1 = head1 + tail1 + _seq_memo[key1] = a1, b1, head1, tail1, head1_tail1 + + if key2 in _seq_memo: + a2, b2, head2, tail2, head2_tail2 = _seq_memo[key2] + else: + a2, b2, head2, tail2 = balanced_decomp_unsafe2(seq2, open_to_close) + head2_tail2 = head2 + tail2 + _seq_memo[key2] = a2, b2, head2, tail2, head2_tail2 + + # TODO: IS THIS THE CORRECT MODIFICATION TO THE RECURRANCE TO + # ACHIEVE A SUBTREE ISOMORPHISM INSTEAD OF AN EMBEDDING? + r""" + + tree1 = nx.OrderedDiGraph() + tree1.add_nodes_from(['a', 'b', 'c', 'd', 'e', 'f', 'g']) + tree1.add_edges_from([('a', 'b'), ('a', 'c'), ('a', 'd'), ('b', 'e'), ('b', 'f'), ('c', 'g')]) + + _print_forest(tree1) + + └── a + ├── b + │   ├── e + │   └── f + ├── c + │   └── g + └── d + + seq1, open_to_close, toks = tree_to_balanced_sequence(tree1, mode='chr') + a, b, head1, tail1 = balanced_decomp(seq1, open_to_close) + _print_forest(seq_to_tree(head1, open_to_close, toks)) + _print_forest(seq_to_tree(tail1, open_to_close, toks)) + + CONTRACTED NODE: + a + + HEAD (children of the contracted node) + + ├── b + │   ├── e + │   └── f + ├── c + │   └── g + └── d + + TAIL (right siblings of the contracted node) + -- + + a, b, head11, tail11 = balanced_decomp(head1, open_to_close) + _print_forest(seq_to_tree(head11, open_to_close, toks)) + _print_forest(seq_to_tree(tail11, open_to_close, toks)) + + CONTRACTED NODE: + b + + HEAD OF HEAD + ├── e + └── f + + TAIL OF HEAD + ├── c + │   └── g + └── d + + + The problem here is that if you are at a level where two levels down + there are two matches, you will return those two matches as the best + solution at that layer, and therefore you won't flag if there is a + feasible solution at this layer. This is a problem because that + feasible low-value solution might be part of the highest value + solution. + + Perhaps we return two solutions at each step: the solution value at + this level if one exists, and the solution value at any other depth. + We are allowed to add to the first, but can take the second if we want + to. + + This should work because we know a solution that skipped a layer will + never be added to, and we are always keeping track of the solution that + might change. By the time we get to the root level, we have enough info + to know which is better. + """ + + # If any of these cases are selected we are not choosing the leftmost + # node as our match + best_lvl, val_lvl, best_low, val_low = None, -1, None, -1 + + # TODO: it may be the case that some of these tests are redundant, in + # which case we could simplify and speed up the algorithm. We would + # need to prove that the value in one of these tests was always lower + # than the value in another one of these tests, in that case we could + # remove the former. + + # When using the head part of the decomp, we can only update the "low" candidate + cand_lvl, score_lvl, cand_low, score_low = _lcsi(head1, seq2, open_to_close, node_affinity, open_to_tok, _memo, _seq_memo) + if score_low > val_low: + val_low = score_low + best_low = cand_low + if score_lvl > val_low: + val_low = score_lvl + best_low = cand_lvl + + cand_lvl, score_lvl, cand_low, score_low = _lcsi(seq1, head2, open_to_close, node_affinity, open_to_tok, _memo, _seq_memo) + if score_low > val_low: + val_low = score_low + best_low = cand_low + if score_lvl > val_low: + val_low = score_lvl + best_low = cand_lvl + + # As long as we are only using the tail part of the decomp we can update + # both the lvl and low scores + cand_lvl, score_lvl, cand_low, score_low = _lcsi(tail1, seq2, open_to_close, node_affinity, open_to_tok, _memo, _seq_memo) + if score_lvl > val_lvl: + val_lvl = score_lvl + best_lvl = cand_lvl + if score_low > val_low: + val_low = score_low + best_low = cand_low + + cand_lvl, score_lvl, cand_low, score_low = _lcsi(seq1, tail2, open_to_close, node_affinity, open_to_tok, _memo, _seq_memo) + if score_lvl > val_lvl: + val_lvl = score_lvl + best_lvl = cand_lvl + if score_low > val_low: + val_low = score_low + best_low = cand_low + + # This is the case where we found a matching node + t1 = open_to_tok[a1[0]] + t2 = open_to_tok[a2[0]] + affinity = node_affinity(t1, t2) + if affinity: + + new_heads_lvl, pval_h_lvl, new_heads_low, pval_h_low = _lcsi(head1, head2, open_to_close, node_affinity, open_to_tok, _memo, _seq_memo) + new_tails_lvl, pval_t_lvl, new_tails_low, pval_t_low = _lcsi(tail1, tail2, open_to_close, node_affinity, open_to_tok, _memo, _seq_memo) + + # Add to the best solution at the former level + score_lvl = pval_h_lvl + pval_t_lvl + affinity + if score_lvl > val_lvl: + new_head1, new_head2 = new_heads_lvl + new_tail1, new_tail2 = new_tails_lvl + subseq1 = a1 + new_head1 + b1 + new_tail1 + subseq2 = a2 + new_head2 + b2 + new_tail2 + cand_lvl = (subseq1, subseq2) + val_lvl = score_lvl + best_lvl = cand_lvl + + # In my big tests these were never hit once, is it true that this + # test was covered by a previous case? + cand_low = new_heads_low + score_low = pval_h_low + if score_low > val_low: + val_low = score_low + best_low = cand_low + + cand_low = new_tails_low + score_low = pval_t_low + if score_low > val_low: + val_low = score_low + best_low = cand_low + + # We return two solutions: + # the best AT this level (lvl), and the best AT any lowers (low). + found = (best_lvl, val_lvl, best_low, val_low) + _memo[key] = found + return found + + +def _print_forest(graph): + """ + Nice ascii representation of a forest + + Ignore: + graph = nx.balanced_tree(r=2, h=3, create_using=nx.DiGraph) + _print_forest(graph) + + graph = CategoryTree.demo('coco').graph + _print_forest(graph) + """ + if len(graph.nodes) == 0: + print('--') + return + assert nx.is_forest(graph) + + def _recurse(node, indent='', islast=False): + if islast: + this_prefix = indent + '└── ' + next_prefix = indent + ' ' + else: + this_prefix = indent + '├── ' + next_prefix = indent + '│   ' + label = graph.nodes[node].get('label', node) + print(this_prefix + str(label)) + graph.succ[node] + children = graph.succ[node] + for idx, child in enumerate(children, start=1): + islast_next = (idx == len(children)) + _recurse(child, indent=next_prefix, islast=islast_next) + + sources = [n for n in graph.nodes if graph.in_degree[n] == 0] + for idx, node in enumerate(sources, start=1): + islast_next = (idx == len(sources)) + _recurse(node, indent='', islast=islast_next) + + +def maximum_common_ordered_paths(paths1, paths2, sep='/'): + import networkx as nx + + # the longest common balanced sequence problem + def _affinity(tok1, tok2): + score = 0 + for t1, t2 in zip(tok1[::-1], tok2[::-1]): + if t1 == t2: + score += 1 + else: + break + return score + # return tok1[-1] == tok2[-1] + node_affinity = _affinity + # import operator + # eq = operator.eq + + def paths_to_tree(paths): + tree = nx.OrderedDiGraph() + for path in sorted(paths): + parts = tuple(path.split(sep)) + node_path = [] + for i in range(1, len(parts) + 1): + node = parts[0:i] + tree.add_node(node) + tree.nodes[node]['label'] = node[-1] + node_path.append(node) + for u, v in ub.iter_window(node_path, 2): + tree.add_edge(u, v) + return tree + + tree1 = paths_to_tree(paths1) + tree2 = paths_to_tree(paths2) + + subtree1, subtree2 = maximum_common_ordered_tree_embedding(tree1, tree2, node_affinity=node_affinity) + # subtree1, subtree2 = maximum_common_ordered_subtree_isomorphism(tree1, tree2, node_affinity=node_affinity) + + subpaths1 = [sep.join(node) for node in subtree1.nodes if subtree1.out_degree[node] == 0] + subpaths2 = [sep.join(node) for node in subtree2.nodes if subtree2.out_degree[node] == 0] + return subpaths1, subpaths2 diff --git a/netharn/initializers/_nx_extensions_cython_backend.pyx b/netharn/initializers/_nx_extensions_cython_backend.pyx new file mode 100644 index 0000000000000000000000000000000000000000..c3d312e6146f61a51417bfd838ae8235a066ba61 --- /dev/null +++ b/netharn/initializers/_nx_extensions_cython_backend.pyx @@ -0,0 +1,46 @@ +""" +cythonize -a -i ~/code/netharn/netharn/initializers/_nx_extensions_cython_backend.pyx + + >>> from netharn.initializers import _nx_extensions_cython_backend + >>> import timerit + >>> ti = timerit.Timerit(100, bestof=10, verbose=2) + >>> for timer in ti.reset('time'): + >>> with timer: + >>> list(_nx_extensions_cython_backend.generate_balance_unsafe_cython(sequence, open_to_close)) + +""" + +def generate_balance_unsafe_cython(sequence, open_to_close): + cdef tuple item + cdef bint flag + cdef int stacklen = 0 + for token in sequence: + if token in open_to_close: + stacklen += 1 + else: + stacklen -= 1 + flag = stacklen == 0 + item = (flag, token) + yield item + + +def balanced_decomp_unsafe2_cython(tuple sequence, dict open_to_close): + cdef int stacklen = 1 # always +1 in the first iteration + cdef int head_stop = 1 + + tok_curr = sequence[0] + want_close = open_to_close[tok_curr] + + # for tok_curr in sequence[1:]: + for head_stop in range(1, len(sequence)): + tok_curr = sequence[head_stop] + stacklen += 1 if tok_curr in open_to_close else -1 + if stacklen == 0 and tok_curr == want_close: + pop_close = sequence[head_stop:head_stop + 1] + break + + pop_open = sequence[0:1] + head = sequence[1:head_stop] + tail = sequence[head_stop + 1:] + return pop_open, pop_close, head, tail + diff --git a/netharn/initializers/balanced_sequence.py b/netharn/initializers/balanced_sequence.py new file mode 100644 index 0000000000000000000000000000000000000000..13d5db1751c06bea874db00cac374abbb9849b6a --- /dev/null +++ b/netharn/initializers/balanced_sequence.py @@ -0,0 +1,969 @@ +import operator +import ubelt as ub +import networkx as nx + +try: + import xdev + profile = xdev.profile +except Exception: + profile = ub.identity + + +# @profile +def longest_common_balanced_sequence(seq1, seq2, open_to_close, open_to_tok=None, node_affinity='auto', impl='iter'): + """ + CommandLine: + xdoctest -m /home/joncrall/code/netharn/netharn/initializers/balanced_sequence.py longest_common_balanced_sequence:0 --profile && cat profile_output.txt + + Example: + >>> from netharn.initializers.balanced_sequence import * # NOQA + >>> from netharn.initializers.balanced_sequence import _lcs_iter_prehash, _lcs_iter_simple, _lcs_recurse, _print_forest + >>> tree1 = random_ordered_tree(5, seed=10, pool='[{(') + >>> tree2 = random_ordered_tree(5, seed=3, pool='[{(') + + >>> import kwarray + >>> rng = kwarray.ensure_rng(3432432, 'python') + >>> tree1 = random_ordered_tree(100, seed=rng, pool='[{(') + >>> tree2 = random_ordered_tree(100, seed=rng, pool='[{(') + >>> if len(tree1.nodes) < 20: + >>> _print_forest(tree1) + >>> _print_forest(tree2) + >>> seq1, open_to_close, toks = tree_to_balanced_sequence(tree1, mode='label', strhack=1) + >>> seq2, open_to_close, toks = tree_to_balanced_sequence(tree2, open_to_close, toks, mode='label', strhack=1) + >>> full_seq1 = seq1 + >>> full_seq2 = seq2 + >>> print('seq1 = {!r}'.format(seq1)) + >>> print('seq2 = {!r}'.format(seq2)) + >>> open_to_tok = ub.invert_dict(toks) + >>> node_affinity = operator.eq + >>> with ub.Timer('iterative-alt2'): + >>> best1, val1 = longest_common_balanced_sequence(seq1, seq2, open_to_close, open_to_tok, impl='iter-alt2') + >>> print('val1, best1 = {}, {!r}'.format(val1, best1)) + >>> with ub.Timer('iterative-alt1'): + >>> best1, val1 = longest_common_balanced_sequence(seq1, seq2, open_to_close, open_to_tok, impl='iter-alt1') + >>> print('val1, best1 = {}, {!r}'.format(val1, best1)) + >>> with ub.Timer('iterative'): + >>> best1, val1 = longest_common_balanced_sequence(seq1, seq2, open_to_close, open_to_tok, impl='iter') + >>> print('val1, best1 = {}, {!r}'.format(val1, best1)) + >>> with ub.Timer('recursive'): + >>> best2, val2 = longest_common_balanced_sequence(seq1, seq2, open_to_close, open_to_tok, impl='recurse') + >>> print('val2, best2 = {}, {!r}'.format(val2, best2)) + >>> #with ub.Timer('iterative-prehash'): + >>> # best1, val1 = longest_common_balanced_sequence(seq1, seq2, open_to_close, open_to_tok, impl='iter-prehash') + >>> # print('val1, best1 = {}, {!r}'.format(val1, best1)) + """ + if node_affinity == 'auto' or node_affinity == 'eq': + node_affinity = operator.eq + if node_affinity is None: + def _matchany(a, b): + return True + node_affinity = _matchany + _memo = {} + _seq_memo = {} + if open_to_tok is None: + class Dummy: + def __getitem__(self, key): + return key + open_to_tok = Dummy() + full_seq1 = seq1 + full_seq2 = seq2 + if impl == 'recurse': + best, value = _lcs_recurse(full_seq1, full_seq2, open_to_close, node_affinity, open_to_tok, _memo, _seq_memo) + elif impl == 'iter': + best, value = _lcs_iter_simple(full_seq1, full_seq2, open_to_close, node_affinity, open_to_tok) + elif impl == 'iter-prehash': + best, value = _lcs_iter_prehash(full_seq1, full_seq2, open_to_close, node_affinity, open_to_tok) + elif impl == 'iter-alt1': + best, value = _lcs_iter_simple_alt1(full_seq1, full_seq2, open_to_close, node_affinity, open_to_tok) + elif impl == 'iter-alt2': + best, value = _lcs_iter_simple_alt2(full_seq1, full_seq2, open_to_close, node_affinity, open_to_tok) + else: + raise KeyError(impl) + return best, value + + +@profile +def _lcs_iter_simple(full_seq1, full_seq2, open_to_close, node_affinity, open_to_tok): + """ + Converts _lcs_recursive to an iterative algorithm using a fairly + straightforward method that effectivly simulates callstacks + """ + all_decomp1 = generate_all_decompositions(full_seq1, open_to_close, open_to_tok) + all_decomp2 = generate_all_decompositions(full_seq2, open_to_close, open_to_tok) + + args0 = (full_seq1, full_seq2) + frame0 = args0 + stack = [frame0] + + _results = {} + # Populate base cases + empty1 = type(ub.peek(all_decomp1.keys()))() + empty2 = type(ub.peek(all_decomp2.keys()))() + best = (empty1, empty2) + base_result = (0, best) + for seq1 in all_decomp1.keys(): + key1 = seq1 + t1, a1, b1, head1, tail1, head_tail1 = all_decomp1[key1] + _results[(seq1, empty2)] = base_result + _results[(head1, empty2)] = base_result + _results[(tail1, empty2)] = base_result + _results[(head_tail1, empty2)] = base_result + + for seq2 in all_decomp2.keys(): + key2 = seq2 + t2, a2, b2, head2, tail2, head_tail2 = all_decomp2[key2] + _results[(empty1, seq2)] = base_result + _results[(empty1, head2)] = base_result + _results[(empty1, tail2)] = base_result + _results[(empty1, head_tail2)] = base_result + + del args0 + del frame0 + del empty1 + del empty2 + del best + del base_result + + missing_frames = [] + while stack: + key = stack.pop() + if key not in _results: + seq1, seq2 = key + missing_frames.clear() + + # try: + t1, a1, b1, head1, tail1, head_tail1 = all_decomp1[seq1] + # except KeyError: + # a1, b1, head1, tail1 = balanced_decomp_unsafe(seq1, open_to_close) + # head_tail1 = head1 + tail1 + # all_decomp1[seq1] = a1, b1, head1, tail1, head_tail1 + + # try: + t2, a2, b2, head2, tail2, head_tail2 = all_decomp2[seq2] + # except KeyError: + # a2, b2, head2, tail2 = balanced_decomp_unsafe(seq2, open_to_close) + # head_tail2 = head2 + tail2 + # all_decomp2[seq2] = a2, b2, head2, tail2, head_tail2 + + # Case 2: The current edge in sequence1 is deleted + try: + try_key = (head_tail1, seq2) + cand1 = _results[try_key] + except KeyError: + missing_frames.append(try_key) + + # Case 3: The current edge in sequence2 is deleted + try: + try_key = (seq1, head_tail2) + cand2 = _results[try_key] + except KeyError: + missing_frames.append(try_key) + + # Case 1: The LCS involves this edge + affinity = node_affinity(t1, t2) + if affinity: + try: + try_key = (head1, head2) + pval_h, new_heads = _results[try_key] + except KeyError: + missing_frames.append(try_key) + + try: + try_key = (tail1, tail2) + pval_t, new_tails = _results[try_key] + except KeyError: + missing_frames.append(try_key) + + if not missing_frames: + new_head1, new_head2 = new_heads + new_tail1, new_tail2 = new_tails + + subseq1 = a1 + new_head1 + b1 + new_tail1 + subseq2 = a2 + new_head2 + b2 + new_tail2 + + res3 = (subseq1, subseq2) + val3 = pval_h + pval_t + affinity + cand3 = (val3, res3) + else: + cand3 = (-1, None) + + if missing_frames: + # We did not solve this frame yet + stack.append(key) + stack.extend(missing_frames) + # stack.extend(missing_frames[::-1]) + else: + # We solved the frame + _results[key] = max(cand1, cand2, cand3) + + val, best = _results[key] + found = (best, val) + return found + + +@profile +def _lcs_iter_simple_alt1(full_seq1, full_seq2, open_to_close, node_affinity, open_to_tok): + """ + Depth first stack trajectory + """ + all_decomp1 = generate_all_decompositions(full_seq1, open_to_close, open_to_tok) + all_decomp2 = generate_all_decompositions(full_seq2, open_to_close, open_to_tok) + + args0 = (full_seq1, full_seq2) + frame0 = args0 + stack = [frame0] + + _results = {} + # Populate base cases + empty1 = type(ub.peek(all_decomp1.keys()))() + empty2 = type(ub.peek(all_decomp2.keys()))() + best = (empty1, empty2) + base_result = (0, best) + for seq1 in all_decomp1.keys(): + key1 = seq1 + t1, a1, b1, head1, tail1, head_tail1 = all_decomp1[key1] + _results[(seq1, empty2)] = base_result + _results[(head1, empty2)] = base_result + _results[(tail1, empty2)] = base_result + _results[(head_tail1, empty2)] = base_result + + for seq2 in all_decomp2.keys(): + key2 = seq2 + t2, a2, b2, head2, tail2, head_tail2 = all_decomp2[key2] + _results[(empty1, seq2)] = base_result + _results[(empty1, head2)] = base_result + _results[(empty1, tail2)] = base_result + _results[(empty1, head_tail2)] = base_result + + del args0 + del frame0 + del empty1 + del empty2 + del best + del base_result + + while stack: + key = stack.pop() + if key not in _results: + seq1, seq2 = key + + t1, a1, b1, head1, tail1, head_tail1 = all_decomp1[seq1] + + t2, a2, b2, head2, tail2, head_tail2 = all_decomp2[seq2] + + # Case 2: The current edge in sequence1 is deleted + try: + try_key = (head_tail1, seq2) + cand1 = _results[try_key] + except KeyError: + stack.append(key) + stack.append(try_key) + continue + + # Case 3: The current edge in sequence2 is deleted + try: + try_key = (seq1, head_tail2) + cand2 = _results[try_key] + except KeyError: + stack.append(key) + stack.append(try_key) + continue + + # Case 1: The LCS involves this edge + affinity = node_affinity(t1, t2) + if affinity: + try: + try_key = (head1, head2) + pval_h, new_heads = _results[try_key] + except KeyError: + stack.append(key) + stack.append(try_key) + continue + + try: + try_key = (tail1, tail2) + pval_t, new_tails = _results[try_key] + except KeyError: + stack.append(key) + stack.append(try_key) + continue + + new_head1, new_head2 = new_heads + new_tail1, new_tail2 = new_tails + + subseq1 = a1 + new_head1 + b1 + new_tail1 + subseq2 = a2 + new_head2 + b2 + new_tail2 + + res3 = (subseq1, subseq2) + val3 = pval_h + pval_t + affinity + cand3 = (val3, res3) + else: + cand3 = (-1, None) + + # We solved the frame + _results[key] = max(cand1, cand2, cand3) + + val, best = _results[key] + found = (best, val) + return found + + +@profile +def _lcs_iter_simple_alt2(full_seq1, full_seq2, open_to_close, node_affinity, open_to_tok): + """ + Depth first stack trajectory and replace try except statements with ifs + """ + all_decomp1 = generate_all_decompositions(full_seq1, open_to_close, open_to_tok) + all_decomp2 = generate_all_decompositions(full_seq2, open_to_close, open_to_tok) + + key0 = (full_seq1, full_seq2) + frame0 = key0 + stack = [frame0] + + _results = {} + # Populate base cases + empty1 = type(ub.peek(all_decomp1.keys()))() + empty2 = type(ub.peek(all_decomp2.keys()))() + best = (empty1, empty2) + base_result = (0, best) + for seq1 in all_decomp1.keys(): + key1 = seq1 + t1, a1, b1, head1, tail1, head_tail1 = all_decomp1[key1] + _results[(seq1, empty2)] = base_result + _results[(head1, empty2)] = base_result + _results[(tail1, empty2)] = base_result + _results[(head_tail1, empty2)] = base_result + + for seq2 in all_decomp2.keys(): + key2 = seq2 + t2, a2, b2, head2, tail2, head_tail2 = all_decomp2[key2] + _results[(empty1, seq2)] = base_result + _results[(empty1, head2)] = base_result + _results[(empty1, tail2)] = base_result + _results[(empty1, head_tail2)] = base_result + + del frame0 + del empty1 + del empty2 + del best + del base_result + + while stack: + key = stack[-1] + if key not in _results: + seq1, seq2 = key + + t1, a1, b1, head1, tail1, head_tail1 = all_decomp1[seq1] + t2, a2, b2, head2, tail2, head_tail2 = all_decomp2[seq2] + + # Case 2: The current edge in sequence1 is deleted + try_key = (head_tail1, seq2) + if try_key in _results: + cand1 = _results[try_key] + else: + # stack.append(key) + stack.append(try_key) + continue + + # Case 3: The current edge in sequence2 is deleted + try_key = (seq1, head_tail2) + if try_key in _results: + cand2 = _results[try_key] + else: + # stack.append(key) + stack.append(try_key) + continue + + # Case 1: The LCS involves this edge + affinity = node_affinity(t1, t2) + if affinity: + try_key = (head1, head2) + if try_key in _results: + pval_h, new_heads = _results[try_key] + else: + # stack.append(key) + stack.append(try_key) + continue + + try_key = (tail1, tail2) + if try_key in _results: + pval_t, new_tails = _results[try_key] + else: + # stack.append(key) + stack.append(try_key) + continue + + new_head1, new_head2 = new_heads + new_tail1, new_tail2 = new_tails + + subseq1 = a1 + new_head1 + b1 + new_tail1 + subseq2 = a2 + new_head2 + b2 + new_tail2 + + res3 = (subseq1, subseq2) + val3 = pval_h + pval_t + affinity + cand3 = (val3, res3) + else: + cand3 = (-1, None) + + # We solved the frame + _results[key] = max(cand1, cand2, cand3) + stack.pop() + + val, best = _results[key0] + found = (best, val) + return found + + +@profile +def _lcs_iter_prehash(full_seq1, full_seq2, open_to_close, node_affinity, open_to_tok): + """ + Version of the lcs iterative algorithm where we precompute hash values + + This is actually slower than the simple version + """ + def decomp_info(seq, open_to_close): + pop_open, pop_close, head, tail = balanced_decomp_unsafe(seq, open_to_close) + head_tail = head + tail + head_key = hash(head) + tail_key = hash(tail) + head_tail_key = hash(head_tail) + tok = open_to_tok[pop_open[0]] + a = pop_open + b = pop_close + info = (tok, seq, head, tail, head_tail, head_key, tail_key, head_tail_key, a, b) + return info + + def gen_decomp_v2(seq, open_to_close): + _genmemo = {} + def _gen(seq): + if seq: + key = hash(seq) + if key not in _genmemo: + info = decomp_info(seq, open_to_close) + head, tail, head_tail = info[2:5] + _genmemo[key] = info + yield (seq, _genmemo[key]) + yield from _gen(head_tail) + yield from _gen(head) + yield from _gen(tail) + all_decomp = dict(_gen(seq)) + return all_decomp + + all_decomp1 = gen_decomp_v2(full_seq1, open_to_close) + all_decomp2 = gen_decomp_v2(full_seq2, open_to_close) + + key_decomp1 = {} + key_decomp2 = {} + _results = {} + # Populate base cases + empty1 = type(ub.peek(all_decomp1.keys()))() + empty2 = type(ub.peek(all_decomp2.keys()))() + empty1_key = hash(empty1) + empty2_key = hash(empty2) + best = (empty1, empty2) + base_result = (0, best) + for seq1, info1 in all_decomp1.items(): + seq1_key = hash(seq1) + head1_key, tail1_key, head_tail1_key = all_decomp1[seq1][5:8] + _results[(seq1_key, empty2_key)] = base_result + _results[(head1_key, empty2_key)] = base_result + _results[(tail1_key, empty2_key)] = base_result + _results[(head_tail1_key, empty2_key)] = base_result + key_decomp1[seq1_key] = info1 + + for seq2, info2 in all_decomp2.items(): + seq2_key = hash(seq2) + head2_key, tail2_key, head_tail2_key = all_decomp2[seq2][5:8] + _results[(empty1_key, seq2_key)] = base_result + _results[(empty1_key, head2_key)] = base_result + _results[(empty1_key, tail2_key)] = base_result + _results[(empty1_key, head_tail2_key)] = base_result + key_decomp2[seq2_key] = info2 + + full_seq1_key = hash(full_seq1) + full_seq2_key = hash(full_seq2) + key0 = (full_seq1_key, full_seq2_key) + frame0 = key0, full_seq1, full_seq2 + stack = [frame0] + missing_frames = [] + while stack: + frame = stack.pop() + key, seq1, seq2 = frame + seq1_key, seq2_key = key + if key not in _results: + missing_frames.clear() + + try: + info1 = key_decomp1[seq1_key] + except KeyError: + info1 = decomp_info(seq1, open_to_close) + key_decomp1[seq1_key] = info1 + tok1, seq1, head1, tail1, head_tail1, head1_key, tail1_key, head_tail1_key, a1, b1 = info1 + + try: + info2 = key_decomp2[seq2_key] + except KeyError: + info2 = decomp_info(seq2, open_to_close) + key_decomp2[seq2_key] = info2 + tok2, seq2, head2, tail2, head_tail2, head2_key, tail2_key, head_tail2_key, a2, b2 = info2 + + affinity = node_affinity(tok1, tok2) + + # Case 2: The current edge in sequence1 is deleted + try: + try_key = (head_tail1_key, seq2_key) + cand1 = _results[try_key] + except KeyError: + miss_frame = try_key, head_tail1, seq2 + missing_frames.append(miss_frame) + + # Case 3: The current edge in sequence2 is deleted + try: + try_key = (seq1_key, head_tail2_key) + cand2 = _results[try_key] + except KeyError: + miss_frame = try_key, seq1, head_tail2 + missing_frames.append(miss_frame) + + # Case 1: The LCS involves this edge + if affinity: + try: + try_key = (head1_key, head2_key) + pval_h, new_heads = _results[try_key] + except KeyError: + miss_frame = try_key, head1, head2 + missing_frames.append(miss_frame) + + try: + try_key = (tail1_key, tail2_key) + pval_t, new_tails = _results[try_key] + except KeyError: + miss_frame = try_key, tail1, tail2 + missing_frames.append(miss_frame) + + if not missing_frames: + new_head1, new_head2 = new_heads + new_tail1, new_tail2 = new_tails + + subseq1 = a1 + new_head1 + b1 + new_tail1 + subseq2 = a2 + new_head2 + b2 + new_tail2 + + res3 = (subseq1, subseq2) + val3 = pval_h + pval_t + affinity + cand3 = (val3, res3) + else: + cand3 = (-1, None) + + if missing_frames: + # We did not solve this frame yet + stack.append(frame) + stack.extend(missing_frames[::-1]) + else: + # We solved the frame + _results[key] = max(cand1, cand2, cand3) + + # The stack pop is our solution + (val, best) = _results[key] + found = (best, val) + return found + + +def generate_all_decompositions(seq, open_to_close, open_to_tok=None): + """ + Can doing this a-priori speed up the algorithm? + + open_to_close = {0: 1} + sequence = [0, 0, 0, 1, 1, 1, 0, 1] + open_to_close = {'{': '}', '(': ')', '[': ']'} + seq = '({[[]]})[[][]]{{}}' + pop_open, pop_close, head, tail = balanced_decomp(seq, open_to_close) + + >>> tree = random_ordered_tree(10) + >>> seq, open_to_close, toks = tree_to_balanced_sequence(tree) + >>> all_decomp = generate_all_decompositions(seq, open_to_close) + """ + if open_to_tok is None: + class Dummy: + def __getitem__(self, key): + return key + open_to_tok = Dummy() + _memo = {} + def _gen(seq): + if not seq: + pass + # yield None + elif seq in _memo: + pass + # yield (seq, _memo[seq]) + else: + pop_open, pop_close, head, tail = balanced_decomp(seq, open_to_close) + head_tail = head + tail + tok = open_to_tok[pop_open[0]] + _memo[seq] = (tok, pop_open, pop_close, head, tail, head_tail) + yield (seq, _memo[seq]) + yield from _gen(head_tail) + yield from _gen(head) + yield from _gen(tail) + all_decomp = dict(_gen(seq)) + return all_decomp + + +@profile +def _lcs_recurse(seq1, seq2, open_to_close, node_affinity, open_to_tok, _memo, _seq_memo): + if not seq1: + return (seq1, seq1), 0 + elif not seq2: + return (seq2, seq2), 0 + else: + # if len(seq2) < len(seq1): + # seq1, seq2 = seq2, seq1 + # key = (seq1, seq2) + key1 = hash(seq1) # using hash(seq) is faster than seq itself + key2 = hash(seq2) + key = hash((key1, key2)) + if key in _memo: + return _memo[key] + + # TODO: we can probably just do a single linear run through the + # sequences to index the sub-sequence locations and then apply an + # offset when we run the decomposed sequence. + if key1 in _seq_memo: + a1, b1, head1, tail1, head1_tail1 = _seq_memo[key1] + else: + a1, b1, head1, tail1 = balanced_decomp_unsafe(seq1, open_to_close) + head1_tail1 = head1 + tail1 + _seq_memo[key1] = a1, b1, head1, tail1, head1_tail1 + + if key2 in _seq_memo: + a2, b2, head2, tail2, head2_tail2 = _seq_memo[key2] + else: + a2, b2, head2, tail2 = balanced_decomp_unsafe(seq2, open_to_close) + head2_tail2 = head2 + tail2 + _seq_memo[key2] = a2, b2, head2, tail2, head2_tail2 + + # Case 2: The current edge in sequence1 is deleted + best, val = _lcs_recurse(head1_tail1, seq2, open_to_close, node_affinity, open_to_tok, _memo, _seq_memo) + + # Case 3: The current edge in sequence2 is deleted + cand, val_alt = _lcs_recurse(seq1, head2_tail2, open_to_close, node_affinity, open_to_tok, _memo, _seq_memo) + if val_alt > val: + best = cand + val = val_alt + + # Case 1: The LCS involves this edge + t1 = open_to_tok[a1[0]] + t2 = open_to_tok[a2[0]] + # if node_affinity(a1[0], a2[0]): + affinity = node_affinity(t1, t2) + if affinity: + new_heads, pval_h = _lcs_recurse(head1, head2, open_to_close, node_affinity, open_to_tok, _memo, _seq_memo) + new_tails, pval_t = _lcs_recurse(tail1, tail2, open_to_close, node_affinity, open_to_tok, _memo, _seq_memo) + + new_head1, new_head2 = new_heads + new_tail1, new_tail2 = new_tails + + subseq1 = a1 + new_head1 + b1 + new_tail1 + subseq2 = a2 + new_head2 + b2 + new_tail2 + + cand = (subseq1, subseq2) + val_alt = pval_h + pval_t + affinity + if val_alt > val: + best = cand + val = val_alt + + found = (best, val) + _memo[key] = found + return found + + +class UnbalancedException(Exception): + pass + + +def balanced_decomp(sequence, open_to_close): + """ + Note this is not exactly the same as the decomposition in the paper. + That is because we also return the "wrapping" element, and we let the + user do the head + tail concatenation. + + Example: + >>> open_to_close = {0: 1} + >>> sequence = [0, 0, 0, 1, 1, 1, 0, 1] + >>> open_to_close = {'{': '}', '(': ')', '[': ']'} + >>> sequence = '({[[]]})[[][]]' + >>> a1, b1, head, tail = balanced_decomp(sequence, open_to_close) + >>> a2, b2, tail1, tail2 = balanced_decomp(tail, open_to_close) + """ + gen = generate_balance(sequence, open_to_close) + + bal_curr, tok_curr = next(gen) + pop_open = sequence[0:1] + want_close = open_to_close[tok_curr] + + head_stop = 1 + for head_stop, (bal_curr, tok_curr) in enumerate(gen, start=1): + if tok_curr is None: + break + elif bal_curr and tok_curr == want_close: + pop_close = sequence[head_stop:head_stop + 1] + break + head = sequence[1:head_stop] + # if __debug__: + # list(gen) # exhaust the generator to check we are balanced + tail = sequence[head_stop + 1:] + return pop_open, pop_close, head, tail + + +def tree_to_balanced_sequence(tree, open_to_close=None, toks=None, mode='tuple', strhack=False): + from collections import namedtuple + Token = namedtuple('Token', ['action', 'value']) + # mapping between opening and closing tokens + sources = [n for n in tree.nodes if tree.in_degree[n] == 0] + sequence = [] + + if open_to_close is None: + open_to_close = {} + if toks is None: + toks = {} + + if strhack: + if mode == 'label': + all_labels = {n['label'] for n in list(tree.nodes.values())} + assert all(x == 1 for x in map(len, all_labels)) + + for source in sources: + for u, v, etype in nx.dfs_labeled_edges(tree, source=source): + if etype == 'forward': + # u has been visited by v has not + if v not in toks: + if mode == 'tuple': + # TODO: token encoding scheme where subdirectories + # are matchable via a custom operation. + # open_tok = '<{}>'.format(v) + # close_tok = ''.format(v) + open_tok = Token('open', v) + close_tok = Token('close', v) + elif mode == 'number': + open_tok = len(toks) + 1 + close_tok = -open_tok + elif mode == 'paren': + open_tok = '{}('.format(v) + close_tok = '){}'.format(v) + elif mode == 'chr': + open_tok = str(v) + close_tok = str(v) + u'\u0301' + elif mode == 'label': + open_tok = tree.nodes[v]['label'] + assert strhack + if open_tok == '{': + close_tok = '}' + if open_tok == '[': + close_tok = ']' + if open_tok == '(': + close_tok = ')' + toks[v] = open_tok + open_to_close[open_tok] = close_tok + open_tok = toks[v] + sequence.append(open_tok) + elif etype == 'reverse': + # Both u and v are visited and the edge is in the tree + close_tok = open_to_close[toks[v]] + sequence.append(close_tok) + else: + raise KeyError(etype) + sequence = tuple(sequence) + if strhack: + sequence = ''.join(sequence) + return sequence, open_to_close, toks + + +def seq_to_tree(subseq, open_to_close, toks): + open_to_tok = ub.invert_dict(toks) + subtree = nx.OrderedDiGraph() + stack = [] + for token in subseq: + if token in open_to_close: + node = open_to_tok[token] + if stack: + parent = open_to_tok[stack[-1]] + subtree.add_edge(parent, node) + else: + subtree.add_node(node) + stack.append(token) + else: + if not stack: + raise Exception + prev_open = stack.pop() + want_close = open_to_close[prev_open] + if token != want_close: + raise Exception + return subtree + + +def random_ordered_tree(n, seed=None, pool=None): + import kwarray + rng = kwarray.ensure_rng(seed, 'python') + tree = nx.dfs_tree(nx.random_tree(n, seed=seed)) + otree = nx.OrderedDiGraph() + otree.add_edges_from(tree.edges) + if pool is not None: + for node in otree.nodes: + otree.nodes[node]['label'] = rng.choice(pool) + return otree + + +def generate_balance_unsafe(sequence, open_to_close): + """ + Benchmark: + >>> tree = random_ordered_tree(1000) + >>> sequence, open_to_close, toks = tree_to_balanced_sequence(tree, mode='tuple') + >>> sequence, open_to_close, toks = tree_to_balanced_sequence(tree, mode='number') + >>> import timerit + >>> ti = timerit.Timerit(100, bestof=10, verbose=2) + >>> for timer in ti.reset('time'): + >>> with timer: + >>> list(generate_balance_unsafe(sequence, open_to_close)) + >>> import timerit + >>> ti = timerit.Timerit(100, bestof=10, verbose=2) + >>> for timer in ti.reset('time'): + >>> with timer: + >>> list(generate_balance_unsafe_cython(sequence, open_to_close)) + """ + stacklen = 0 + for token in sequence: + if token in open_to_close: + stacklen += 1 + else: + stacklen -= 1 + yield stacklen == 0, token + + +def balanced_decomp_unsafe(sequence, open_to_close): + """ + Example: + >>> open_to_close = {'{': '}', '(': ')', '[': ']'} + >>> sequence = '({[[]]})[[][]]' + >>> print('sequence = {!r}'.format(sequence)) + >>> a1, b1, head, tail = balanced_decomp(sequence, open_to_close) + >>> print('a1 = {!r}'.format(a1)) + >>> print('tail = {!r}'.format(tail)) + >>> print('head = {!r}'.format(head)) + >>> a2, b2, tail1, tail2 = balanced_decomp(tail, open_to_close) + >>> print('a2 = {!r}'.format(a2)) + >>> print('tail1 = {!r}'.format(tail1)) + >>> print('tail2 = {!r}'.format(tail2)) + """ + gen = generate_balance_unsafe(sequence, open_to_close) + + bal_curr, tok_curr = next(gen) + pop_open = sequence[0:1] + want_close = open_to_close[tok_curr] + + head_stop = 1 + for head_stop, (bal_curr, tok_curr) in enumerate(gen, start=1): + if bal_curr and tok_curr == want_close: + pop_close = sequence[head_stop:head_stop + 1] + break + head = sequence[1:head_stop] + tail = sequence[head_stop + 1:] + return pop_open, pop_close, head, tail + + +def generate_balance(sequence, open_to_close): + """ + Safe version + + Example: + >>> open_to_close = {0: 1} + >>> sequence = [0, 0, 0, 1, 1, 1] + >>> gen = list(generate_balance(sequence, open_to_close)) + >>> for flag, token in gen: + >>> print('flag={:d}, token={}'.format(flag, token)) + + Example: + >>> tree = random_ordered_tree(1000) + >>> sequence, open_to_close, toks = tree_to_balanced_sequence(tree) + >>> gen = list(generate_balance(sequence, open_to_close)) + >>> for flag, token in gen: + >>> print('flag={:d}, token={}'.format(flag, token)) + """ + stack = [] + # Traversing the Expression + for token in sequence: + + if token in open_to_close: + # Push opening elements onto the stack + stack.append(token) + else: + # Check that closing elements + if not stack: + raise UnbalancedException + prev_open = stack.pop() + want_close = open_to_close[prev_open] + + if token != want_close: + raise UnbalancedException + + # If the stack is empty the sequence is currently balanced + currently_balanced = not bool(stack) + yield currently_balanced, token + + if stack: + raise UnbalancedException + + +def _print_forest(graph): + """ + Nice ascii representation of a forest + + Ignore: + graph = nx.balanced_tree(r=2, h=3, create_using=nx.DiGraph) + _print_forest(graph) + + graph = CategoryTree.demo('coco').graph + _print_forest(graph) + """ + if len(graph.nodes) == 0: + print('--') + return + assert nx.is_forest(graph) + + def _recurse(node, indent='', islast=False): + if islast: + this_prefix = indent + '└── ' + next_prefix = indent + ' ' + else: + this_prefix = indent + '├── ' + next_prefix = indent + '│   ' + label = graph.nodes[node].get('label', node) + print(this_prefix + str(label)) + graph.succ[node] + children = graph.succ[node] + for idx, child in enumerate(children, start=1): + islast_next = (idx == len(children)) + _recurse(child, indent=next_prefix, islast=islast_next) + + sources = [n for n in graph.nodes if graph.in_degree[n] == 0] + for idx, node in enumerate(sources, start=1): + islast_next = (idx == len(sources)) + _recurse(node, indent='', islast=islast_next) + + +__notes_ = """ + + # if 0: + # tuples = [(i + 1, i + 2, i + 3,) for i in range(4)] + # import timerit + + # ti = timerit.Timerit(100, bestof=10, verbose=2) + # import itertools as it + # for timer in ti.reset('time'): + # with timer: + # tuple(it.chain.from_iterable(tuples)) + # for timer in ti.reset('time'): + # with timer: + # res = tuples[0] + # for a in tuples[1:]: + # res = res + a + +""" diff --git a/netharn/initializers/bseq2.py b/netharn/initializers/bseq2.py new file mode 100644 index 0000000000000000000000000000000000000000..1b26842d0c1f2cb1f6cabb7e290a861613bea51a --- /dev/null +++ b/netharn/initializers/bseq2.py @@ -0,0 +1,612 @@ +import operator +import ubelt as ub +import networkx as nx + +try: + import xdev + profile = xdev.profile +except Exception: + profile = ub.identity + + +def longest_common_balanced_sequence(seq1, seq2, open_to_close, open_to_tok=None, node_affinity='auto', impl='iter'): + """ + CommandLine: + xdoctest -m /home/joncrall/code/netharn/netharn/initializers/balanced_sequence.py longest_common_balanced_sequence:0 --profile && cat profile_output.txt + + Example: + >>> from netharn.initializers.balanced_sequence import * # NOQA + >>> tree1 = random_ordered_tree(5, seed=10, pool='[{(') + >>> tree2 = random_ordered_tree(5, seed=3, pool='[{(') + + >>> import kwarray + >>> rng = kwarray.ensure_rng(None, 'python') + >>> tree1 = random_ordered_tree(100, seed=rng, pool='[{(') + >>> tree2 = random_ordered_tree(200, seed=rng, pool='[{(') + >>> if len(tree1.nodes) < 20: + >>> _print_forest(tree1) + >>> _print_forest(tree2) + >>> seq1, open_to_close, toks = tree_to_balanced_sequence(tree1, mode='label', strhack=1) + >>> seq2, open_to_close, toks = tree_to_balanced_sequence(tree2, open_to_close, toks, mode='label', strhack=1) + >>> full_seq1 = seq1 + >>> full_seq2 = seq2 + >>> print('seq1 = {!r}'.format(seq1)) + >>> print('seq2 = {!r}'.format(seq2)) + >>> open_to_tok = ub.invert_dict(toks) + >>> with ub.Timer('recursive'): + >>> best2, val2 = longest_common_balanced_sequence(seq1, seq2, open_to_close, open_to_tok, impl='recurse') + >>> print('val2, best2 = {}, {!r}'.format(val2, best2)) + >>> with ub.Timer('iterative'): + >>> best1, val1 = longest_common_balanced_sequence(seq1, seq2, open_to_close, open_to_tok, impl='iter') + >>> print('val1, best1 = {}, {!r}'.format(val1, best1)) + """ + if node_affinity == 'auto' or node_affinity == 'eq': + node_affinity = operator.eq + if node_affinity is None: + def _matchany(a, b): + return True + node_affinity = _matchany + _memo = {} + _seq_memo = {} + if open_to_tok is None: + class Dummy: + def __getitem__(self, key): + return key + open_to_tok = Dummy() + full_seq1 = seq1 + full_seq2 = seq2 + if impl == 'recurse': + best, value = _lcs_recurse(full_seq1, full_seq2, open_to_close, node_affinity, open_to_tok, _memo, _seq_memo) + elif impl == 'iter': + best, value = _lcs_iter(full_seq1, full_seq2, open_to_close, node_affinity, open_to_tok) + else: + raise KeyError(impl) + return best, value + + +@profile +def _lcs_iter(full_seq1, full_seq2, open_to_close, node_affinity, open_to_tok): + def decomp_info(seq, open_to_close): + pop_open, pop_close, head, tail = balanced_decomp_unsafe(seq, open_to_close) + head_tail = head + tail + head_key = hash(head) + tail_key = hash(tail) + head_tail_key = hash(head_tail) + tok = open_to_tok[pop_open[0]] + a = pop_open + b = pop_close + info = (tok, seq, head, tail, head_tail, head_key, tail_key, head_tail_key, a, b) + return info + + def gen_decomp_v2(seq, open_to_close): + _genmemo = {} + def _gen(seq): + if seq: + key = hash(seq) + if key not in _genmemo: + info = decomp_info(seq, open_to_close) + head, tail, head_tail = info[2:5] + _genmemo[key] = info + yield (seq, _genmemo[key]) + yield from _gen(head_tail) + yield from _gen(head) + yield from _gen(tail) + all_decomp = dict(_gen(seq)) + return all_decomp + + all_decomp1 = gen_decomp_v2(full_seq1, open_to_close) + all_decomp2 = gen_decomp_v2(full_seq2, open_to_close) + + key_decomp1 = {} + key_decomp2 = {} + _results = {} + # Populate base cases + empty1 = type(ub.peek(all_decomp1.keys()))() + empty2 = type(ub.peek(all_decomp2.keys()))() + empty1_key = hash(empty1) + empty2_key = hash(empty2) + best = (empty1, empty2) + base_result = (0, best) + for seq1, info1 in all_decomp1.items(): + seq1_key = hash(seq1) + head1_key, tail1_key, head_tail1_key = all_decomp1[seq1][5:8] + _results[(seq1_key, empty2_key)] = base_result + _results[(head1_key, empty2_key)] = base_result + _results[(tail1_key, empty2_key)] = base_result + _results[(head_tail1_key, empty2_key)] = base_result + key_decomp1[seq1_key] = info1 + + for seq2, info2 in all_decomp2.items(): + seq2_key = hash(seq2) + head2_key, tail2_key, head_tail2_key = all_decomp2[seq2][5:8] + _results[(empty1_key, seq2_key)] = base_result + _results[(empty1_key, head2_key)] = base_result + _results[(empty1_key, tail2_key)] = base_result + _results[(empty1_key, head_tail2_key)] = base_result + key_decomp2[seq2_key] = info2 + + full_seq1_key = hash(full_seq1) + full_seq2_key = hash(full_seq2) + key0 = (full_seq1_key, full_seq2_key) + frame0 = key0, full_seq1, full_seq2 + stack = [frame0] + missing_frames = [] + num_misses = 0 + while stack: + frame = stack.pop() + key, seq1, seq2 = frame + seq1_key, seq2_key = key + if key not in _results: + missing_frames.clear() + + try: + info1 = key_decomp1[seq1_key] + except KeyError: + info1 = decomp_info(seq1, open_to_close) + key_decomp1[seq1_key] = info1 + tok1, seq1, head1, tail1, head_tail1, head1_key, tail1_key, head_tail1_key, a1, b1 = info1 + + try: + info2 = key_decomp2[seq2_key] + except KeyError: + info2 = decomp_info(seq2, open_to_close) + key_decomp2[seq2_key] = info2 + tok2, seq2, head2, tail2, head_tail2, head2_key, tail2_key, head_tail2_key, a2, b2 = info2 + + affinity = node_affinity(tok1, tok2) + + # Case 2: The current edge in sequence1 is deleted + try: + try_key = (head_tail1_key, seq2_key) + cand1 = _results[try_key] + except KeyError: + miss_frame = try_key, head_tail1, seq2 + missing_frames.append(miss_frame) + + # Case 3: The current edge in sequence2 is deleted + try: + try_key = (seq1_key, head_tail2_key) + cand2 = _results[try_key] + except KeyError: + miss_frame = try_key, seq1, head_tail2 + missing_frames.append(miss_frame) + + # Case 1: The LCS involves this edge + if affinity: + try: + try_key = (head1_key, head2_key) + pval_h, new_heads = _results[try_key] + except KeyError: + miss_frame = try_key, head1, head2 + missing_frames.append(miss_frame) + + try: + try_key = (tail1_key, tail2_key) + pval_t, new_tails = _results[try_key] + except KeyError: + miss_frame = try_key, tail1, tail2 + missing_frames.append(miss_frame) + + if not missing_frames: + new_head1, new_head2 = new_heads + new_tail1, new_tail2 = new_tails + + subseq1 = a1 + new_head1 + b1 + new_tail1 + subseq2 = a2 + new_head2 + b2 + new_tail2 + + res3 = (subseq1, subseq2) + val3 = pval_h + pval_t + affinity + cand3 = (val3, res3) + else: + cand3 = (-1, None) + + if missing_frames: + num_misses += 1 + # We did not solve this frame yet + stack.append(frame) + stack.extend(missing_frames[::-1]) + else: + # We solved the frame + _results[key] = max(cand1, cand2, cand3) + + print('num_misses = {!r}'.format(num_misses)) + + # The stack pop is our solution + (val, best) = _results[key] + found = (best, val) + return found + + +def generate_all_decompositions(seq, open_to_close): + """ + Can doing this a-priori speed up the algorithm? + + open_to_close = {0: 1} + sequence = [0, 0, 0, 1, 1, 1, 0, 1] + open_to_close = {'{': '}', '(': ')', '[': ']'} + seq = '({[[]]})[[][]]{{}}' + pop_open, pop_close, head, tail = balanced_decomp(seq, open_to_close) + + >>> tree = random_ordered_tree(1000) + >>> seq, open_to_close, toks = tree_to_balanced_sequence(tree) + >>> all_decomp = _generate_all_decompositions(seq, open_to_close) + """ + _memo = {} + def _gen(seq): + if not seq: + pass + # yield None + elif seq in _memo: + pass + # yield (seq, _memo[seq]) + else: + pop_open, pop_close, head, tail = balanced_decomp(seq, open_to_close) + head_tail = head + tail + _memo[seq] = (pop_open, pop_close, head, tail, head_tail) + yield (seq, _memo[seq]) + yield from _gen(head_tail) + yield from _gen(head) + yield from _gen(tail) + all_decomp = dict(_gen(seq)) + return all_decomp + + +@profile +def _lcs_recurse(seq1, seq2, open_to_close, node_affinity, open_to_tok, _memo, _seq_memo): + if not seq1: + return (seq1, seq1), 0 + elif not seq2: + return (seq2, seq2), 0 + else: + # if len(seq2) < len(seq1): + # seq1, seq2 = seq2, seq1 + # key = (seq1, seq2) + key1 = hash(seq1) # using hash(seq) is faster than seq itself + key2 = hash(seq2) + key = hash((key1, key2)) + if key in _memo: + return _memo[key] + + # TODO: we can probably just do a single linear run through the + # sequences to index the sub-sequence locations and then apply an + # offset when we run the decomposed sequence. + if key1 in _seq_memo: + a1, b1, head1, tail1, head1_tail1 = _seq_memo[key1] + else: + a1, b1, head1, tail1 = balanced_decomp_unsafe(seq1, open_to_close) + head1_tail1 = head1 + tail1 + _seq_memo[key1] = a1, b1, head1, tail1, head1_tail1 + + if key2 in _seq_memo: + a2, b2, head2, tail2, head2_tail2 = _seq_memo[key2] + else: + a2, b2, head2, tail2 = balanced_decomp_unsafe(seq2, open_to_close) + head2_tail2 = head2 + tail2 + _seq_memo[key2] = a2, b2, head2, tail2, head2_tail2 + + # Case 2: The current edge in sequence1 is deleted + best, val = _lcs_recurse(head1_tail1, seq2, open_to_close, node_affinity, open_to_tok, _memo, _seq_memo) + + # Case 3: The current edge in sequence2 is deleted + cand, val_alt = _lcs_recurse(seq1, head2_tail2, open_to_close, node_affinity, open_to_tok, _memo, _seq_memo) + if val_alt > val: + best = cand + val = val_alt + + # Case 1: The LCS involves this edge + t1 = open_to_tok[a1[0]] + t2 = open_to_tok[a2[0]] + # if node_affinity(a1[0], a2[0]): + affinity = node_affinity(t1, t2) + if affinity: + new_heads, pval_h = _lcs_recurse(head1, head2, open_to_close, node_affinity, open_to_tok, _memo, _seq_memo) + new_tails, pval_t = _lcs_recurse(tail1, tail2, open_to_close, node_affinity, open_to_tok, _memo, _seq_memo) + + new_head1, new_head2 = new_heads + new_tail1, new_tail2 = new_tails + + subseq1 = a1 + new_head1 + b1 + new_tail1 + subseq2 = a2 + new_head2 + b2 + new_tail2 + + cand = (subseq1, subseq2) + val_alt = pval_h + pval_t + affinity + if val_alt > val: + best = cand + val = val_alt + + found = (best, val) + _memo[key] = found + return found + + +class UnbalancedException(Exception): + pass + + +def balanced_decomp(sequence, open_to_close): + """ + Note this is not exactly the same as the decomposition in the paper. + That is because we also return the "wrapping" element, and we let the + user do the head + tail concatenation. + + Example: + >>> open_to_close = {0: 1} + >>> sequence = [0, 0, 0, 1, 1, 1, 0, 1] + >>> open_to_close = {'{': '}', '(': ')', '[': ']'} + >>> sequence = '({[[]]})[[][]]' + >>> a1, b1, head, tail = balanced_decomp(sequence, open_to_close) + >>> a2, b2, tail1, tail2 = balanced_decomp(tail, open_to_close) + """ + gen = generate_balance(sequence, open_to_close) + + bal_curr, tok_curr = next(gen) + pop_open = sequence[0:1] + want_close = open_to_close[tok_curr] + + head_stop = 1 + for head_stop, (bal_curr, tok_curr) in enumerate(gen, start=1): + if tok_curr is None: + break + elif bal_curr and tok_curr == want_close: + pop_close = sequence[head_stop:head_stop + 1] + break + head = sequence[1:head_stop] + # if __debug__: + # list(gen) # exhaust the generator to check we are balanced + tail = sequence[head_stop + 1:] + return pop_open, pop_close, head, tail + + +def tree_to_balanced_sequence(tree, open_to_close=None, toks=None, mode='tuple', strhack=False): + from collections import namedtuple + Token = namedtuple('Token', ['action', 'value']) + # mapping between opening and closing tokens + sources = [n for n in tree.nodes if tree.in_degree[n] == 0] + sequence = [] + + if open_to_close is None: + open_to_close = {} + if toks is None: + toks = {} + + if strhack: + if mode == 'label': + all_labels = {n['label'] for n in list(tree.nodes.values())} + assert all(x == 1 for x in map(len, all_labels)) + + for source in sources: + for u, v, etype in nx.dfs_labeled_edges(tree, source=source): + if etype == 'forward': + # u has been visited by v has not + if v not in toks: + if mode == 'tuple': + # TODO: token encoding scheme where subdirectories + # are matchable via a custom operation. + # open_tok = '<{}>'.format(v) + # close_tok = ''.format(v) + open_tok = Token('open', v) + close_tok = Token('close', v) + elif mode == 'number': + open_tok = len(toks) + 1 + close_tok = -open_tok + elif mode == 'paren': + open_tok = '{}('.format(v) + close_tok = '){}'.format(v) + elif mode == 'chr': + open_tok = str(v) + close_tok = str(v) + u'\u0301' + elif mode == 'label': + open_tok = tree.nodes[v]['label'] + assert strhack + if open_tok == '{': + close_tok = '}' + if open_tok == '[': + close_tok = ']' + if open_tok == '(': + close_tok = ')' + toks[v] = open_tok + open_to_close[open_tok] = close_tok + open_tok = toks[v] + sequence.append(open_tok) + elif etype == 'reverse': + # Both u and v are visited and the edge is in the tree + close_tok = open_to_close[toks[v]] + sequence.append(close_tok) + else: + raise KeyError(etype) + sequence = tuple(sequence) + if strhack: + sequence = ''.join(sequence) + return sequence, open_to_close, toks + + +def seq_to_tree(subseq, open_to_close, toks): + open_to_tok = ub.invert_dict(toks) + subtree = nx.OrderedDiGraph() + stack = [] + for token in subseq: + if token in open_to_close: + node = open_to_tok[token] + if stack: + parent = open_to_tok[stack[-1]] + subtree.add_edge(parent, node) + else: + subtree.add_node(node) + stack.append(token) + else: + if not stack: + raise Exception + prev_open = stack.pop() + want_close = open_to_close[prev_open] + if token != want_close: + raise Exception + return subtree + + +def random_ordered_tree(n, seed=None, pool=None): + import kwarray + rng = kwarray.ensure_rng(seed, 'python') + tree = nx.dfs_tree(nx.random_tree(n, seed=seed)) + otree = nx.OrderedDiGraph() + otree.add_edges_from(tree.edges) + if pool is not None: + for node in otree.nodes: + otree.nodes[node]['label'] = rng.choice(pool) + return otree + + +def generate_balance_unsafe(sequence, open_to_close): + """ + Benchmark: + >>> tree = random_ordered_tree(1000) + >>> sequence, open_to_close, toks = tree_to_balanced_sequence(tree, mode='tuple') + >>> sequence, open_to_close, toks = tree_to_balanced_sequence(tree, mode='number') + >>> import timerit + >>> ti = timerit.Timerit(100, bestof=10, verbose=2) + >>> for timer in ti.reset('time'): + >>> with timer: + >>> list(generate_balance_unsafe(sequence, open_to_close)) + >>> import timerit + >>> ti = timerit.Timerit(100, bestof=10, verbose=2) + >>> for timer in ti.reset('time'): + >>> with timer: + >>> list(generate_balance_unsafe_cython(sequence, open_to_close)) + """ + stacklen = 0 + for token in sequence: + if token in open_to_close: + stacklen += 1 + else: + stacklen -= 1 + yield stacklen == 0, token + + +def balanced_decomp_unsafe(sequence, open_to_close): + """ + Example: + >>> open_to_close = {'{': '}', '(': ')', '[': ']'} + >>> sequence = '({[[]]})[[][]]' + >>> print('sequence = {!r}'.format(sequence)) + >>> a1, b1, head, tail = balanced_decomp(sequence, open_to_close) + >>> print('a1 = {!r}'.format(a1)) + >>> print('tail = {!r}'.format(tail)) + >>> print('head = {!r}'.format(head)) + >>> a2, b2, tail1, tail2 = balanced_decomp(tail, open_to_close) + >>> print('a2 = {!r}'.format(a2)) + >>> print('tail1 = {!r}'.format(tail1)) + >>> print('tail2 = {!r}'.format(tail2)) + """ + gen = generate_balance_unsafe(sequence, open_to_close) + + bal_curr, tok_curr = next(gen) + pop_open = sequence[0:1] + want_close = open_to_close[tok_curr] + + head_stop = 1 + for head_stop, (bal_curr, tok_curr) in enumerate(gen, start=1): + if bal_curr and tok_curr == want_close: + pop_close = sequence[head_stop:head_stop + 1] + break + head = sequence[1:head_stop] + tail = sequence[head_stop + 1:] + return pop_open, pop_close, head, tail + + +def generate_balance(sequence, open_to_close): + """ + Safe version + + Example: + >>> open_to_close = {0: 1} + >>> sequence = [0, 0, 0, 1, 1, 1] + >>> gen = list(generate_balance(sequence, open_to_close)) + >>> for flag, token in gen: + >>> print('flag={:d}, token={}'.format(flag, token)) + + Example: + >>> tree = random_ordered_tree(1000) + >>> sequence, open_to_close, toks = tree_to_balanced_sequence(tree) + >>> gen = list(generate_balance(sequence, open_to_close)) + >>> for flag, token in gen: + >>> print('flag={:d}, token={}'.format(flag, token)) + """ + stack = [] + # Traversing the Expression + for token in sequence: + + if token in open_to_close: + # Push opening elements onto the stack + stack.append(token) + else: + # Check that closing elements + if not stack: + raise UnbalancedException + prev_open = stack.pop() + want_close = open_to_close[prev_open] + + if token != want_close: + raise UnbalancedException + + # If the stack is empty the sequence is currently balanced + currently_balanced = not bool(stack) + yield currently_balanced, token + + if stack: + raise UnbalancedException + + +def _print_forest(graph): + """ + Nice ascii representation of a forest + + Ignore: + graph = nx.balanced_tree(r=2, h=3, create_using=nx.DiGraph) + _print_forest(graph) + + graph = CategoryTree.demo('coco').graph + _print_forest(graph) + """ + if len(graph.nodes) == 0: + print('--') + return + assert nx.is_forest(graph) + + def _recurse(node, indent='', islast=False): + if islast: + this_prefix = indent + '└── ' + next_prefix = indent + ' ' + else: + this_prefix = indent + '├── ' + next_prefix = indent + '│   ' + label = graph.nodes[node].get('label', node) + print(this_prefix + str(label)) + graph.succ[node] + children = graph.succ[node] + for idx, child in enumerate(children, start=1): + islast_next = (idx == len(children)) + _recurse(child, indent=next_prefix, islast=islast_next) + + sources = [n for n in graph.nodes if graph.in_degree[n] == 0] + for idx, node in enumerate(sources, start=1): + islast_next = (idx == len(sources)) + _recurse(node, indent='', islast=islast_next) + + +__notes_ = """ + + # if 0: + # tuples = [(i + 1, i + 2, i + 3,) for i in range(4)] + # import timerit + + # ti = timerit.Timerit(100, bestof=10, verbose=2) + # import itertools as it + # for timer in ti.reset('time'): + # with timer: + # tuple(it.chain.from_iterable(tuples)) + # for timer in ti.reset('time'): + # with timer: + # res = tuples[0] + # for a in tuples[1:]: + # res = res + a + +""" diff --git a/netharn/initializers/functional.py b/netharn/initializers/functional.py index 1f3665fe8176c0757d1a04f5ddbcd97cbcdadb9f..1c488c509a425b91566922bd03cb5f216b1c1b6b 100644 --- a/netharn/initializers/functional.py +++ b/netharn/initializers/functional.py @@ -123,7 +123,8 @@ def apply_initializer(input, func, funckw): def load_partial_state(model, model_state_dict, leftover=None, ignore_unset=False, verbose=2, - mangle=True, initializer=None): + mangle=True, association=None, + initializer=None): """ CommandLine: python -m netharn.initializers.nninit_base load_partial_state @@ -136,6 +137,10 @@ def load_partial_state(model, model_state_dict, leftover=None, leftover (callable): fallback method for initializing incompatible areas, if none then those areas are left as-is. + association (str): controls how we search for the association between + the two model states. Can be strict, module-hack, prefix-hack, or + embedding. Default is: prefix-hack. + mangle (bool, default=True): If True, mangles tensors that have the same key, but different shapes forcing them to fit. This might destroy information when forcing a a larger tensor into a smaller @@ -151,6 +156,75 @@ def load_partial_state(model, model_state_dict, leftover=None, TODO: - [ ] Allow user to specify how incompatible layers are handled. + Notes: + + Have you ever had the scenario where + + Has anyone ever had a problem where you had a torch model with a state + dict with keys that looked like: `mymodel.detector.layer1.conv.weight`, + but you had a pretrained weight file with keys that looked like: + `module.layer1.conv.weight`? + + The latest version of + `netharn.initializers.functional.load_patial_state` can handle this by + solving a maximum-common-subtree-isomorphism problem. This computes the + largest possible mapping between the two state dictionaries that share + consistent suffixes. + + >>> # This means you can load an off-the-shelf unmodified pretrained resnet50 + >>> # where the keys might look something like this: + >>> resnet_keys = { + >>> 'conv1.weight', + >>> 'layer1.0.conv1.weight', + >>> 'layer1.0.conv2.weight', + >>> 'layer1.0.conv3.weight', + >>> 'layer1.0.downsample.0.weight', + >>> 'layer2.0.conv1.weight', + >>> 'layer2.0.conv2.weight', + >>> 'layer2.0.conv3.weight', + >>> 'layer3.0.conv1.weight', + >>> 'layer4.0.conv1.weight', + >>> 'fc.weight', + >>> 'fc.bias', + >>> } + >>> # + >>> # And perhaps you have a model that has a state dict where keys + >>> # look like this: + >>> model_keys = { + >>> 'preproc.conv1.weight' + >>> 'backbone.layer1.0.conv1.weight', + >>> 'backbone.layer1.0.conv2.weight', + >>> 'backbone.layer1.0.conv3.weight', + >>> 'backbone.layer1.0.downsample.0.weight', + >>> 'backbone.layer2.0.conv1.weight', + >>> 'backbone.layer2.0.conv2.weight', + >>> 'backbone.layer2.0.conv3.weight', + >>> 'backbone.layer3.0.conv1.weight', + >>> 'backbone.layer4.0.conv1.weight', + >>> 'head.conv1' + >>> 'head.conv2' + >>> 'head.fc.weight' + >>> 'head.fc.bias' + >>> } + >>> # + >>> # We can compute a partial mapping between them + >>> subpaths1, subpaths2 = maximum_common_ordered_subpaths(resnet_keys, model_keys) + >>> print(ub.repr2(ub.dzip(subpaths1, subpaths2))) + { + 'layer1.0.conv2.weight': 'backbone.layer1.0.conv2.weight', + 'layer1.0.conv3.weight': 'backbone.layer1.0.conv3.weight', + 'layer1.0.downsample.0.weight': 'backbone.layer1.0.downsample.0.weight', + 'layer2.0.conv1.weight': 'backbone.layer2.0.conv1.weight', + 'layer2.0.conv2.weight': 'backbone.layer2.0.conv2.weight', + 'layer2.0.conv3.weight': 'backbone.layer2.0.conv3.weight', + 'layer3.0.conv1.weight': 'backbone.layer3.0.conv1.weight', + 'layer4.0.conv1.weight': 'backbone.layer4.0.conv1.weight', + } + + Also, if the sizes of the tensor don't quite fit, they will be + mangled, i.e. "shoved-in" as best as possible. + + Example: >>> import netharn as nh >>> self1 = nh.models.ToyNet2d(input_channels=1, num_classes=10) @@ -169,7 +243,48 @@ def load_partial_state(model, model_state_dict, leftover=None, >>> self2 = xpu.mount(self1) >>> load_partial_state(self2, self1.state_dict()) >>> load_partial_state(self1, self2.state_dict()) + >>> # Add extra nonsense to state-dict + >>> extra_state_dict = {'extra.' + k: v for k, v in self1.state_dict().items()} + >>> extra_state_dict['stats'] = ub.peek(extra_state_dict.values()).clone() + >>> model = self2 + >>> model_state_dict = extra_state_dict + >>> load_partial_state(self2, extra_state_dict) + + Example: + >>> # xdoctest: +REQUIRES(--slow) + >>> from netharn.initializers.functional import * # NOQA + >>> import torchvision + >>> import torch + >>> resnet50 = torchvision.models.resnet50() + >>> class CustomModel(torch.nn.Module): + >>> def __init__(self): + >>> super().__init__() + >>> self.module = resnet50 + >>> self.extra = torch.nn.Linear(1, 1) + >>> model = CustomModel() + >>> model_state_dict = resnet50.state_dict() + >>> model_state_dict2 = {'prefix.' + k: v for k, v in model_state_dict.items()} + >>> import ubelt as ub + >>> with ub.Timer(verbose=2, label='strict'): + >>> load_partial_state(model, model_state_dict, association='strict', verbose=0) + >>> with ub.Timer(verbose=2, label='prefix-hack'): + >>> load_partial_state(model, model_state_dict, association='prefix-hack', verbose=0) + >>> with ub.Timer(verbose=2, label='module-hack'): + >>> load_partial_state(model, model_state_dict, association='module-hack', verbose=0) + >>> with ub.Timer(verbose=2, label='embedding'): + >>> load_partial_state(model, model_state_dict, association='embedding', verbose=0) + + >>> load_partial_state(model, model_state_dict, association='prefix-hack', verbose=1) + >>> load_partial_state(model, model_state_dict, association='module-hack', verbose=1) + + CommandLine: + xdoctest -m /home/joncrall/code/netharn/netharn/initializers/functional.py load_partial_state:2 --slow + """ + if association is None: + association = 'module-hack' # old default + # association = 'prefix-hack' # new default + if initializer is not None: import warnings warnings.warn('initializer is deprecated use leftover') @@ -185,21 +300,56 @@ def load_partial_state(model, model_state_dict, leftover=None, """ other_keys = set(model_state_dict) self_keys = set(self_state) - - if not other_keys.intersection(self_keys): - prefix = 'module.' - def smap(f, ss): - return set(map(f, ss)) - def fix1(k): - return prefix + k - def fix2(k): - if k.startswith(prefix): - return k[len(prefix):] - if smap(fix1, other_keys).intersection(self_keys): - model_state_dict = ub.map_keys(fix1, model_state_dict) - elif smap(fix2, other_keys).intersection(self_keys): - model_state_dict = ub.map_keys(fix2, model_state_dict) - + common_keys = other_keys.intersection(self_keys) + if not common_keys: + if association == 'strict': + pass + elif association == 'module-hack': + # If there are no common keys try a hack + prefix = 'module.' + def smap(f, ss): + return set(map(f, ss)) + def fix1(k): + return prefix + k + def fix2(k): + if k.startswith(prefix): + return k[len(prefix):] + if smap(fix1, other_keys).intersection(self_keys): + model_state_dict = ub.map_keys(fix1, model_state_dict) + elif smap(fix2, other_keys).intersection(self_keys): + model_state_dict = ub.map_keys(fix2, model_state_dict) + elif association == 'prefix-hack': + import functools + def add_prefix(k, prefix): + return prefix + k + def remove_prefix(k, prefix): + if k.startswith(prefix): + return k[len(prefix):] + # set1 = other_keys + # target_set2 = self_keys + found = _best_prefix_transform(other_keys, self_keys) + if found is not None: + for action, prefix in found['transform']: + if action == 'add': + func = functools.partial(add_prefix, prefix=prefix) + elif action == 'remove': + func = functools.partial(remove_prefix, prefix=prefix) + else: + raise AssertionError + model_state_dict = ub.map_keys(func, model_state_dict) + elif association == 'embedding': + if verbose > 1: + print('Using subpath embedding assocation, may take some time') + # I believe this is the correct way to solve the problem + paths1 = sorted(other_keys) + paths2 = sorted(self_state) + subpaths1, subpaths2 = maximum_common_ordered_subpaths(paths1, paths2) + mapping = ub.dzip(subpaths1, subpaths2) + if verbose > 1: + print('mapping = {}'.format(ub.repr2(mapping, nl=1))) + model_state_dict = ub.map_keys(lambda k: mapping.get(k, k), model_state_dict) + else: + raise KeyError(association) return model_state_dict other_state = _fix_keys(model_state_dict) @@ -295,6 +445,11 @@ def load_partial_state(model, model_state_dict, leftover=None, print('Seen Keys: {}'.format(ub.repr2(seen_keys, nl=2))) print('Self Unset Keys: {}'.format(ub.repr2(self_unset_keys, nl=1))) print('Other Unused keys: {}'.format(ub.repr2(other_unused_keys, nl=1))) + print('summary:') + seen_sum = ub.map_vals(len, seen_keys) + print('Seen Num: {}'.format(ub.repr2(seen_sum, nl=2))) + print('Self Unset Num: {}'.format(ub.repr2(len(self_unset_keys), nl=1))) + print('Other Unused Num: {}'.format(ub.repr2(len(other_unused_keys), nl=1))) if leftover: if verbose > 0: print('Initializing unused keys using {}'.format(leftover)) @@ -321,3 +476,252 @@ def load_partial_state(model, model_state_dict, leftover=None, 'other_unused': other_unused_keys } return info + + +def _best_prefix_transform(set1, target_set2): + """ + Find a way to transform prefixes of items in set1 to match target_set2 + + Example: + >>> set1 = {'mod.f.0.w', + >>> 'mod.f.1.b', + >>> 'mod.f.1.n', + >>> 'mod.f.1.rm', + >>> 'mod.f.1.rv',} + >>> # + >>> target_set2 = { + >>> 'bar.foo.extra.f.1.b', + >>> 'bar.foo.extra.f.1.n', + >>> 'bar.foo.extra.f.1.w', + >>> 'bar.foo.extra.f.3.w', + >>> } + >>> _best_prefix_transform(set1, target_set2) + >>> target_set2.add('JUNK') + >>> _best_prefix_transform(set1, target_set2) + """ + + # probably an efficient way to do this with a trie + + # NOTE: In general this is a graph-isomorphism problem or a maximum common + # subgraph problem. However, we can look only at the special case of + # "maximum common subtrees". Given two directory structures (as trees) + # we find the common bits. + # https://perso.ensta-paris.fr/~diam/ro/online/viggo_wwwcompendium/node168.html + # We can approximate to O(log log n / log^2 n) + # Can get algorithm from maximum independent set + # https://arxiv.org/abs/1602.07210 + + # The most efficient algorithm here would be for solving + # "Maximum common labeled subtrees" + # APX-hard for unordered trees, but polytime solveable for ordered trees + # For directory structures we can induce an order, and hense obtain a + # polytime solution + # # + # On the Maximum Common Embedded Subtree Problem for Ordered Trees + # https://pdfs.semanticscholar.org/0b6e/061af02353f7d9b887f9a378be70be64d165.pdf + + from os.path import commonprefix + prefixes1 = commonprefix(list(set1)).split('.') + prefixes2 = commonprefix(list(target_set2)).split('.') + + # Remove the trailing prefixes that are the same + num_same = 0 + for i in range(1, min(len(prefixes1), len(prefixes2))): + if prefixes1[-i] == prefixes2[-i]: + num_same = i + else: + break + prefixes1 = prefixes1[:-num_same] + prefixes2 = prefixes2[:-num_same] + + ALLOW_FUZZY = 1 + if ALLOW_FUZZY and len(prefixes2) == 0: + # SUPER HACK FOR CASE WHERE THERE IS JUST ONE SPOILER ELEMENT IN THE + # TARGET SET. THE ALGORITHM NEEDS TO BE RETHOUGHT FOR THAT CASE + possible_prefixes = [k.split('.') for k in target_set2] + prefix_hist = ub.ddict(lambda: 0) + for item in possible_prefixes: + for i in range(1, len(item)): + prefix_hist[tuple(item[0:i])] += 1 + prefixes2 = ['.'.join(ub.argmax(prefix_hist))] + + def add_prefix(items, prefix): + return {prefix + k for k in items} + def remove_prefix(items, prefix): + return {k[len(prefix):] if k.startswith(prefix) else k for k in items} + + import itertools as it + found_cand = [] + for i1, i2 in it.product(range(len(prefixes1) + 1), range(len(prefixes2) + 1)): + if i1 == 0 and i2 == 0: + continue + # Very inefficient, we should be able to do better + prefix1 = '.'.join(prefixes1[:i1]) + prefix2 = '.'.join(prefixes2[:i2]) + if prefix1: + prefix1 = prefix1 + '.' + if prefix2: + prefix2 = prefix2 + '.' + + # We are allowed to remove a prefix from a set, add the other + # prefix to the set, or remove and then add. + set1_cand1 = remove_prefix(set1, prefix1) + set1_cand2 = add_prefix(set1, prefix2) + set1_cand3 = add_prefix(set1_cand1, prefix2) + + common1 = set1_cand1 & target_set2 + common2 = set1_cand2 & target_set2 + common3 = set1_cand3 & target_set2 + if common1: + found_cand.append({ + 'transform': [('remove', prefix1)], + 'value': len(common1), + }) + if common2: + found_cand.append({ + 'transform': [('add', prefix2)], + 'value': len(common2), + }) + if common3: + found_cand.append({ + 'transform': [('remove', prefix1), ('add', prefix2)], + 'value': len(common3), + }) + if len(found_cand): + found = max(found_cand, key=lambda x: x['value']) + else: + found = None + return found + + +def maximum_common_ordered_subpaths(paths1, paths2, sep='.'): + """ + CommandLine: + xdoctest -m /home/joncrall/code/netharn/netharn/initializers/functional.py maximum_common_ordered_subpaths:0 --profile && cat profile_output.txt + xdoctest -m /home/joncrall/code/netharn/netharn/initializers/functional.py maximum_common_ordered_subpaths:0 + + Example: + >>> import torchvision + >>> resnet50 = torchvision.models.resnet50() + >>> paths1 = sorted(resnet50.state_dict().keys())[0:100] + >>> paths2 = ['prefix.' + k for k in paths1] + >>> paths2.append('extra_key') + >>> subpaths1, subpaths2 = maximum_common_ordered_subpaths(paths1, paths2) + >>> mapping = ub.dzip(subpaths1, subpaths2) + >>> print('mapping = {}'.format(ub.repr2(mapping, nl=1))) + + Example: + >>> rng = None + >>> import kwarray + >>> rng = kwarray.ensure_rng(rng) + >>> def random_paths(rng, max_depth=10): + >>> depth = rng.randint(1, max_depth) + >>> parts = list(map(chr, rng.randint(ord('a'), ord('z'), size=depth))) + >>> path = '.'.join(parts) + >>> return path + >>> n = 50 + >>> paths1 = sorted({random_paths(rng) for _ in range(n)}) + >>> paths2 = sorted({random_paths(rng) for _ in range(n)}) + >>> paths1 = paths1 + ['a.' + k for k in paths2[0:n // 3]] + >>> subpaths1, subpaths2 = maximum_common_ordered_subpaths(paths1, paths2) + >>> mapping = ub.dzip(subpaths1, subpaths2) + >>> print('mapping = {}'.format(ub.repr2(mapping, nl=1))) + + Example: + >>> from netharn.initializers.functional import * # NOQA + >>> paths1 = [ + >>> 'stats', + >>> 'z.mod.f.0.w', + >>> 'a.z.mod.f.0.b', + >>> 'z.mod.f.1.b', + >>> 'z.mod.f.1.n', + >>> 'z.mod.f.1.m', + >>> 'z.mod.f.1.v', + >>> 'z.mod.f.2.m', + >>> 'z.mod.z.q' + >>> ] + >>> # paths1 = ['mod'] + >>> # + >>> paths2 = [ + >>> 'stats', + >>> 'bar.f.0.w', + >>> 'bar.foo.extra.z.q', + >>> 'bar.foo.extra', + >>> 'bar.foo.extra.f.1.b', + >>> 'bar.foo.extra.f.1.n', + >>> 'bar.foo.extra.f.1.w', + >>> 'bar.foo.extra.f.3.z', # FIXME we need to handle label comparision operators + >>> # I think we allow labels to match if they have the same suffix + >>> ] + >>> sep = '.' + >>> subpaths1, subpaths2 = maximum_common_ordered_subpaths(paths1, paths2, sep) + >>> mapping = ub.dzip(subpaths1, subpaths2) + >>> print('mapping = {}'.format(ub.repr2(mapping, nl=1))) + + + Example: + >>> sep = '.' + >>> paths1 = ['a.b'] + >>> paths2 = ['a.b'] + >>> subpaths1, subpaths2 = maximum_common_ordered_subpaths(paths1, paths2, sep) + >>> mapping = ub.dzip(subpaths1, subpaths2) + >>> print('mapping = {}'.format(ub.repr2(mapping, nl=1))) + >>> paths1 = ['c.a.b'] + >>> paths2 = ['a.b'] + >>> subpaths1, subpaths2 = maximum_common_ordered_subpaths(paths1, paths2, sep) + >>> mapping = ub.dzip(subpaths1, subpaths2) + >>> print('mapping = {}'.format(ub.repr2(mapping, nl=1))) + >>> paths1 = ['c.a.b', 'c.a.e', 'c.a.q'] + >>> paths2 = ['a.b', 'c.e', 'c.a', 'a.q'] + >>> subpaths1, subpaths2 = maximum_common_ordered_subpaths(paths1, paths2, sep) + >>> mapping = ub.dzip(subpaths1, subpaths2) + >>> print('mapping = {}'.format(ub.repr2(mapping, nl=1))) + """ + import networkx as nx + + # the longest common balanced sequence problem + def _affinity(tok1, tok2): + score = 0 + for t1, t2 in zip(tok1[::-1], tok2[::-1]): + if t1 == t2: + score += 1 + else: + break + return score + # return tok1[-1] == tok2[-1] + node_affinity = _affinity + # import operator + # eq = operator.eq + + def paths_to_tree(paths): + tree = nx.OrderedDiGraph() + for path in sorted(paths): + parts = tuple(path.split(sep)) + node_path = [] + for i in range(1, len(parts) + 1): + node = parts[0:i] + tree.add_node(node) + tree.nodes[node]['label'] = node[-1] + node_path.append(node) + for u, v in ub.iter_window(node_path, 2): + tree.add_edge(u, v) + return tree + + tree1 = paths_to_tree(paths1) + tree2 = paths_to_tree(paths2) + + # _print_forest(tree1) + # _print_forest(tree2) + + # if 0: + # DiGM = isomorphism.DiGraphMatcher(tree1, tree2) + # DiGM.is_isomorphic() + # list(DiGM.subgraph_isomorphisms_iter()) + + from netharn.initializers import _nx_extensions + subtree1, subtree2 = _nx_extensions.maximum_common_ordered_tree_embedding(tree1, tree2, node_affinity=node_affinity) + # subtree1, subtree2 = _nx_extensions.maximum_common_ordered_subtree_isomorphism(tree1, tree2, node_affinity=node_affinity) + + subpaths1 = [sep.join(node) for node in subtree1.nodes if subtree1.out_degree[node] == 0] + subpaths2 = [sep.join(node) for node in subtree2.nodes if subtree2.out_degree[node] == 0] + return subpaths1, subpaths2 diff --git a/netharn/initializers/pretrained.py b/netharn/initializers/pretrained.py index 24f69453d218ec619d6610c5ec768b4c3e5d6d35..751b8de277ce8d91f35cebcc4e491f16159020ec 100644 --- a/netharn/initializers/pretrained.py +++ b/netharn/initializers/pretrained.py @@ -40,6 +40,9 @@ class Pretrained(api.Initializer, ub.NiceRepr): placed in a larger one. Note be careful when mangling a classification layer if class indexes are not aligned. + association (str): controls how we search for the association between + the two model states. Can be strict, module-hack, prefix-hack, or embedding. + info (dict, optional): specify explicit history info initializer (netharn.Initializer): DEPRECATED use the `leftover`. @@ -82,7 +85,7 @@ class Pretrained(api.Initializer, ub.NiceRepr): >>> self(model2) """ def __init__(self, fpath, leftover=None, mangle=True, info=None, - initializer=None): + initializer=None, association=None): if initializer is not None: import warnings warnings.warn('Pretrained `initializer` kwarg is deprecated ' @@ -95,6 +98,7 @@ class Pretrained(api.Initializer, ub.NiceRepr): leftover = initializer_[0](**initializer_[1]) self.leftover = leftover + self.association = association self.mangle = mangle self.info = info @@ -196,6 +200,7 @@ class Pretrained(api.Initializer, ub.NiceRepr): info = load_partial_state(raw_model, model_state_dict, leftover=self.leftover, mangle=self.mangle, + association=self.association, verbose=verbose) return info diff --git a/netharn/mixins.py b/netharn/mixins.py index 1b4b8fd244f0db82ab69004fa6491c524ca52cef..f083a0243cc7d303f859081c2affa2aeb6f2301b 100644 --- a/netharn/mixins.py +++ b/netharn/mixins.py @@ -9,6 +9,7 @@ The purpose of this file is to contain functions that might not general-purpose enough to add to FitHarn itself, but they are also common enough, where it makes no sense to write them from scratch for each new project. """ +from distutils.version import LooseVersion def _dump_monitor_tensorboard(harn, mode='epoch', special_groupers=['loss'], @@ -53,6 +54,32 @@ def _dump_monitor_tensorboard(harn, mode='epoch', special_groupers=['loss'], out_dpath = ub.ensuredir((train_dpath, 'monitor', 'tensorboard')) + # Write a script that the user can run to + if not ub.WIN32: + reviz_fpath = join(out_dpath, 'revisualize.sh') + reviz_text = ub.codeblock( + ''' + #!/bin/bash + __heredoc__ = """ + Helper script to visualize all of the results in the pkl / json files + in this directory. + """ + REVIZ_DPATH="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" + xdoctest -m netharn.mixins _dump_measures --out_dpath=$REVIZ_DPATH + ''') + with open(reviz_fpath, 'w') as file: + file.write(reviz_text) + try: + import os + import stat + orig_mode = os.stat(reviz_fpath).st_mode + new_flags = stat.S_IXGRP | stat.S_IEXEC + if (new_flags & orig_mode) != new_flags: + new_mode = orig_mode | new_flags + os.chmod(reviz_fpath, new_mode) + except Exception as ex: + print('ex = {!r}'.format(ex)) + tb_data_pickle_fpath = join(out_dpath, 'tb_data.pkl') with open(tb_data_pickle_fpath, 'wb') as file: pickle.dump(tb_data, file) @@ -180,233 +207,243 @@ def _dump_measures(tb_data, out_dpath, mode=None, smoothing=0.0, from os.path import join import numpy as np import kwplot - # kwplot.autompl() - - # TODO: Is it possible to get htop to show this process with some name that - # distinguishes it from the dataloader workers? - # import sys - # import multiprocessing - # if multiprocessing.current_process().name != 'MainProcess': - # if sys.platform.startswith('linux'): - # import ctypes - # libc = ctypes.cdll.LoadLibrary('libc.so.6') - # title = 'Netharn MPL Dump Measures' - # libc.prctl(len(title), title, 0, 0, 0) - - # NOTE: This cause warnings when exeucted as daemon process - # try: - # import seaborn as sbn - # sbn.set() - # except ImportError: - # pass - - valid_modes = ['epoch', 'iter'] - if mode is None: - mode = valid_modes - if ub.iterable(mode): - # Hack: Call with all modes - for mode_ in mode: - _dump_measures(tb_data, out_dpath, mode=mode_, smoothing=smoothing, - ignore_outliers=ignore_outliers) - return - else: - assert mode in valid_modes - - meta = tb_data.get('meta', {}) - nice = meta.get('nice', '?nice?') - special_groupers = meta.get('special_groupers', ['loss']) - - fig = kwplot.figure(fnum=1) - - plot_keys = [key for key in tb_data if - ('train_' + mode in key or - 'vali_' + mode in key or - 'test_' + mode in key or - mode + '_' in key)] - y01_measures = ['_acc', '_ap', '_mAP', '_auc', '_mcc', '_brier', '_mauc'] - y0_measures = ['error', 'loss'] - - keys = set(tb_data.keys()).intersection(set(plot_keys)) - - # print('mode = {!r}'.format(mode)) - # print('tb_data.keys() = {!r}'.format(tb_data.keys())) - # print('plot_keys = {!r}'.format(plot_keys)) - # print('keys = {!r}'.format(keys)) - - def smooth_curve(ydata, beta): - """ - Curve smoothing algorithm used by tensorboard - """ - import pandas as pd - alpha = 1.0 - beta - if alpha <= 0: - return ydata - ydata_smooth = pd.Series(ydata).ewm(alpha=alpha).mean().values - return ydata_smooth - - def inlier_ylim(ydatas): - """ - outlier removal used by tensorboard - """ - low, high = None, None - for ydata in ydatas: - q1 = 0.05 - q2 = 0.95 - low_, high_ = np.quantile(ydata, [q1, q2]) - - # Extrapolate how big the entire span should be based on inliers - inner_q = q2 - q1 - inner_extent = high_ - low_ - extrap_total_extent = inner_extent / inner_q - - # amount of padding to add to either side - missing_p1 = q1 - missing_p2 = 1 - q2 - frac1 = missing_p1 / (missing_p2 + missing_p1) - frac2 = missing_p2 / (missing_p2 + missing_p1) - missing_extent = extrap_total_extent - inner_extent - - pad1 = missing_extent * frac1 - pad2 = missing_extent * frac2 - - low_ = low_ - pad1 - high_ = high_ + pad2 - - low = low_ if low is None else min(low_, low) - high = high_ if high is None else max(high_, high) - return (low, high) - - # Hack values that we don't apply smoothing to - HACK_NO_SMOOTH = ['lr', 'momentum'] - - def tag_grouper(k): - # parts = ['train_epoch', 'vali_epoch', 'test_epoch'] - # parts = [p.replace('epoch', 'mode') for p in parts] - parts = [p + mode for p in ['train_', 'vali_', 'test_']] - for p in parts: - if p in k: - return p.split('_')[0] - return 'unknown' - - GROUP_LOSSES = True - GROUP_AND_INDIVIDUAL = False - INDIVIDUAL_PLOTS = True - GROUP_SPECIAL = True - - if GROUP_LOSSES: - # Group all losses in one plot for comparison - loss_keys = [k for k in keys if 'loss' in k] - tagged_losses = ub.group_items(loss_keys, tag_grouper) - tagged_losses.pop('unknown', None) - kw = {} - kw['ymin'] = 0.0 - # print('tagged_losses = {!r}'.format(tagged_losses)) - for tag, losses in tagged_losses.items(): - - min_abs_y = .01 - min_y = 0 - xydata = ub.odict() - for key in sorted(losses): - ydata = tb_data[key]['ydata'] - - if HACK_NO_SMOOTH not in key.split('_'): - ydata = smooth_curve(ydata, smoothing) - - try: - min_y = min(min_y, ydata.min()) - pos_ys = ydata[ydata > 0] - min_abs_y = min(min_abs_y, pos_ys.min()) - except Exception: - pass - - xydata[key] = (tb_data[key]['xdata'], ydata) - - kw['ymin'] = min_y - - if ignore_outliers: - low, kw['ymax'] = inlier_ylim([t[1] for t in xydata.values()]) - - yscales = ['symlog', 'linear'] - for yscale in yscales: - fig.clf() - ax = fig.gca() - title = nice + '\n' + tag + '_' + mode + ' losses' - kwplot.multi_plot(xydata=xydata, ylabel='loss', xlabel=mode, - yscale=yscale, title=title, fnum=1, ax=ax, - **kw) - if yscale == 'symlog': - ax.set_yscale('symlog', linthreshy=min_abs_y) - fname = '_'.join([tag, mode, 'multiloss', yscale]) + '.png' - fpath = join(out_dpath, fname) - ax.figure.savefig(fpath) + import matplotlib as mpl + + from kwplot.auto_backends import BackendContext + + with BackendContext('agg'): + # kwplot.autompl() + + # TODO: Is it possible to get htop to show this process with some name that + # distinguishes it from the dataloader workers? + # import sys + # import multiprocessing + # if multiprocessing.current_process().name != 'MainProcess': + # if sys.platform.startswith('linux'): + # import ctypes + # libc = ctypes.cdll.LoadLibrary('libc.so.6') + # title = 'Netharn MPL Dump Measures' + # libc.prctl(len(title), title, 0, 0, 0) + + # NOTE: This cause warnings when exeucted as daemon process + # try: + # import seaborn as sbn + # sbn.set() + # except ImportError: + # pass + + valid_modes = ['epoch', 'iter'] + if mode is None: + mode = valid_modes + if ub.iterable(mode): + # Hack: Call with all modes + for mode_ in mode: + _dump_measures(tb_data, out_dpath, mode=mode_, smoothing=smoothing, + ignore_outliers=ignore_outliers) + return + else: + assert mode in valid_modes - # don't dump losses individually if we dump them in a group - if not GROUP_AND_INDIVIDUAL: - keys.difference_update(set(loss_keys)) - # print('keys = {!r}'.format(keys)) + meta = tb_data.get('meta', {}) + nice = meta.get('nice', '?nice?') + special_groupers = meta.get('special_groupers', ['loss']) - if GROUP_SPECIAL: - tag_groups = ub.group_items(keys, tag_grouper) - tag_groups.pop('unknown', None) - # Group items matching these strings - kw = {} - for tag, tag_keys in tag_groups.items(): - for groupname in special_groupers: - group_keys = [k for k in tag_keys if groupname in k.split('_')] - if len(group_keys) > 1: - # Gather data for this group - xydata = ub.odict() - for key in sorted(group_keys): - ydata = tb_data[key]['ydata'] - if HACK_NO_SMOOTH not in key.split('_'): - ydata = smooth_curve(ydata, smoothing) - xydata[key] = (tb_data[key]['xdata'], ydata) + fig = kwplot.figure(fnum=1) - if ignore_outliers: - low, kw['ymax'] = inlier_ylim([t[1] for t in xydata.values()]) - - yscales = ['linear'] - for yscale in yscales: - fig.clf() - ax = fig.gca() - title = nice + '\n' + tag + '_' + mode + ' ' + groupname - kwplot.multi_plot(xydata=xydata, ylabel=groupname, xlabel=mode, - yscale=yscale, title=title, fnum=1, ax=ax, - **kw) - if yscale == 'symlog': - ax.set_yscale('symlog', linthreshy=min_abs_y) - fname = '_'.join([tag, mode, 'group-' + groupname, yscale]) + '.png' - fpath = join(out_dpath, fname) - ax.figure.savefig(fpath) + plot_keys = [key for key in tb_data if + ('train_' + mode in key or + 'vali_' + mode in key or + 'test_' + mode in key or + mode + '_' in key)] + y01_measures = [ + '_acc', '_ap', '_mAP', '_auc', '_mcc', '_brier', '_mauc', + ] + y0_measures = ['error', 'loss'] - if not GROUP_AND_INDIVIDUAL: - keys.difference_update(set(group_keys)) + keys = set(tb_data.keys()).intersection(set(plot_keys)) - if INDIVIDUAL_PLOTS: + # print('mode = {!r}'.format(mode)) + # print('tb_data.keys() = {!r}'.format(tb_data.keys())) + # print('plot_keys = {!r}'.format(plot_keys)) # print('keys = {!r}'.format(keys)) - for key in keys: - d = tb_data[key] - - ydata = d['ydata'] - ydata = smooth_curve(ydata, smoothing) + def smooth_curve(ydata, beta): + """ + Curve smoothing algorithm used by tensorboard + """ + import pandas as pd + alpha = 1.0 - beta + if alpha <= 0: + return ydata + ydata_smooth = pd.Series(ydata).ewm(alpha=alpha).mean().values + return ydata_smooth + + def inlier_ylim(ydatas): + """ + outlier removal used by tensorboard + """ + low, high = None, None + for ydata in ydatas: + q1 = 0.05 + q2 = 0.95 + low_, high_ = np.quantile(ydata, [q1, q2]) + + # Extrapolate how big the entire span should be based on inliers + inner_q = q2 - q1 + inner_extent = high_ - low_ + extrap_total_extent = inner_extent / inner_q + + # amount of padding to add to either side + missing_p1 = q1 + missing_p2 = 1 - q2 + frac1 = missing_p1 / (missing_p2 + missing_p1) + frac2 = missing_p2 / (missing_p2 + missing_p1) + missing_extent = extrap_total_extent - inner_extent + + pad1 = missing_extent * frac1 + pad2 = missing_extent * frac2 + + low_ = low_ - pad1 + high_ = high_ + pad2 + + low = low_ if low is None else min(low_, low) + high = high_ if high is None else max(high_, high) + return (low, high) + + # Hack values that we don't apply smoothing to + HACK_NO_SMOOTH = ['lr', 'momentum'] + + def tag_grouper(k): + # parts = ['train_epoch', 'vali_epoch', 'test_epoch'] + # parts = [p.replace('epoch', 'mode') for p in parts] + parts = [p + mode for p in ['train_', 'vali_', 'test_']] + for p in parts: + if p in k: + return p.split('_')[0] + return 'unknown' + + GROUP_LOSSES = True + GROUP_AND_INDIVIDUAL = False + INDIVIDUAL_PLOTS = True + GROUP_SPECIAL = True + + if GROUP_LOSSES: + # Group all losses in one plot for comparison + loss_keys = [k for k in keys if 'loss' in k] + tagged_losses = ub.group_items(loss_keys, tag_grouper) + tagged_losses.pop('unknown', None) kw = {} - if any(m.lower() in key.lower() for m in y01_measures): - kw['ymin'] = 0.0 - kw['ymax'] = 1.0 - elif any(m.lower() in key.lower() for m in y0_measures): - kw['ymin'] = min(0.0, ydata.min()) + kw['ymin'] = 0.0 + # print('tagged_losses = {!r}'.format(tagged_losses)) + for tag, losses in tagged_losses.items(): + + min_abs_y = .01 + min_y = 0 + xydata = ub.odict() + for key in sorted(losses): + ydata = tb_data[key]['ydata'] + + if HACK_NO_SMOOTH not in key.split('_'): + ydata = smooth_curve(ydata, smoothing) + + try: + min_y = min(min_y, ydata.min()) + pos_ys = ydata[ydata > 0] + min_abs_y = min(min_abs_y, pos_ys.min()) + except Exception: + pass + + xydata[key] = (tb_data[key]['xdata'], ydata) + + kw['ymin'] = min_y + if ignore_outliers: - low, kw['ymax'] = inlier_ylim([ydata]) - - # NOTE: this is actually pretty slow - fig.clf() - ax = fig.gca() - title = nice + '\n' + key - kwplot.multi_plot(d['xdata'], ydata, ylabel=key, xlabel=mode, - title=title, fnum=1, ax=ax, **kw) - - # png is slightly smaller than jpg for this kind of plot - fpath = join(out_dpath, key + '.png') - # print('save fpath = {!r}'.format(fpath)) - ax.figure.savefig(fpath) + low, kw['ymax'] = inlier_ylim([t[1] for t in xydata.values()]) + + yscales = ['symlog', 'linear'] + for yscale in yscales: + fig.clf() + ax = fig.gca() + title = nice + '\n' + tag + '_' + mode + ' losses' + kwplot.multi_plot(xydata=xydata, ylabel='loss', xlabel=mode, + yscale=yscale, title=title, fnum=1, ax=ax, + **kw) + if yscale == 'symlog': + if LooseVersion(mpl.__version__) >= LooseVersion('3.3'): + ax.set_yscale('symlog', linthresh=min_abs_y) + else: + ax.set_yscale('symlog', linthreshy=min_abs_y) + fname = '_'.join([tag, mode, 'multiloss', yscale]) + '.png' + fpath = join(out_dpath, fname) + ax.figure.savefig(fpath) + + # don't dump losses individually if we dump them in a group + if not GROUP_AND_INDIVIDUAL: + keys.difference_update(set(loss_keys)) + # print('keys = {!r}'.format(keys)) + + if GROUP_SPECIAL: + tag_groups = ub.group_items(keys, tag_grouper) + tag_groups.pop('unknown', None) + # Group items matching these strings + kw = {} + for tag, tag_keys in tag_groups.items(): + for groupname in special_groupers: + group_keys = [k for k in tag_keys if groupname in k.split('_')] + if len(group_keys) > 1: + # Gather data for this group + xydata = ub.odict() + for key in sorted(group_keys): + ydata = tb_data[key]['ydata'] + if HACK_NO_SMOOTH not in key.split('_'): + ydata = smooth_curve(ydata, smoothing) + xydata[key] = (tb_data[key]['xdata'], ydata) + + if ignore_outliers: + low, kw['ymax'] = inlier_ylim([t[1] for t in xydata.values()]) + + yscales = ['linear'] + for yscale in yscales: + fig.clf() + ax = fig.gca() + title = nice + '\n' + tag + '_' + mode + ' ' + groupname + kwplot.multi_plot(xydata=xydata, ylabel=groupname, xlabel=mode, + yscale=yscale, title=title, fnum=1, ax=ax, + **kw) + if yscale == 'symlog': + ax.set_yscale('symlog', linthreshy=min_abs_y) + fname = '_'.join([tag, mode, 'group-' + groupname, yscale]) + '.png' + fpath = join(out_dpath, fname) + ax.figure.savefig(fpath) + + if not GROUP_AND_INDIVIDUAL: + keys.difference_update(set(group_keys)) + + if INDIVIDUAL_PLOTS: + # print('keys = {!r}'.format(keys)) + for key in keys: + d = tb_data[key] + + ydata = d['ydata'] + ydata = smooth_curve(ydata, smoothing) + + kw = {} + if any(m.lower() in key.lower() for m in y01_measures): + kw['ymin'] = 0.0 + kw['ymax'] = 1.0 + elif any(m.lower() in key.lower() for m in y0_measures): + kw['ymin'] = min(0.0, ydata.min()) + if ignore_outliers: + low, kw['ymax'] = inlier_ylim([ydata]) + + # NOTE: this is actually pretty slow + fig.clf() + ax = fig.gca() + title = nice + '\n' + key + kwplot.multi_plot(d['xdata'], ydata, ylabel=key, xlabel=mode, + title=title, fnum=1, ax=ax, **kw) + + # png is slightly smaller than jpg for this kind of plot + fpath = join(out_dpath, key + '.png') + # print('save fpath = {!r}'.format(fpath)) + ax.figure.savefig(fpath) diff --git a/netharn/util/util_misc.py b/netharn/util/util_misc.py index 2611363830205560d073fb5af23e01708b3048fa..74bd19f707e91f31728596dcdac28e87f333225b 100644 --- a/netharn/util/util_misc.py +++ b/netharn/util/util_misc.py @@ -49,9 +49,58 @@ class FlatIndexer(ub.NiceRepr): >>> self.unravel(4) >>> self.ravel(2, 1) """ - def __init__(self, lens): + def __init__(self, lens, cums=None): self.lens = lens - self.cums = np.cumsum(lens) + if cums is None: + self.cums = np.cumsum(lens) + else: + self.cums = cums + + def concat(self, other): + """ + >>> self = FlatIndexer([1, 2, 3]) + >>> self = self.concat(self).concat(self) + >>> len(self) + """ + new_lens = self.lens + other.lens + new_cums = np.concatenate([self.cums, other.cums + self.cums[-1]], axis=0) + new = self.__class__(new_lens, new_cums) + return new + + def subslice(self, start, stop): + """ + >>> self = FlatIndexer([3, 7, 9, 4, 5] + [3] * 50) + >>> start = 4 + >>> stop = 150 + >>> self.subslice(start, stop) + >>> self.subslice(0, 10).cums + """ + outer1, inner1 = self.unravel(start) + outer2, inner2 = self.unravel(stop) + return self._subslice(outer1, outer2, inner1, inner2) + + def _subslice(self, outer1, outer2, inner1, inner2): + inner2 = min(self.lens[outer2], inner2) + if outer1 == outer2: + new_lens = [inner2 - inner1] + new_cums = np.array(new_lens) + else: + first = [self.lens[outer1] - inner1] + inner = self.lens[outer1 + 1:outer2] + last = [inner2] + + new_lens = first + inner + last + # Oddly, this is faster than just redoing the cumsum + # or is it now that we added a copy? + new_cums = self.cums[outer1:outer2 + 1].copy() + new_cums -= (new_cums[0] - first[0]) + new_cums[-1] = new_cums[-2] + inner2 + + if new_lens[-1] == 0: + new_lens = new_lens[:-1] + new_cums = new_cums[:-1] + new = self.__class__(new_lens, new_cums) + return new @classmethod def fromlist(cls, items): @@ -68,8 +117,15 @@ class FlatIndexer(ub.NiceRepr): Returns: Tuple[int, int]: outer and inner indices + + Example: + >>> self = FlatIndexer([1, 1]) + >>> index = 2 + >>> self.unravel(2) """ - outer = np.where(self.cums > index)[0][0] + found = np.where(self.cums > index)[0] + # Keep indexing past the end of the last bucket for slicing + outer = found[0] if len(found) else len(self.cums) - 1 base = self.cums[outer] - self.lens[outer] inner = index - base return (outer, inner) diff --git a/requirements/runtime.txt b/requirements/runtime.txt index 39edb49259b60babcb017d3e881a2ef18178e5f3..b947a61320911240c5c7914a1c7bf0ad171a8e37 100644 --- a/requirements/runtime.txt +++ b/requirements/runtime.txt @@ -14,7 +14,7 @@ ubelt >= 0.8.4 parse >= 1.8.4 pyflakes >= 1.6.0 astunparse >= 1.6.1 -pygtrie >= 2.2a +pygtrie >= 2.3.3 imageio > 2.6.0;python_version > '3.0' imageio < 2.8.0;python_version < '3.0' diff --git a/super_setup.py b/super_setup.py index 650e049470f03d364f04f2e850bd4742dc292f1c..aa6b163dc61b75fc9a41bf2deb8624ea1cc0bcb6 100755 --- a/super_setup.py +++ b/super_setup.py @@ -15,13 +15,41 @@ try: import click import git as gitpython import functools -except ImportError as ex: - print('ex = {!r}'.format(ex)) - print('!!!!!!!!!') - print('NEED TO INSTALL SUPER SETUP DEPENDENCIES. RUN:') - print('pip install -r requirements/super_setup.txt') - print('!!!!!!!!!') - raise +except Exception as ex: + + ALLOW_FALLBACK = True + if ALLOW_FALLBACK: + print('Attempting to install requirements/super_setup.txt') + import subprocess + import sys + try: + super_setup_dpath = dirname(__file__) + except NameError: + super_setup_dpath = '.' # For Ipython (assume in repo root) + superreq_fpath = join(super_setup_dpath, 'requirements/super_setup.txt') + args = [sys.executable, '-m', 'pip', 'install', '-r', superreq_fpath] + proc = subprocess.Popen(args) + out, err = proc.communicate() + print(out) + print(err) + + try: + import ubelt as ub + import click + import git as gitpython + import functools + except Exception: + FALLBACK_FAILED = True + else: + FALLBACK_FAILED = False + + if FALLBACK_FAILED: + print('ex = {!r}'.format(ex)) + print('!!!!!!!!!') + print('NEED TO INSTALL SUPER SETUP DEPENDENCIES. RUN:') + print('pip install -r requirements/super_setup.txt') + print('!!!!!!!!!') + raise class ShellException(Exception): @@ -683,57 +711,55 @@ def determine_code_dpath(): def make_netharn_registry(): code_dpath = determine_code_dpath() CommonRepo = functools.partial(Repo, code_dpath=code_dpath) - repos = [ + devel_repos = [ # The util libs - CommonRepo( - name='kwarray', branch='dev/0.5.10', remote='public', - remotes={'public': 'git@gitlab.kitware.com:computer-vision/kwarray.git'}, - ), - CommonRepo( - name='kwimage', branch='dev/0.6.4', remote='public', - remotes={'public': 'git@gitlab.kitware.com:computer-vision/kwimage.git'}, - ), - # CommonRepo( # TODO - # name='kwannot', branch='dev/0.1.0', remote='public', - # remotes={'public': 'git@gitlab.kitware.com:computer-vision/kwannot.git'}, - # ), - CommonRepo( - name='kwcoco', branch='dev/0.1.4', remote='public', - remotes={'public': 'git@gitlab.kitware.com:computer-vision/kwcoco.git'}, - ), - CommonRepo( - name='kwplot', branch='dev/0.4.7', remote='public', - remotes={'public': 'git@gitlab.kitware.com:computer-vision/kwplot.git'}, - ), + { + 'name': 'kwarray', 'branch': 'dev/0.5.10', 'remote': 'public', + 'remotes': {'public': 'git@gitlab.kitware.com:computer-vision/kwarray.git'}, + }, + { + 'name': 'kwimage', 'branch': 'dev/0.6.6', 'remote': 'public', + 'remotes': {'public': 'git@gitlab.kitware.com:computer-vision/kwimage.git'}, + }, + { + 'name': 'kwcoco', 'branch': 'dev/0.1.6', 'remote': 'public', + 'remotes': {'public': 'git@gitlab.kitware.com:computer-vision/kwcoco.git'}, + }, + { + 'name': 'kwplot', 'branch': 'dev/0.4.8', 'remote': 'public', + 'remotes': {'public': 'git@gitlab.kitware.com:computer-vision/kwplot.git'}, + }, # Pytorch deployer / exporter - CommonRepo( - name='liberator', branch='dev/0.0.2', remote='public', - remotes={'public': 'git@gitlab.kitware.com:python/liberator.git'}, - ), - CommonRepo( - name='torch_liberator', branch='dev/0.0.4', remote='public', - remotes={'public': 'git@gitlab.kitware.com:computer-vision/torch_liberator.git'}, - ), - + { + 'name': 'liberator', 'branch': 'dev/0.0.2', 'remote': 'public', + 'remotes': {'public': 'git@gitlab.kitware.com:python/liberator.git'}, + }, + { + 'name': 'torch_liberator', 'branch': 'dev/0.0.5', 'remote': 'public', + 'remotes': {'public': 'git@gitlab.kitware.com:computer-vision/torch_liberator.git'}, + }, # For example data and CLI - CommonRepo( - name='scriptconfig', branch='dev/0.5.7', remote='public', - remotes={'public': 'git@gitlab.kitware.com:utils/scriptconfig.git'}, - ), - CommonRepo( - name='ndsampler', branch='dev/0.5.10', remote='public', - remotes={'public': 'git@gitlab.kitware.com:computer-vision/ndsampler.git'}, - ), + { + 'name': 'scriptconfig', 'branch': 'dev/0.5.8', 'remote': 'public', + 'remotes': {'public': 'git@gitlab.kitware.com:utils/scriptconfig.git'}, + }, + { + 'name': 'ndsampler', 'branch': 'dev/0.5.12', 'remote': 'public', + 'remotes': {'public': 'git@gitlab.kitware.com:computer-vision/ndsampler.git'}, + }, # netharn - training harness - CommonRepo( - name='netharn', branch='dev/0.5.8', remote='public', - remotes={'public': 'git@gitlab.kitware.com:computer-vision/netharn.git'}, - ), + { + 'name': 'netharn', 'branch': 'dev/0.5.9', 'remote': 'public', + 'remotes': {'public': 'git@gitlab.kitware.com:computer-vision/netharn.git'}, + }, ] + + repos = [CommonRepo(**kw) for kw in devel_repos] + registery = RepoRegistry(repos) return registery