diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 51383c714f7304a0a0ebbd24ae442c2a31f61cf8..50a714155367575b12ac0e8f1ae3556efc0826c0 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -303,58 +303,3 @@ test_full/cp36-cp36m-linux: # image: # python:3.6 - -# --------------- -# Python 3.5 Jobs - -build/cp35-cp35m-linux: - <<: - - *build_template - image: - python:3.5 - -test_full/cp35-cp35m-linux: - <<: - - *test_full_template - image: - python:3.5 - -#gpgsign/cp35-cp35m-linux: -# <<: -# - *gpgsign_template -# image: -# python:3.5 - -#deploy/cp35-cp35m-linux: -# <<: -# - *deploy_template -# image: -# python:3.5 - - -# --------------- -# Python 2.7 Jobs - -#build/cp27-cp27mu-linux: -# <<: -# - *build_template -# image: -# python:2.7 - -#test_full/cp27-cp27mu-linux: -# <<: -# - *test_full_template -# image: -# python:2.7 - -#gpgsign/cp27-cp27mu-linux: -# <<: -# - *gpgsign_template -# image: -# python:2.7 - -#deploy/cp27-cp27mu-linux: -# <<: -# - *deploy_template -# image: -# python:2.7 diff --git a/CHANGELOG.md b/CHANGELOG.md index 5ad6538cf54b8f33dcc609607f0aa0cf78ab547c..627f40839f59b3dfc96ac542b6e43e91319b5663 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,7 +4,41 @@ This changelog follows the specifications detailed in: [Keep a Changelog](https: This project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html), although we have not yet reached a `1.0.0` release. -## Version 0.5.9 - Unreleased +## Version 0.5.10 - Unreleased + + +### Added +* `allow_unicode` option to `FitHarnPreferences`, which can be set to False to + disable utf8 characters in output formatting. + +* `IndexableWalker` in `netharn.util.util_json` (also exists in kwcoco) + +* New helper methods in `data_containers.BatchContainer` + +### Fixed +* Typo: directory `explit_checkpoints` renamed to `explicit_checkpoints`. + +* Fixed bug where epoch 0 would write a snapshot if it failed. + + +### Changed + +* Removed Python 3.5 support + +* ProgIter information will now written to the log file pending release of ubelt 0.9.3. + +* Progress information now includes warmup LR information in the first epoch. + + +### Deprecated +* Deprecate `colored` option in `FitHarnPreferences`. Use `NO_COLOR` environ to + disable ANSI coloring instead. + +* `netharn.export` has been deprecated for `torch_liberator` and `liberator`, + and will be removed in the future. + + +## Version 0.5.9 - Released 2020-08-26 ### Changed diff --git a/README.rst b/README.rst index bb1ac3b7c6eba7a89860d89f640869ff750aaa50..93d4cb38c391d2cdb9c191974079508e9385b1cb 100644 --- a/README.rst +++ b/README.rst @@ -92,8 +92,8 @@ Features (continued) * Hyperparameter tracking: The hash of your hyperparameters determines the directory data will be written to. We also allow for a "nicer" means to manage directory structures. Given a ``HyperParams`` object, we create the - symlink ``{workdir}/fit/nice/{nice}`` which points to - ``{workdir}/fit/runs/{nice}/{hashid}``. + symlink ``{workdir}/fit/name/{name}`` which points to + ``{workdir}/fit/runs/{name}/{hashid}``. * Automatic restarts: Calling ``FitHarn.run`` twice restarts training from where you left off by @@ -265,7 +265,7 @@ useful to look at. Its complexity is more than CIFAR but less than YOLO. >>> 'workdir' : ub.ensure_app_cache_dir('netharn/demo'), >>> 'xpu' : netharn.XPU.coerce('auto'), >>> # workdir is a directory where intermediate results can be saved - >>> # "nice" symlinks /fit/name/ -> ../runs/ + >>> # "name" symlinks /fit/name/ -> ../runs/ >>> # XPU auto select a gpu if idle and VRAM>6GB else a cpu >>> # ================ >>> # Data Components @@ -303,6 +303,7 @@ useful to look at. Its complexity is more than CIFAR but less than YOLO. >>> harn = netharn.FitHarn(hyper) >>> # non-algorithmic behavior preferences (do not change learned models) >>> harn.preferences['num_keep'] = 10 + >>> harn.preferences['auto_prepare_batch'] = True >>> # start training. >>> harn.initialize(reset='delete') # delete removes an existing run >>> harn.run() # note: run calls initialize it hasn't already been called. @@ -315,7 +316,7 @@ Running this code produes the following output: RESET HARNESS BY DELETING EVERYTHING IN TRAINING DIR Symlink: /home/joncrall/.cache/netharn/demo/fit/runs/demo/lnejaaum -> /home/joncrall/.cache/netharn/demo/_mru ... already exists - Symlink: /home/joncrall/.cache/netharn/demo/fit/runs/demo/lnejaaum -> /home/joncrall/.cache/netharn/demo/fit/nice/demo + Symlink: /home/joncrall/.cache/netharn/demo/fit/runs/demo/lnejaaum -> /home/joncrall/.cache/netharn/demo/fit/name/demo ... already exists ... and points to the right place INFO: Initializing tensorboard (dont forget to start the tensorboard server) @@ -324,12 +325,12 @@ Running this code produes the following output: INFO: Exported model topology to /home/joncrall/.cache/netharn/demo/fit/runs/demo/lnejaaum/ToyNet2d_2a3f49.py INFO: Initializing model weights with: INFO: * harn.train_dpath = '/home/joncrall/.cache/netharn/demo/fit/runs/demo/lnejaaum' - INFO: * harn.nice_dpath = '/home/joncrall/.cache/netharn/demo/fit/nice/demo' + INFO: * harn.name_dpath = '/home/joncrall/.cache/netharn/demo/fit/name/demo' INFO: Snapshots will save to harn.snapshot_dpath = '/home/joncrall/.cache/netharn/demo/fit/runs/demo/lnejaaum/torch_snapshots' INFO: ARGV: /home/joncrall/.local/conda/envs/py36/bin/python /home/joncrall/.local/conda/envs/py36/bin/ipython INFO: dont forget to start: - tensorboard --logdir ~/.cache/netharn/demo/fit/nice + tensorboard --logdir ~/.cache/netharn/demo/fit/name INFO: === begin training 0 / 10 : demo === epoch lr:0.0001 │ vloss is unevaluated 0/10... rate=0 Hz, eta=?, total=0:00:00, wall=19:36 EST train loss:0.173 │ 100.00% of 64x8... rate=11762.01 Hz, eta=0:00:00, total=0:00:00, wall=19:36 EST @@ -366,9 +367,9 @@ Running this code produes the following output: INFO: training completed INFO: harn.train_dpath = '/home/joncrall/.cache/netharn/demo/fit/runs/demo/lnejaaum' - INFO: harn.nice_dpath = '/home/joncrall/.cache/netharn/demo/fit/nice/demo' + INFO: harn.name_dpath = '/home/joncrall/.cache/netharn/demo/fit/name/demo' INFO: view tensorboard results for this run via: - tensorboard --logdir ~/.cache/netharn/demo/fit/nice + tensorboard --logdir ~/.cache/netharn/demo/fit/name [DEPLOYER] Deployed zipfpath=/home/joncrall/.cache/netharn/demo/fit/runs/demo/lnejaaum/deploy_ToyNet2d_lnejaaum_009_GAEYQT.zip INFO: wrote single-file deployment to: '/home/joncrall/.cache/netharn/demo/fit/runs/demo/lnejaaum/deploy_ToyNet2d_lnejaaum_009_GAEYQT.zip' INFO: exiting fit harness. @@ -381,7 +382,7 @@ then it would produce this more detailed description of what it was doing: RESET HARNESS BY DELETING EVERYTHING IN TRAINING DIR Symlink: /home/joncrall/.cache/netharn/demo/fit/runs/demo/lnejaaum -> /home/joncrall/.cache/netharn/demo/_mru ... already exists - Symlink: /home/joncrall/.cache/netharn/demo/fit/runs/demo/lnejaaum -> /home/joncrall/.cache/netharn/demo/fit/nice/demo + Symlink: /home/joncrall/.cache/netharn/demo/fit/runs/demo/lnejaaum -> /home/joncrall/.cache/netharn/demo/fit/name/demo ... already exists ... and points to the right place DEBUG: Initialized logging @@ -476,12 +477,12 @@ then it would produce this more detailed description of what it was doing: INFO: Initializing model weights with: DEBUG: calling harn.initializer= INFO: * harn.train_dpath = '/home/joncrall/.cache/netharn/demo/fit/runs/demo/lnejaaum' - INFO: * harn.nice_dpath = '/home/joncrall/.cache/netharn/demo/fit/nice/demo' + INFO: * harn.name_dpath = '/home/joncrall/.cache/netharn/demo/fit/name/demo' INFO: Snapshots will save to harn.snapshot_dpath = '/home/joncrall/.cache/netharn/demo/fit/runs/demo/lnejaaum/torch_snapshots' INFO: ARGV: /home/joncrall/.local/conda/envs/py36/bin/python /home/joncrall/.local/conda/envs/py36/bin/ipython --verbose INFO: dont forget to start: - tensorboard --logdir ~/.cache/netharn/demo/fit/nice + tensorboard --logdir ~/.cache/netharn/demo/fit/name INFO: === begin training 0 / 10 : demo === DEBUG: epoch lr:0.0001 │ vloss is unevaluated epoch lr:0.0001 │ vloss is unevaluated 0/10... rate=0 Hz, eta=?, total=0:00:00, wall=19:56 EST @@ -570,13 +571,24 @@ then it would produce this more detailed description of what it was doing: INFO: training completed INFO: harn.train_dpath = '/home/joncrall/.cache/netharn/demo/fit/runs/demo/lnejaaum' - INFO: harn.nice_dpath = '/home/joncrall/.cache/netharn/demo/fit/nice/demo' + INFO: harn.name_dpath = '/home/joncrall/.cache/netharn/demo/fit/name/demo' INFO: view tensorboard results for this run via: - tensorboard --logdir ~/.cache/netharn/demo/fit/nice + tensorboard --logdir ~/.cache/netharn/demo/fit/name [DEPLOYER] Deployed zipfpath=/home/joncrall/.cache/netharn/demo/fit/runs/demo/lnejaaum/deploy_ToyNet2d_lnejaaum_000_JWPNDC.zip INFO: wrote single-file deployment to: '/home/joncrall/.cache/netharn/demo/fit/runs/demo/lnejaaum/deploy_ToyNet2d_lnejaaum_000_JWPNDC.zip' INFO: exiting fit harness. +Related Packages +================ + +pytorch-lightning (https://github.com/PyTorchLightning/pytorch-lightning) has +very similar goals to netharn. Currently, there are strengths and weaknesses to +both, but in the future I do see one consuming functionality of the other. +Currently (2020-10-21), pytorch-lightning does distributed training better, +whereas netharn's logging and hyperparameter management outshines +pytorch-lightning. + + .. |Pypi| image:: https://img.shields.io/pypi/v/netharn.svg :target: https://pypi.python.org/pypi/netharn diff --git a/dev/debug_memory.py b/dev/debug_memory.py new file mode 100644 index 0000000000000000000000000000000000000000..aea3ff72a96bf85870cab521f83517078816410a --- /dev/null +++ b/dev/debug_memory.py @@ -0,0 +1,259 @@ +""" +Experiment Script Related to Pytorch Memory Leak Issue + +References: + https://github.com/pytorch/pytorch/issues/13246 + https://gist.github.com/mprostock/2850f3cd465155689052f0fa3a177a50 +""" +from torch.utils.data import Dataset, DataLoader +import numpy as np +import torch +import psutil +import ubelt as ub +import sys + + +class DataIter(Dataset): + def __init__(self, storage_mode='numpy', return_mode='tensor', total=24e7): + self.return_mode = return_mode + self.storage_mode = storage_mode + + assert self.return_mode in {'tensor', 'dict', 'tuple', 'list'} + + if storage_mode == 'numpy': + self.data = np.array([x for x in range(int(total))]) + elif storage_mode == 'python': + self.data = [x for x in range(int(total))] + elif storage_mode == 'ndsampler': + import ndsampler + assert total <= 1000 + self.data = ndsampler.CocoSampler.demo('shapes{}'.format(total)) + else: + raise KeyError(storage_mode) + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + if self.storage_mode == 'ndsampler': + data = self.data.load_item(idx)['im'].ravel()[0:1].astype(np.float32) + data_pt = torch.from_numpy(data) + else: + data = self.data[idx] + data = np.array([data], dtype=np.int64) + data_pt = torch.tensor(data) + + if self.return_mode == 'tensor': + item = data_pt + elif self.return_mode == 'dict': + item = { + 'data': data_pt + } + elif self.return_mode == 'tuple': + item = (data_pt,) + elif self.return_mode == 'list': + item = [data_pt] + return item + + +def getsize(*objs): + """ + sum size of object & members. + https://stackoverflow.com/questions/449560/how-do-i-determine-the-size-of-an-object-in-python + """ + import sys + from types import ModuleType, FunctionType + from gc import get_referents + # Custom objects know their class. + # Function objects seem to know way too much, including modules. + # Exclude modules as well. + blocklist = (type, ModuleType, FunctionType) + # if isinstance(obj, blocklist): + # raise TypeError('getsize() does not take argument of type: ' + str(type(obj))) + seen_ids = set() + size = 0 + objects = objs + while objects: + need_referents = [] + for obj in objects: + if not isinstance(obj, blocklist) and id(obj) not in seen_ids: + seen_ids.add(id(obj)) + size += sys.getsizeof(obj) + need_referents.append(obj) + objects = get_referents(*need_referents) + return size, len(seen_ids) + + +def byte_str(num, unit='auto', precision=2): + """ + Automatically chooses relevant unit (KB, MB, or GB) for displaying some + number of bytes. + + Args: + num (int): number of bytes + unit (str): which unit to use, can be auto, B, KB, MB, GB, TB, PB, EB, + ZB, or YB. + + References: + https://en.wikipedia.org/wiki/Orders_of_magnitude_(data) + + Returns: + str: string representing the number of bytes with appropriate units + + Example: + >>> num_list = [1, 100, 1024, 1048576, 1073741824, 1099511627776] + >>> result = ub.repr2(list(map(byte_str, num_list)), nl=0) + >>> print(result) + ['0.00 KB', '0.10 KB', '1.00 KB', '1.00 MB', '1.00 GB', '1.00 TB'] + """ + abs_num = abs(num) + if unit == 'auto': + if abs_num < 2.0 ** 10: + unit = 'KB' + elif abs_num < 2.0 ** 20: + unit = 'KB' + elif abs_num < 2.0 ** 30: + unit = 'MB' + elif abs_num < 2.0 ** 40: + unit = 'GB' + elif abs_num < 2.0 ** 50: + unit = 'TB' + elif abs_num < 2.0 ** 60: + unit = 'PB' + elif abs_num < 2.0 ** 70: + unit = 'EB' + elif abs_num < 2.0 ** 80: + unit = 'ZB' + else: + unit = 'YB' + if unit.lower().startswith('b'): + num_unit = num + elif unit.lower().startswith('k'): + num_unit = num / (2.0 ** 10) + elif unit.lower().startswith('m'): + num_unit = num / (2.0 ** 20) + elif unit.lower().startswith('g'): + num_unit = num / (2.0 ** 30) + elif unit.lower().startswith('t'): + num_unit = num / (2.0 ** 40) + elif unit.lower().startswith('p'): + num_unit = num / (2.0 ** 50) + elif unit.lower().startswith('e'): + num_unit = num / (2.0 ** 60) + elif unit.lower().startswith('z'): + num_unit = num / (2.0 ** 70) + elif unit.lower().startswith('y'): + num_unit = num / (2.0 ** 80) + else: + raise ValueError('unknown num={!r} unit={!r}'.format(num, unit)) + return ub.repr2(num_unit, precision=precision) + ' ' + unit + + +def main(storage_mode='numpy', return_mode='tensor', total=24e5, shuffle=True): + """ + Args: + storage_mode : how the dataset is stored in backend datasets + + return_mode : how each data item is returned + + total : size of backend storage + + """ + mem = psutil.virtual_memory() + start_mem = mem.used + mem_str = byte_str(start_mem) + print('Starting used system memory = {!r}'.format(mem_str)) + + train_data = DataIter( + storage_mode=storage_mode, + return_mode=return_mode, + total=total) + # self = train_data + + if storage_mode == 'numpy': + total_storate_bytes = train_data.data.dtype.itemsize * train_data.data.size + else: + total_storate_bytes = sys.getsizeof(train_data.data) + # total_storate_bytes = getsize(self.data) + print('total_storage_size = {!r}'.format(byte_str(total_storate_bytes))) + + mem = psutil.virtual_memory() + mem_str = byte_str(mem.used - start_mem) + print('After init DataIter memory = {!r}'.format(mem_str)) + + print('shuffle = {!r}'.format(shuffle)) + + num_workers = 2 + train_loader = DataLoader(train_data, batch_size=300, + shuffle=shuffle, + drop_last=True, + pin_memory=False, + num_workers=num_workers) + + used_nbytes = psutil.virtual_memory().used - start_mem + print('After init DataLoader memory = {!r}'.format(byte_str(used_nbytes))) + + if True: + # Estimate peak usage + import gc + all_obj_nbytes, num_objects = getsize(*gc.get_objects()) + python_ptr_size = int((np.log2(sys.maxsize) + 1) / 8) + assert python_ptr_size == 8, 'should be 8 bytes on 64bit python' + all_ptr_nbytes = (num_objects * python_ptr_size) + + prog_nbytes_estimated_1 = all_ptr_nbytes + all_obj_nbytes + prog_nbytes_measured_2 = psutil.virtual_memory().used - start_mem + print('prog_nbytes_estimated_1 = {!r}'.format(byte_str(prog_nbytes_estimated_1))) + print('prog_nbytes_measured_2 = {!r}'.format(byte_str(prog_nbytes_measured_2))) + + peak_bytes_est1 = prog_nbytes_estimated_1 * (num_workers + 1) + peak_bytes_est2 = prog_nbytes_measured_2 * (num_workers + 1) + print('peak_bytes_est1 = {!r}'.format(byte_str(peak_bytes_est1))) + print('peak_bytes_est2 = {!r}'.format(byte_str(peak_bytes_est2))) + + max_bytes = -float('inf') + prog = ub.ProgIter(train_loader) + for item in prog: + used_bytes = psutil.virtual_memory().used - start_mem + max_bytes = max(max_bytes, used_bytes) + prog.set_extra(' Mem=' + byte_str(used_bytes)) + + used_bytes = psutil.virtual_memory().used - start_mem + print('measured final usage: {}'.format(byte_str(used_bytes))) + print('measured peak usage: {}'.format(byte_str(max_bytes))) + + +if __name__ == '__main__': + """ + CommandLine: + python debug_memory.py numpy tensor --total=24e5 --shuffle=True + + cd ~/code/netharn/dev + + python debug_memory.py --storage_mode=numpy --total=24e5 --shuffle=True + python debug_memory.py --storage_mode=numpy --total=24e5 --shuffle=False + python debug_memory.py --storage_mode=python --total=24e5 --shuffle=True + python debug_memory.py --storage_mode=python --total=24e5 --shuffle=False + + python debug_memory.py --storage_mode=ndsampler --total=1000 --shuffle=True + + python debug_memory.py numpy dict 24e5 + python debug_memory.py python list 24e7 + + Conclusions: + + * It seems like it is ok if the return type is a dictionary + the problem seems to be localized to the storage type. + """ + import fire + fire.Fire(main) + +""" + +@VitalyFedyunin Let me see if I understand correctly, when you access an item +in a list you create a new reference to it, which will force its refcount to be +incremented (i.e. be written to). + +pages are typically 4096 bytes. + +""" diff --git a/dev/list_deployed.py b/dev/list_deployed.py new file mode 100644 index 0000000000000000000000000000000000000000..49e20eda1f37466d986a3418f6db4beb30b1daf6 --- /dev/null +++ b/dev/list_deployed.py @@ -0,0 +1,60 @@ +""" +Simple script that prints the deployed models in a given netharn work directory +""" +import scriptconfig as scfg +import ubelt as ub +import glob +from os.path import join, exists + + +class ListDeployedConfig(scfg.Config): + """ + Given a netharn work directory list all deployed models + """ + default = { + 'workdir': scfg.Value(None, help='work directory'), + 'name': scfg.Value(None, help='"nice" name of the run'), + } + + +def main(cmdline=True, **kw): + config = ListDeployedConfig(cmdline=cmdline, default=kw) + print('config = {}'.format(ub.repr2(dict(config), nl=1))) + + runs_dpath = join(config['workdir'], 'fit/runs') + if not exists(runs_dpath): + print('Workdir does not seem to contain a runs dpath') + print('Checking for alternates? TODO') + raise NotImplementedError + + workdirs = [config['workdir']] + for workdir in workdirs: + run_name = config['name'] + if run_name is None: + named_run_dpath = join(runs_dpath, '*') + dpath_exists = exists(named_run_dpath) + print('dpath_exists = {!r}'.format(dpath_exists)) + else: + named_run_dpath = join(runs_dpath, run_name) + dpath_exists = exists(named_run_dpath) + print('dpath_exists = {!r}'.format(dpath_exists)) + + # TODO: do we want to remove deploy.zip symlinks here? + deployed_fpaths = glob.glob(join(named_run_dpath, '*/*.zip')) + + SHRINK = 1 + if SHRINK: + # Make output text smaller, and more likely to work cross-system + deployed_fpaths = [ + ub.shrinkuser(fpath, home='$HOME') + for fpath in deployed_fpaths] + + print('deployed_fpaths = {}'.format(ub.repr2(deployed_fpaths, nl=1))) + + +if __name__ == '__main__': + """ + CommandLine: + python ~/code/netharn/dev/list_deployed.py --workdir $HOME/work/netharn + """ + main() diff --git a/dev/manage_snapshots.py b/dev/manage_snapshots.py index ddc08d99aba16e82d122502a2bfccf3d7cc27032..83b3c03fa87c34da6460a992b8d3b6a61d33e93e 100755 --- a/dev/manage_snapshots.py +++ b/dev/manage_snapshots.py @@ -13,6 +13,7 @@ import numpy as np import os import parse import ubelt as ub +import copy def byte_str(num, unit='auto', precision=2): @@ -129,13 +130,100 @@ def get_file_info(fpath): return info +def _demodata_workdir(): + """ + Make a work directory with various types of sessions + """ + workdir = ub.ensure_app_cache_dir('netharn/tests/sessions') + def _demodata_toy_sesssion(workdir, name='demo_session', lr=1e-4): + """ + workdir = ub.ensure_app_cache_dir('netharn/tests/sessions') + workdir + """ + # This will train a toy model with toy data using netharn + import netharn as nh + hyper = nh.HyperParams(**{ + 'workdir' : ub.ensure_app_cache_dir('netharn/tests/sessions'), + 'name' : name, + 'xpu' : nh.XPU.coerce('cpu'), + 'datasets' : {'train': nh.data.ToyData2d(size=3, rng=0), 'vali': nh.data.ToyData2d(size=3, rng=0)}, + 'loaders' : {'batch_size': 64}, + 'model' : (nh.models.ToyNet2d, {}), + 'optimizer' : (nh.optimizers.SGD, {'lr': lr}), + 'criterion' : (nh.criterions.FocalLoss, {}), + 'initializer' : (nh.initializers.KaimingNormal, {}), + 'monitor' : (nh.Monitor, {'max_epoch': 1}), + }) + harn = nh.FitHarn(hyper) + harn.preferences['use_tensorboard'] = False + harn.preferences['timeout'] = 1 + harn.run() # TODO: make this run faster if we don't need to rerun + + _demodata_toy_sesssion(workdir, name='demo_session1', lr=1e-3) + _demodata_toy_sesssion(workdir, name='demo_session2', lr=1e-3) + _demodata_toy_sesssion(workdir, name='demo_session3', lr=1e-3) + _demodata_toy_sesssion(workdir, name='demo_session2', lr=1e-4) + _demodata_toy_sesssion(workdir, name='demo_session3', lr=1e-4) + _demodata_toy_sesssion(workdir, name='demo_session3', lr=1e-5) + return workdir + + +def collect_sessions(workdir): + """ + Netharn writes all training runs into a work directory under + /fit/runs//. And makes symlinks in + /fit/name/. This collects all sessions within a workdir that + match the filter criteria. + + workdir = _demodata_workdir() + all_sessions = collect_sessions(workdir) + + """ + run_dpath = join(workdir, 'fit', 'runs') + training_dpaths = list(glob.glob(join(run_dpath, '*/*'))) + + all_sessions = [] + for dpath in ub.ProgIter(training_dpaths, desc='collect sessions', freq=1): + session = Session(dpath) + all_sessions.append(session) + return all_sessions + + +class Session(ub.NiceRepr): + """ + NEW: object to maintain info / manipulate a specific training directory + + TODO: + - [ ] Lazy properties + - [ ] Better convinience methods + - [ ] Log parsing + """ + def __init__(session, dpath): + session.dpath = dpath + info, details = session_info(session.dpath) + session.info = info + session.details = details + + def __nice__(session): + return repr(session.info) + + def session_info(dpath): """ Stats about a training session """ info = {} snap_dpath = join(dpath, 'torch_snapshots') - snapshots = os.listdir(snap_dpath) if exists(snap_dpath) else [] + check_dpath = join(dpath, 'checkpoints') + if exists(check_dpath): + snapshots = os.listdir(check_dpath) if exists(check_dpath) else [] + snapshots = [join(check_dpath, fname) for fname in snapshots] + elif exists(snap_dpath): + # Old snapshot directory name + snapshots = os.listdir(snap_dpath) if exists(snap_dpath) else [] + snapshots = [join(snap_dpath, fname) for fname in snapshots] + else: + snapshots = [] dpath = realpath(dpath) if True: @@ -143,23 +231,40 @@ def session_info(dpath): name = basename(dirname(dpath)) info['name'] = name fitdir = dirname(dirname(dirname(dpath))) + target = None name_dpath = join(fitdir, 'name', name) - try: - target = realpath(ub.util_links._readlink(name_dpath)) - except Exception: - target = None + if exists(name_dpath): + try: + target = realpath(ub.util_links._readlink(name_dpath)) + except Exception: + target = None + else: + nice_dpath = join(fitdir, 'nice', name) + if exists(nice_dpath): + try: + target = realpath(ub.util_links._readlink(nice_dpath)) + except Exception: + target = None info['linked'] = (target == dpath) + best_snapshot_fpath = join(dpath, 'best_snapshot.pt') + details = {} + details['best_snapshot'] = best_snapshot_fpath if exists(best_snapshot_fpath) else None + details['deployed'] = [p for p in glob.glob(join(dpath, '*.zip')) if not ub.util_links.islink(p)] + details['snapshots'] = snapshots + info['dpath'] = dpath + info['has_deploy'] = bool(details['deployed']) + info['has_best'] = bool(details['best_snapshot']) info['num_snapshots'] = len(snapshots) info['size'] = float(ub.cmd('du -s ' + dpath)['out'].split('\t')[0]) if len(snapshots) > 0: contents = [join(dpath, c) for c in os.listdir(dpath)] - timestamps = [get_file_info(c)['last_modified'] for c in contents] + timestamps = [get_file_info(c)['last_modified'] for c in contents if exists(c)] unixtime = max(timestamps) dt = datetime.datetime.fromtimestamp(unixtime) info['last_modified'] = dt - return info + return info, details def _devcheck_remove_dead_runs(workdir, dry=True, dead_num_snap_thresh=10, @@ -177,37 +282,35 @@ def _devcheck_remove_dead_runs(workdir, dry=True, dead_num_snap_thresh=10, workdir = '.' import xdev globals().update(xdev.get_func_kwargs(_devcheck_remove_dead_runs)) + + workdir = _demodata_workdir() + _devcheck_remove_dead_runs(workdir) """ - import ubelt as ub - import copy print('Checking for dead / dangling sessions in your runs dir') - # Find if any run directory is empty run_dpath = join(workdir, 'fit', 'runs') - training_dpaths = list(glob.glob(join(run_dpath, '*/*'))) - - all_sessions = [] - for dpath in training_dpaths: - session = session_info(dpath) - all_sessions.append(session) + all_sessions = collect_sessions(workdir) + # Find if any run directory is empty now = datetime.datetime.now() long_time_ago = now - datetime.timedelta(days=safe_num_days) for session in all_sessions: - if session['num_snapshots'] == 0: - session['decision'] = 'bad' - elif session['num_snapshots'] < dead_num_snap_thresh: - dt = session['last_modified'] + info = session.info + if not (info['has_deploy'] or info['num_snapshots'] or info['has_best']): + info['decision'] = 'bad' + elif info['num_snapshots'] < dead_num_snap_thresh: + dt = info.get('last_modified', now) if dt < long_time_ago: - session['decision'] = 'iffy' + info['decision'] = 'iffy' else: - session['decision'] = 'good' + info['decision'] = 'good' else: - session['decision'] = 'good' + info['decision'] = 'good' - nice_groups = ub.group_items(all_sessions, lambda x: x['name']) + all_info = [s.info for s in all_sessions] + nice_groups = ub.group_items(all_info, lambda x: x['name']) for name, group in nice_groups.items(): print(' --- {} --- '.format(name)) group = sorted(group, key=lambda x: x['size']) @@ -221,9 +324,14 @@ def _devcheck_remove_dead_runs(workdir, dry=True, dead_num_snap_thresh=10, # Partion your "name" sessions into broken and live symlinks. # For each live link remember what the real path is. broken_links = [] + nice_dpath = join(workdir, 'fit', 'nice') name_dpath = join(workdir, 'fit', 'name') - for dname in os.listdir(name_dpath): - dpath = join(name_dpath, dname) + dpaths = [] + if exists(name_dpath): + dpaths += [join(name_dpath, d) for d in os.listdir(name_dpath)] + if exists(nice_dpath): + dpaths += [join(nice_dpath, d) for d in os.listdir(nice_dpath)] + for dpath in dpaths: if is_symlink_broken(dpath): broken_links.append(dpath) @@ -233,7 +341,7 @@ def _devcheck_remove_dead_runs(workdir, dry=True, dead_num_snap_thresh=10, if len(os.listdir(dpath)) == 0: empty_dpaths.append(dpath) - decision_groups = ub.group_items(all_sessions, lambda x: x['decision']) + decision_groups = ub.group_items(all_info, lambda x: x['decision']) print('Empty dpaths: {:>4}'.format(len(empty_dpaths))) print('Broken links: {:>4}'.format(len(broken_links))) @@ -255,75 +363,47 @@ def _devcheck_remove_dead_runs(workdir, dry=True, dead_num_snap_thresh=10, for p in empty_dpaths: ub.delete(p) for p in broken_links: - os.unlink(info['dpath']) - + os.unlink(p) -class Session(ub.NiceRepr): - """ - UNFINISHED: - NEW: object to maintain info / manipulate a specific training directory - """ - def __init__(session, dpath): - session.dpath = dpath - session.info = session_info(session.dpath) - def __nice__(session): - return repr(session.info) +def _devcheck_manage_monitor(workdir, dry=True): + all_sessions = collect_sessions(workdir) -def _devcheck_manage_monitor(workdir, dry=True): # Get all the images in the monitor directories # (this is a convention and not something netharn does by default) - run_dpath = join(workdir, 'fit', 'runs') - training_dpaths = list(glob.glob(join(run_dpath, '*/*'))) - - all_sessions = [] - for dpath in training_dpaths: - session = Session(dpath) - all_sessions.append(session) - # UNFINISHED all_files = [] - factor = 100 + # factor = 100 + max_keep = 300 def _choose_action(file_infos): import kwarray file_infos = kwarray.shuffle(file_infos, rng=0) - n_keep = (len(file_infos) // factor) + 1 + n_keep = max_keep + # n_keep = (len(file_infos) // factor) + 1 + # n_keep = min(max_keep, n_keep) for info in file_infos[:n_keep]: info['action'] = 'keep' for info in file_infos[n_keep:]: info['action'] = 'delete' - for session in all_sessions: - dpath = join(session.dpath, 'monitor', 'train', 'batch') - fpaths = list(glob.glob(join(dpath, '*.jpg'))) - file_infos = [{'size': os.stat(p).st_size, 'fpath': p} - for p in fpaths] - _choose_action(file_infos) - all_files.extend(file_infos) - - dpath = join(session.dpath, 'monitor', 'vali', 'batch') - fpaths = list(glob.glob(join(dpath, '*.jpg'))) - file_infos = [{'size': os.stat(p).st_size, 'fpath': p} - for p in fpaths] - _choose_action(file_infos) - all_files.extend(file_infos) - - dpath = join(session.dpath, 'monitor', 'train') - fpaths = list(glob.glob(join(dpath, '*.jpg'))) - file_infos = [{'size': os.stat(p).st_size, 'fpath': p} - for p in fpaths] - _choose_action(file_infos) - all_files.extend(file_infos) - - dpath = join(session.dpath, 'monitor', 'vali') - fpaths = list(glob.glob(join(dpath, '*.jpg'))) - file_infos = [{'size': os.stat(p).st_size, 'fpath': p} - for p in fpaths] - _choose_action(file_infos) - all_files.extend(file_infos) + for session in ub.ProgIter(all_sessions, desc='checking monitor files'): + dpaths = [ + join(session.dpath, 'monitor', 'train', 'batch'), + join(session.dpath, 'monitor', 'vali', 'batch'), + join(session.dpath, 'monitor', 'train'), + join(session.dpath, 'monitor', 'vali'), + ] + exts = ['*.jpg', '*.png'] + for dpath in dpaths: + for ext in exts: + fpaths = list(glob.glob(join(dpath, ext))) + file_infos = [{'size': os.stat(p).st_size, 'fpath': p} + for p in fpaths] + _choose_action(file_infos) + all_files.extend(file_infos) grouped_actions = ub.group_items(all_files, lambda x: x['action']) @@ -336,7 +416,7 @@ def _devcheck_manage_monitor(workdir, dry=True): else: delete = grouped_actions.get('delete', []) delete_fpaths = [item['fpath'] for item in delete] - for p in delete_fpaths: + for p in ub.ProgIter(delete_fpaths, desc='deleting'): ub.delete(p) @@ -368,19 +448,21 @@ def _devcheck_manage_snapshots(workdir, recent=5, factor=10, dry=True): USE_RANGE_HUERISTIC = True - run_dpath = join(workdir, 'fit', 'runs') - snapshot_dpaths = list(glob.glob(join(run_dpath, '**/torch_snapshots'), recursive=True)) - print('checking {} snapshot paths'.format(len(snapshot_dpaths))) + all_sessions = collect_sessions(workdir) + print('Checking sessions = {}'.format(ub.repr2(all_sessions, nl=1))) all_keep = [] all_remove = [] - for snapshot_dpath in snapshot_dpaths: - snapshots = sorted(glob.glob(join(snapshot_dpath, '_epoch_*.pt'))) - epoch_to_snap = { - int(parse.parse('{}_epoch_{num:d}.pt', path).named['num']): path - for path in snapshots - } + for session in all_sessions: + snapshots = session.details['snapshots'] + epoch_to_snap = {} + extra_types = {'prefix': parse.with_pattern('.*')(ub.identity)} + for path in snapshots: + parsed = parse.parse('{:prefix}_epoch_{num:d}.pt', path, extra_types) + if parsed: + epoch = int(parsed.named['num']) + epoch_to_snap[epoch] = path existing_epochs = sorted(epoch_to_snap.keys()) # print('existing_epochs = {}'.format(ub.repr2(existing_epochs))) toremove = [] @@ -411,7 +493,7 @@ def _devcheck_manage_snapshots(workdir, recent=5, factor=10, dry=True): print('keep = {!r}'.format(sorted(keep))) print('kill = {!r}'.format(sorted(kill))) - print('Keep {}/{} from {}'.format(len(keep), len(existing_epochs), snapshot_dpath)) + print('Keep {}/{} from {}'.format(len(keep), len(existing_epochs), session.info['dpath'])) all_keep += [tokeep] all_remove += [toremove] @@ -426,14 +508,14 @@ def _devcheck_manage_snapshots(workdir, recent=5, factor=10, dry=True): for path in ub.flatten(all_remove): total += os.path.getsize(path) - total_mb = total / 2 ** 20 if dry: - print('Cleanup would delete {} snapshots and free {!r} MB'.format(len(all_remove), total_mb)) + print('Cleanup would delete {} snapshots and free {}'.format(len(all_remove), byte_str(total))) print('Use -f to confirm and force cleanup') else: - print('About to free {!r} MB'.format(total_mb)) - for path in ub.flatten(all_remove): - ub.delete(path, verbose=True) + print('About to free {}'.format(byte_str(total))) + fpaths = list(ub.flatten(all_remove)) + for path in ub.ProgIter(fpaths, desc='deleting'): + ub.delete(path) def main(): @@ -481,6 +563,9 @@ if __name__ == '__main__': python ~/code/netharn/dev/manage_snapshots.py --mode=snapshots --workdir=~/work/voc_yolo2/ --recent 2 --factor 40 python ~/code/netharn/dev/manage_snapshots.py --mode=runs --workdir=~/work/voc_yolo2/ python ~/code/netharn/dev/manage_snapshots.py --mode=monitor --workdir=~/work/voc_yolo2/ + python ~/code/netharn/dev/manage_snapshots.py --mode=monitor --workdir=. -f + python ~/code/netharn/dev/manage_snapshots.py --mode=runs --workdir=. + python ~/code/netharn/dev/manage_snapshots.py --mode=snapshots --workdir=. --recent 2 --factor 40 -f Notes: # Remove random files diff --git a/docs/bugs.md b/docs/bugs.md index 0e6f1cda57bd1731ce10771f82f8c6f8d7f80c36..cfb77bef6b2f0fe8c6a3d8ff21e7642c1316ced7 100644 --- a/docs/bugs.md +++ b/docs/bugs.md @@ -4,3 +4,21 @@ * The per-batch iteration metrics seem to jump on the x-axis in the tensorboard logs. Not sure why this is. Perhaps there is a scheduler bug? + + +* If PyQt5 is installed and there is a problem with the matplotlib qt backend + then you may just get an error that crashes the system: + + ``` + QObject::moveToThread: Current thread (0x5636e99e0690) is not the object's thread (0x5636ea1e26b0). + Cannot move to target thread (0x5636e99e0690) + + qt.qpa.plugin: Could not load the Qt platform plugin "xcb" in "$HOME/.local/conda/envs/py38/lib/python3.8/site-packages/cv2/qt/plugins" even though it was found. + This application failed to start because no Qt platform plugin could be initialized. Reinstalling the application may fix this problem. + + Available platform plugins are: xcb, eglfs, linuxfb, minimal, minimalegl, offscreen, vnc, wayland-egl, wayland, wayland-xcomposite-egl, wayland-xcomposite-glx, webgl. + +``` + +The workaround is to uninstall PyQt5, but that's not great. Need to detect that +this will happen before it does so we can warn and avoid it. diff --git a/netharn/__init__.py b/netharn/__init__.py index 77ae9ac9fc5bf22fc7a40f0729ac2083c1f8ddcd..4add55830c2f20ba45bd337d1de7ab7833e2cccc 100644 --- a/netharn/__init__.py +++ b/netharn/__init__.py @@ -4,7 +4,7 @@ mkinit netharn --noattrs --dry mkinit netharn --noattrs """ -__version__ = '0.5.9' +__version__ = '0.5.10' try: # PIL 7.0.0 removed PIL_VERSION, which breaks torchvision, monkey patch it @@ -68,7 +68,6 @@ from netharn import criterions from netharn import data from netharn import device from netharn import exceptions -from netharn import export from netharn import fit_harn from netharn import hyperparams from netharn import initializers @@ -90,7 +89,7 @@ __all__ = ['Criterion', 'Dynamics', 'FitHarn', 'HiddenFields', 'HiddenShapes', 'Optimizer', 'OutputShape', 'OutputShapeFor', 'ReceptiveField', 'ReceptiveFieldFor', 'Scheduler', 'XPU', 'analytic_for', 'api', 'configure_hacks', 'configure_workdir', 'criterions', 'data', - 'device', 'exceptions', 'export', 'fit_harn', 'hyperparams', + 'device', 'exceptions', 'fit_harn', 'hyperparams', 'initializers', 'layers', 'metrics', 'mixins', 'models', 'monitor', 'optimizers', 'output_shape_for', 'prefit', 'receptive_field_for', 'schedulers', 'util'] diff --git a/netharn/analytic/receptive_field_for.py b/netharn/analytic/receptive_field_for.py index 0d10905732d49252b0b58cb671e9c8502c423290..a768ee1eadcabb65b94c3d0152f707a40bde99fa 100644 --- a/netharn/analytic/receptive_field_for.py +++ b/netharn/analytic/receptive_field_for.py @@ -1001,13 +1001,42 @@ def effective_receptive_feild(module, inputs, output_key=None, sigma=0, >>> kwplot.imshow(emperical_field['impact'], doclf=True) Ignore: + >>> def forward(self, x): + >>> # See note [TorchScript super()] + >>> x = self.conv1(x) + >>> x = self.bn1(x) + >>> x = self.relu(x) + >>> x = self.maxpool(x) + >>> # + >>> x = self.layer1(x) + >>> x = self.layer2(x) + >>> x = self.layer3(x) + >>> x = self.layer4(x) + >>> # + >>> #x = self.avgpool(x) + >>> #x = torch.flatten(x, 1) + >>> #x = self.fc(x) + >>> return x >>> xpu = nh.XPU.coerce('auto') - >>> module = xpu.move(torchvision.models.resnet50()) - >>> inputs = xpu.move(torch.rand(8, 3, 224, 224)) - >>> emperical_field = effective_receptive_feild(module, inputs) + >>> module1 = torchvision.models.resnet50() + >>> ub.inject_method(module1, forward) + >>> module1 = xpu.move(module1) + >>> module2 = torchvision.models.resnet50(pretrained=True) + >>> module2 = xpu.move(module2) + >>> ub.inject_method(module2, forward) + >>> import kwimage + >>> img = kwimage.grab_test_image(dsize=(224, 224)) + >>> inputs = torch.from_numpy(img.transpose(2, 0, 1)[None, :] / 255.).float() + >>> inputs = xpu.move(inputs) + >>> #inputs = xpu.move(torch.rand(8, 3, 224, 224)) + >>> ignore_norms = 1 + >>> emperical_field1 = effective_receptive_feild(module1, inputs, ignore_norms=ignore_norms) + >>> emperical_field2 = effective_receptive_feild(module2, inputs, ignore_norms=ignore_norms) >>> import kwplot >>> kwplot.autompl() - >>> kwplot.imshow(emperical_field['impact'], doclf=True) + >>> kwplot.imshow(inputs[0].data.cpu().numpy(), fnum=1, pnum=(1, 3, 1), title='input', doclf=1) + >>> kwplot.imshow(emperical_field1['impact'], fnum=1, pnum=(1, 3, 2), title='pretrained=False') + >>> kwplot.imshow(emperical_field2['impact'], doclf=0, fnum=1, pnum=(1, 3, 3), title='pretrained=True') """ import netharn as nh @@ -1043,10 +1072,10 @@ def effective_receptive_feild(module, inputs, output_key=None, sigma=0, outputs = module(inputs) # Note: grab a single (likely FCN) output channel - if callable(output_key): - output_y = output_key(outputs) - elif output_key is None: + if output_key is None: output_y = outputs + elif callable(output_key): + output_y = output_key(outputs) else: output_y = outputs[output_key] # elif isinstance(output_key, (six.string_types, int)): diff --git a/netharn/api.py b/netharn/api.py index cb75bc607f67f53633a75e2fa85807dcb70b4185..72f7ccf2515aa87049b8d2e7fdf7645522a2dba7 100644 --- a/netharn/api.py +++ b/netharn/api.py @@ -43,6 +43,211 @@ class Datasets(object): return torch_datasets +class DatasetInfo(object): + """ + experimental, attempts to do more heavy lifting + """ + @staticmethod + def coerce(config={}, **kw): + """ + Accepts 'datasets', 'train_dataset', 'vali_dataset', and 'test_dataset'. + + Args: + config (dict | str): coercable configuration dictionary. + """ + config = _update_defaults(config, kw) + dataset_info = _coerce_datasets(config) + return dataset_info + + +def _coerce_datasets(config): + import netharn as nh + import ndsampler + import numpy as np + from torchvision import transforms + coco_datasets = nh.api.Datasets.coerce(config) + print('coco_datasets = {}'.format(ub.repr2(coco_datasets, nl=1))) + for tag, dset in coco_datasets.items(): + dset._build_hashid(hash_pixels=False) + + workdir = ub.ensuredir(ub.expandpath(config['workdir'])) + samplers = { + tag: ndsampler.CocoSampler(dset, workdir=workdir, backend=config['sampler_backend']) + for tag, dset in coco_datasets.items() + } + + for tag, sampler in ub.ProgIter(list(samplers.items()), desc='prepare frames'): + sampler.frames.prepare(workers=config['workers']) + + # TODO: basic ndsampler torch dataset, likely has to support the transforms + # API, bleh. + + transform = transforms.Compose([ + transforms.Resize(config['input_dims']), + transforms.CenterCrop(config['input_dims']), + transforms.ToTensor(), + transforms.Lambda(lambda x: x.mul(255)) + ]) + + torch_datasets = { + key: SamplerDataset( + sapmler, transform=transform, + # input_dims=config['input_dims'], + # augmenter=config['augmenter'] if key == 'train' else None, + ) + for key, sapmler in samplers.items() + } + # self = torch_dset = torch_datasets['train'] + + if config['normalize_inputs']: + # Get stats on the dataset (todo: turn off augmentation for this) + import kwarray + _dset = torch_datasets['train'] + stats_idxs = kwarray.shuffle(np.arange(len(_dset)), rng=0)[0:min(1000, len(_dset))] + stats_subset = torch.utils.data.Subset(_dset, stats_idxs) + + cacher = ub.Cacher('dset_mean', cfgstr=_dset.input_id + 'v3') + input_stats = cacher.tryload() + + from netharn.data.channel_spec import ChannelSpec + channels = ChannelSpec.coerce(config['channels']) + + if input_stats is None: + # Use parallel workers to load data faster + from netharn.data.data_containers import container_collate + from functools import partial + collate_fn = partial(container_collate, num_devices=1) + + loader = torch.utils.data.DataLoader( + stats_subset, + collate_fn=collate_fn, + num_workers=config['workers'], + shuffle=True, + batch_size=config['batch_size']) + + # Track moving average of each fused channel stream + channel_stats = {key: nh.util.RunningStats() + for key in channels.keys()} + assert len(channel_stats) == 1, ( + 'only support one fused stream for now') + for batch in ub.ProgIter(loader, desc='estimate mean/std'): + if isinstance(batch, (tuple, list)): + inputs = {'rgb': batch[0]} # make assumption + else: + inputs = batch['inputs'] + + for key, val in inputs.items(): + try: + for part in val.numpy(): + channel_stats[key].update(part) + except ValueError: # final batch broadcast error + pass + + perchan_input_stats = {} + for key, running in channel_stats.items(): + running = ub.peek(channel_stats.values()) + perchan_stats = running.simple(axis=(1, 2)) + perchan_input_stats[key] = { + 'std': perchan_stats['mean'].round(3), + 'mean': perchan_stats['std'].round(3), + } + + input_stats = ub.peek(perchan_input_stats.values()) + cacher.save(input_stats) + else: + input_stats = {} + + torch_loaders = { + tag: dset.make_loader( + batch_size=config['batch_size'], + num_batches=config['num_batches'], + num_workers=config['workers'], + shuffle=(tag == 'train'), + balance=(config['balance'] if tag == 'train' else None), + pin_memory=True) + for tag, dset in torch_datasets.items() + } + + dataset_info = { + 'torch_datasets': torch_datasets, + 'torch_loaders': torch_loaders, + 'input_stats': input_stats + } + return dataset_info + + +class SamplerDataset(torch.utils.data.Dataset): + def __init__(self, sampler, transform=None, return_style='torchvision'): + self.sampler = sampler + self.transform = transform + self.return_style = return_style + self.input_id = self.sampler.hashid + + def __len__(self): + return len(self.sampler) + + def __getitem__(self, index): + item = self.sampler.load_item(index) + numpy_im = item['im'] + + if self.transform: + from PIL import Image + pil_im = Image.fromarray(numpy_im) + torch_chw = self.transform(pil_im) + else: + torch_chw = torch.from_numpy(numpy_im).permute(2, 0, 1).float() + # raise NotImplementedError + + if self.return_style == 'torchvision': + cid = item['tr']['category_id'] + cidx = self.sampler.classes.id_to_idx[cid] + return torch_chw, cidx + else: + raise NotImplementedError + + def make_loader(self, batch_size=16, num_batches='auto', num_workers=0, + shuffle=False, pin_memory=False, drop_last=False, + balance=None): + + import kwarray + if len(self) == 0: + raise Exception('must have some data') + + def worker_init_fn(worker_id): + import numpy as np + for i in range(worker_id + 1): + seed = np.random.randint(0, int(2 ** 32) - 1) + seed = seed + worker_id + kwarray.seed_global(seed) + # if self.augmenter: + # rng = kwarray.ensure_rng(None) + # self.augmenter.seed_(rng) + + loaderkw = { + 'num_workers': num_workers, + 'pin_memory': pin_memory, + 'worker_init_fn': worker_init_fn, + } + if balance is None: + loaderkw['shuffle'] = shuffle + loaderkw['batch_size'] = batch_size + loaderkw['drop_last'] = drop_last + elif balance == 'classes': + from netharn.data.batch_samplers import BalancedBatchSampler + index_to_cid = [ + cid for cid in self.sampler.regions.targets['category_id'] + ] + batch_sampler = BalancedBatchSampler( + index_to_cid, batch_size=batch_size, + shuffle=shuffle, num_batches=num_batches) + loaderkw['batch_sampler'] = batch_sampler + else: + raise KeyError(balance) + + loader = torch.utils.data.DataLoader(self, **loaderkw) + return loader + + class Initializer(object): """ Base class for all netharn initializers @@ -102,7 +307,7 @@ class Initializer(object): >>> print(ub.repr2(nh.Initializer.coerce(config))) ( , - {... 'fpath': '/fit/nice/untitled', 'leftover': None, 'mangle': True}, + {... 'fpath': '/fit/nice/untitled', 'leftover': None, 'mangle': False}, ) >>> print(ub.repr2(nh.Initializer.coerce({'init': 'kaiming_normal'}))) ( @@ -150,7 +355,7 @@ class Initializer(object): initializer_ = (nh.initializers.Pretrained, { 'fpath': ub.expandpath(config['pretrained_fpath']), 'leftover': kw.get('leftover', None), - 'mangle': kw.get('mangle', True), + 'mangle': kw.get('mangle', False), 'association': kw.get('association', None), }) elif config['init'] == 'cls': @@ -255,25 +460,21 @@ class Optimizer(object): }) else: from netharn.util import util_inspect + _lut = {} + + optim_modules = [ + torch.optim, + ] + try: + # Allow coerce to use torch_optimizer package if available import torch_optimizer except Exception: torch_optimizer = None - - _lut = {} - - if torch_optimizer is not None: - # known = ['AccSGD', 'AdaBound', 'AdaMod', 'DiffGrad', 'Lamb', - # 'Lookahead', 'NovoGrad', 'RAdam', 'SGDW', 'Yogi'] - # if 0: - # for key in known: - # cls = getattr(torch_optimizer, key, None) - # print('cls = {!r}'.format(cls)) - # defaultkw = util_inspect.default_kwargs(cls) - # print('defaultkw = {!r}'.format(defaultkw)) - # _lut.update({k.lower(): k for k in known}) + else: + optim_modules.append(torch_optimizer) _lut.update({ - k: c.__name__ + k.lower(): c.__name__ for k, c in torch_optimizer._NAME_OPTIM_MAP.items()}) _lut.update({ @@ -282,23 +483,18 @@ class Optimizer(object): key = _lut[key] - cls = getattr(torch.optim, key, None) - if cls is not None: - defaultkw = util_inspect.default_kwargs(cls) - kw = defaultkw.copy() - kw.update() - optim_ = (cls, kw) - else: - if torch_optimizer is None: - raise KeyError(key) - cls = getattr(torch_optimizer, key, None) + cls = None + for module in optim_modules: + cls = getattr(module, key, None) if cls is not None: defaultkw = util_inspect.default_kwargs(cls) kw = defaultkw.copy() kw.update() optim_ = (cls, kw) - else: - raise KeyError(key) + break + + if cls is None: + raise KeyError(key) return optim_ diff --git a/netharn/data/channel_spec.py b/netharn/data/channel_spec.py index 08e2f21c57731a3d1b47b929e83839049e3c2bea..06f653d7171f0fbf09692bd4c33305a8d886b0a9 100644 --- a/netharn/data/channel_spec.py +++ b/netharn/data/channel_spec.py @@ -166,6 +166,46 @@ class ChannelSpec(ub.NiceRepr): for spec in stream_specs: yield spec + def streams(self): + """ + Breaks this spec up into one spec for each early-fused input stream + """ + streams = [self.__class__(spec) for spec in self.keys()] + return streams + + def difference(self, other): + """ + Set difference + + Example: + >>> self = ChannelSpec('rgb|disparity,flowx|flowy') + >>> other = ChannelSpec('rgb') + >>> self.difference(other) + >>> other = ChannelSpec('flowx') + >>> self.difference(other) + """ + assert len(list(other.keys())) == 1, 'can take diff with one stream' + other_norm = ub.oset(ub.peek(other.normalize().values())) + self_norm = self.normalize() + + new_streams = [] + for key, parts in self_norm.items(): + new_parts = ub.oset(parts) - ub.oset(other_norm) + # shrink the representation of a complex r|g|b to an alias if + # possible. + # TODO: make this more efficient + for alias, alias_spec in self._known.items(): + alias_parts = ub.oset(alias_spec.split('|')) + index = subsequence_index(new_parts, alias_parts) + if index is not None: + oset_delitem(new_parts, index) + oset_insert(new_parts, index.start, alias) + new_stream = '|'.join(new_parts) + new_streams.append(new_stream) + new_spec = ','.join(new_streams) + new = self.__class__(new_spec) + return new + def sizes(self): """ Number of dimensions for each fused stream channel @@ -244,12 +284,12 @@ class ChannelSpec(ub.NiceRepr): stream. Args: - item (dict): a batch item + item (Dict[str, Tensor]): a batch item containing unfused parts axis (int, default=0): concatenation dimension Returns: - Dict[str, Tensor]: mapping between input stream and its early fused - tensor input. + Dict[str, Tensor]: + mapping between input stream and its early fused tensor input. Example: >>> import torch @@ -284,6 +324,9 @@ class ChannelSpec(ub.NiceRepr): """ break an early fused item into its components + Args: + inputs (Dict[str, Tensor]): dictionary of components + Example: >>> import torch >>> dims = (4, 4) @@ -313,6 +356,129 @@ class ChannelSpec(ub.NiceRepr): idx1 = idx2 return components + def component_indices(self, axis=2): + """ + Look up component indices within fused streams + + Example: + >>> import torch + >>> dims = (4, 4) + >>> inputs = ['flowx', 'flowy', 'disparity'] + >>> self = ChannelSpec('disparity,flowx|flowy') + >>> component_indices = self.component_indices() + >>> print('component_indices = {!r}'.format(component_indices)) + """ + parsed = self.parse() + component_indices = dict() + for key, parts in parsed.items(): + idx1 = 0 + for part in parts: + size = self._size_lut.get(part, 1) + idx2 = idx1 + size + index = ([slice(None)] * axis + [slice(idx1, idx2)]) + idx1 = idx2 + component_indices[part] = (key, index) + return component_indices + + +def subsequence_index(oset1, oset2): + """ + Returns a slice into the first items indicating the position of + the second items if they exist. + + This is a variant of the substring problem. + + Returns: + None | slice + + Example: + >>> oset1 = ub.oset([1, 2, 3, 4, 5, 6]) + >>> oset2 = ub.oset([2, 3, 4]) + >>> index = subsequence_index(oset1, oset2) + >>> assert index + + >>> oset1 = ub.oset([1, 2, 3, 4, 5, 6]) + >>> oset2 = ub.oset([2, 4, 3]) + >>> index = subsequence_index(oset1, oset2) + >>> assert not index + """ + if len(oset2) == 0: + base = 0 + else: + item1 = oset2[0] + try: + base = oset1.index(item1) + except (IndexError, KeyError): + base = None + + index = None + if base is not None: + sl = slice(base, base + len(oset2)) + subset = oset1[sl] + if subset == oset2: + index = sl + return index + + +def oset_insert(self, index, obj): + """ + self = ub.oset() + oset_insert(self, 0, 'a') + oset_insert(self, 0, 'b') + oset_insert(self, 0, 'c') + oset_insert(self, 1, 'd') + oset_insert(self, 2, 'e') + oset_insert(self, 0, 'f') + """ + if obj not in self: + # Bump index of every item after the insert position + for key in self.items[index:]: + self.map[key] = self.map[key] + 1 + self.items.insert(index, obj) + self.map[obj] = index + + +def oset_delitem(self, index): + """ + for ubelt oset, todo contribute back to luminosoinsight + + >>> self = ub.oset([1, 2, 3, 4, 5, 6, 7, 8, 9]) + >>> index = slice(3, 5) + >>> oset_delitem(self, index) + + self = ub.oset(['r', 'g', 'b', 'disparity']) + index = slice(0, 3) + oset_delitem(self, index) + + """ + if isinstance(index, slice) and index == ub.orderedset.SLICE_ALL: + self.clear() + else: + if ub.orderedset.is_iterable(index): + to_remove = [self.items[i] for i in index] + elif isinstance(index, slice) or hasattr(index, "__index__"): + to_remove = self.items[index] + else: + raise TypeError("Don't know how to index an OrderedSet by %r" % index) + + if isinstance(to_remove, list): + # Modified version of discard slightly more efficient for multiple + # items + remove_idxs = sorted([self.map[key] for key in to_remove], reverse=True) + + for key in to_remove: + del self.map[key] + + for idx in remove_idxs: + del self.items[idx] + + for k, v in self.map.items(): + # I think there is a more efficient way to do this? + num_after = sum(v >= i for i in remove_idxs) + if num_after: + self.map[k] = v - num_after + else: + self.discard(to_remove) if __name__ == '__main__': """ diff --git a/netharn/data/data_containers.py b/netharn/data/data_containers.py index bda1696f8113f5362a3d5b87f5c72f4034044176..3a733a7ac7fc93b89ab73b91f7d4fb1def2026e7 100644 --- a/netharn/data/data_containers.py +++ b/netharn/data/data_containers.py @@ -54,16 +54,16 @@ class BatchContainer(ub.NiceRepr): outputs or a set of items that have already been collated. Attributes: - data (List): Unlike ItemContainer, data is always a list where - len(data) is the number of devices this batch will run on. - Each item in the list may be either a pre-batched Tensor (in the - case where the each item in the batch has the same shape) or a list - of individual item Tensors (in the case where different batch items + data (List[Any]): Unlike ItemContainer, data is always a list where + len(data) is the number of devices this batch will run on. Each + item in the list may be either a pre-batched Tensor (in the case + where the each item in the batch has the same shape) or a list of + individual item Tensors (in the case where different batch items may have different shapes). """ def __init__(self, data, stack=False, padding_value=-1, cpu_only=False, pad_dims=2): - self.data = data + self.data = data # type: list self.meta = { 'stack': stack, 'padding_value': padding_value, @@ -71,9 +71,45 @@ class BatchContainer(ub.NiceRepr): 'pad_dims': pad_dims, } + @property + def nestshape(self): + return nestshape(self.data) + + def numel(self): + """ + The number of scalar elements held by this container + """ + shapes = self.nestshape + total = sum([np.prod(s) for s in shapes]) + return total + + @property + def packshape(self): + """ + The shape of this data if it was packed + """ + # shape = np.maximum.reduce(self.nestshape) + # return shape + dim = 0 + if self.stack: + # Should be a straight forward concatenation + shapes = [d.shape for d in self.data] + max_shape = np.maximum.reduce(shapes) # should all be the same here + stacked_dim = sum([s[dim] for s in shapes]) + max_shape[dim] = stacked_dim + pack_shape = tuple(max_shape.tolist()) + return pack_shape + else: + shapes = nestshape(self.data) + max_shape = np.maximum.reduce(shapes) + stacked_dim = sum([s[dim] for s in shapes]) + max_shape[dim] = stacked_dim + pack_shape = tuple(max_shape.tolist()) + return pack_shape + def __nice__(self): try: - shape_repr = ub.repr2(nestshape(self.data), nl=-2) + shape_repr = ub.repr2(self.nestshape, nl=-2) return 'nestshape(data)={}'.format(shape_repr) except Exception: return super().__repr__() @@ -152,12 +188,11 @@ class BatchContainer(ub.NiceRepr): def to(self, device): """ inplace move data onto a device """ - for item in self.data: - if torch.is_tensor(item): - item.to(item) - else: - for subitem in item: - subitem.to(device) + from netharn.util.util_json import IndexableWalker + walker = IndexableWalker(self.data) + for path, val in walker: + if torch.is_tensor(val): + walker[path] = val.to(device) return self @@ -184,13 +219,16 @@ class ItemContainer(ub.NiceRepr): 'pad_dims': pad_dims, } + @property + def nestshape(self): + return nestshape(self.data) + def __nice__(self): try: - shape_repr = ub.repr2(nestshape(self.data), nl=-2) + shape_repr = ub.repr2(self.nestshape, nl=-2) return 'nestshape(data)={}'.format(shape_repr) except Exception: return super().__repr__() - # return 'nestshape(data)={}, **{}'.format(shape_repr, ub.repr2(self.meta, nl=0)) @classmethod def demo(cls, key='img', rng=None, **kwargs): @@ -869,6 +907,14 @@ class ContainerXPU(XPU): def nestshape(data): + """ + Examine nested shape of the data + + Example: + >>> data = [np.arange(10), np.arange(13)] + >>> nestshape(data) + [(10,), (13,)] + """ import ubelt as ub def _recurse(d): diff --git a/netharn/device.py b/netharn/device.py index 6d8a482c0c0ef8b8469a520bf94f8223c5e55336..8cc92558d8ae9184420f56b5ad831cf4f749137e 100644 --- a/netharn/device.py +++ b/netharn/device.py @@ -116,10 +116,10 @@ class XPU(ub.NiceRepr): if check: if not XPU.exists(item): if isinstance(item, int) and not torch.cuda.is_available(): - raise ValueError('XPU {} does not exist. ' + raise ValueError('XPU {!r} does not exist. ' 'CUDA is not available'.format(item)) else: - raise ValueError('XPU {} does not exist.'.format(item)) + raise ValueError('XPU {!r} does not exist.'.format(item)) if item is None: xpu._main_device_id = None diff --git a/netharn/examples/style_transfer.py b/netharn/examples/style_transfer.py new file mode 100644 index 0000000000000000000000000000000000000000..af47dc4af1ee4283b5406b7051877e988d180ea2 --- /dev/null +++ b/netharn/examples/style_transfer.py @@ -0,0 +1,443 @@ +""" +Adapated from : https://github.com/rrmina/fast-neural-style-pytorch/blob/master/video.py +""" +from torchvision import models, transforms +import sys +import torch +import torch.nn as nn +import ubelt as ub +import netharn as nh +import scriptconfig as scfg +import kwimage + + +class StyleTransferConfig(scfg.Config): + default = { + 'name': scfg.Value('style_example', help='A human readable tag that is "name" for humans'), + 'workdir': scfg.Path('~/work/netharn', help='Dump all results in your workdir'), + + 'workers': scfg.Value(2, help='number of parallel dataloading jobs'), + 'xpu': scfg.Value('auto', help='See netharn.XPU for details. can be auto/cpu/xpu/cuda0/0,1,2,3)'), + + 'datasets': scfg.Value('special:shapes256', help='Either a special key or a coco file'), + 'train_dataset': scfg.Value(None), + 'vali_dataset': scfg.Value(None), + 'test_dataset': scfg.Value(None), + + 'sampler_backend': scfg.Value(None, help='ndsampler backend'), + + 'channels': scfg.Value('rgb', help='special channel code. See ChannelSpec'), + + # 'arch': scfg.Value('resnet50', help='Network architecture code'), + 'optim': scfg.Value('adam', help='Weight optimizer. Can be SGD, ADAM, ADAMW, etc..'), + + 'input_dims': scfg.Value((256, 256), help='Window size to input to the network'), + + # TODO + 'normalize_inputs': scfg.Value(True, help=( + 'if True, precompute training mean and std for data whitening')), + + 'balance': scfg.Value(None, help='balance strategy. Can be category or None'), + # 'augmenter': scfg.Value('simple', help='type of training dataset augmentation'), + + 'batch_size': scfg.Value(3, help='number of items per batch'), + 'num_batches': scfg.Value('auto', help='Number of batches per epoch (mainly for balanced batch sampling)'), + + 'max_epoch': scfg.Value(140, help='Maximum number of epochs'), + 'patience': scfg.Value(140, help='Maximum "bad" validation epochs before early stopping'), + + 'lr': scfg.Value(1e-4, help='Base learning rate'), + 'decay': scfg.Value(1e-5, help='Base weight decay'), + 'schedule': scfg.Value( + 'step90-120', help=( + 'Special coercible netharn code. Eg: onecycle50, step50, gamma, ReduceLROnPlateau-p10-c10')), + 'init': scfg.Value('noop', help='How to initialized weights: e.g. noop, kaiming_normal, path-to-a-pretrained-model)'), + # 'pretrained': scfg.Path(help=('alternative way to specify a path to a pretrained model')), + } + + +class StyleTransferHarn(nh.FitHarn): + + def after_initialize(harn): + STYLE_IMAGE_PATH = ub.grabdata('https://raw.githubusercontent.com/iamRusty/fast-neural-style-pytorch/master/images/mosaic.jpg') + + device = harn.xpu.device + harn.MSELoss = nn.MSELoss().to(device) + + vgg_path = ub.grabdata('https://web.eecs.umich.edu/~justincj/models/vgg16-00b39a1b.pth') + + # TODO: should be tracked + harn.vgg = VGG16(**{'vgg_path': vgg_path}) + harn.vgg = harn.xpu.move(harn.vgg) + + def itot(img, max_size=None): + # Rescale the image + if (max_size is None): + itot_t = transforms.Compose([ + # transforms.ToPILImage(), + transforms.ToTensor(), + transforms.Lambda(lambda x: x.mul(255)) + ]) + else: + H, W, C = img.shape + image_size = tuple( + [int((float(max_size) / max([H, W])) * x) for x in [H, W]]) + itot_t = transforms.Compose([ + transforms.ToPILImage(), + transforms.Resize(image_size), + transforms.ToTensor(), + transforms.Lambda(lambda x: x.mul(255)) + ]) + + # Convert image to tensor + tensor = itot_t(img) + + # Add the batch_size dimension + tensor = tensor.unsqueeze(dim=0) + return tensor + + # Get Style Features + imagenet_neg_mean = torch.tensor( + [-103.939, -116.779, -123.68], dtype=torch.float32).reshape(1, 3, 1, 1).to(device) + style_image = kwimage.imread(STYLE_IMAGE_PATH) + style_tensor = itot(style_image).to(device) + style_tensor = style_tensor.add(imagenet_neg_mean) + B, C, H, W = style_tensor.shape + + harn.imagenet_neg_mean = imagenet_neg_mean + harn.style_tensor = style_tensor + + batch_size = harn.script_config['batch_size'] + im = style_tensor.expand([batch_size, C, H, W]) + + style_features = harn.vgg(im) + style_gram = {} + for key, value in style_features.items(): + style_gram[key] = gram(value) + harn.style_gram = style_gram + + def run_batch(harn, batch): + """ + Ignore: + import sys, ubelt + sys.path.append(ubelt.expandpath('~/code/netharn')) + from netharn.examples.style_transfer import * # NOQA + kw = {} + cmdline = False + harn = setup_harn() + harn.initialize() + batch = harn._demo_batch() + harn.run_batch(batch) + """ + # Current Batch size in case of odd batches + content_batch, _ = batch + curr_batch_size = content_batch.shape[0] + + model = harn.model + # Zero-out Gradients + + # Generate images and get features + content_batch = harn.xpu.move(content_batch[:, [2, 1, 0]]) + + generated_batch = model(content_batch) + + generated_batch = harn.model(content_batch) + content_features = harn.vgg(content_batch.add(harn.imagenet_neg_mean)) + generated_features = harn.vgg(generated_batch.add(harn.imagenet_neg_mean)) + + # Content Loss + CONTENT_WEIGHT = 17 + STYLE_WEIGHT = 50 + + content_loss = CONTENT_WEIGHT * \ + harn.MSELoss( + content_features['relu2_2'], + generated_features['relu2_2']) + + # Style Loss + style_loss = 0 + for key, value in generated_features.items(): + s_loss = harn.MSELoss( + gram(value), + harn.style_gram[key][:curr_batch_size] + ) + style_loss += s_loss + style_loss *= STYLE_WEIGHT + + # Total Loss + loss_parts = { + 'content_loss': content_loss, + 'style_loss': style_loss, + } + return generated_batch, loss_parts + + def on_batch(harn, batch, generated_batch, loss): + _do_draw = harn.batch_index % 500 == 0 + _do_draw |= harn.batch_index < 4 + if _do_draw: + # Save sample generated image + from os.path import join + dpath = ub.ensuredir((harn.train_dpath, 'monitor', harn.current_tag)) + sample_tensor = generated_batch[0].clone().detach().unsqueeze(dim=0) + sample_image = sample_tensor.clone().detach().cpu().squeeze().numpy().transpose(1, 2, 0) + sample_image_path = join(dpath, "sample0_" + str(harn.batch_index) + '_' + str(harn.batch_index) + ".png") + kwimage.imwrite(sample_image_path, sample_image.clip(0, 255)) + print("Saved sample tranformed image at {}".format(sample_image_path)) + + +def setup_harn(cmdline=False, **kw): + """ + Ignore: + kw = {} + cmdline = False + harn = setup_harn() + """ + config = StyleTransferConfig(default=kw) + config.load(cmdline=cmdline) + print('config = {}'.format(ub.repr2(config.asdict()))) + + nh.configure_hacks(config) + + dataset_info = nh.api.DatasetInfo.coerce(config) + + # input_stats = dataset_info['input_stats'] + model = (TransformerNetwork, {}) + + hyper = nh.HyperParams( + name=config['name'], + + workdir=config['workdir'], + xpu=nh.XPU.coerce(config['xpu']), + + datasets=dataset_info['torch_datasets'], + loaders=dataset_info['torch_loaders'], + + model=model, + criterion=None, + initializer=None, + + optimizer=nh.Optimizer.coerce(config), + dynamics=nh.Dynamics.coerce(config), + scheduler=nh.Scheduler.coerce(config), + + monitor=(nh.Monitor, { + 'minimize': ['loss'], + 'patience': config['patience'], + 'max_epoch': config['max_epoch'], + 'smoothing': 0.0, + }), + other={ + 'name': config['name'], + 'batch_size': config['batch_size'], + 'balance': config['balance'], + }, + extra={ + 'argv': sys.argv, + 'config': ub.repr2(config.asdict()), + } + ) + harn = StyleTransferHarn(hyper=hyper) + harn.preferences.update({ + 'num_keep': 3, + 'keep_freq': 10, + 'tensorboard_groups': ['loss'], + 'eager_dump_tensorboard': True, + }) + harn.intervals.update({}) + harn.script_config = config + return harn + + +def gram(tensor): + B, C, H, W = tensor.shape + x = tensor.view(B, C, H * W) + x_t = x.transpose(1, 2) + return torch.bmm(x, x_t) / (C * H * W) + + +class VGG16(nn.Module): + def __init__(self, vgg_path="models/vgg16-00b39a1b.pth"): + super(VGG16, self).__init__() + self.vgg_path = vgg_path + # Load VGG Skeleton, Pretrained Weights + vgg16_features = models.vgg16(pretrained=False) + vgg16_features.load_state_dict(torch.load(vgg_path), strict=False) + self.features = vgg16_features.features + + # Turn-off Gradient History + for param in self.features.parameters(): + param.requires_grad = False + + def forward(self, x): + layers = { + '3': 'relu1_2', + '8': 'relu2_2', + '15': 'relu3_3', + '22': 'relu4_3'} + features = {} + for name, layer in self.features._modules.items(): + x = layer(x) + if name in layers: + features[layers[name]] = x + if name == '22': + break + + return features + + +class TransformerNetwork(nn.Module): + """Feedforward Transformation Network without Tanh + reference: https://arxiv.org/abs/1603.08155 + exact architecture: https://cs.stanford.edu/people/jcjohns/papers/fast-style/fast-style-supp.pdf + """ + + def __init__(self): + super().__init__() + self.ConvBlock = nn.Sequential( + ConvLayer(3, 32, 9, 1), + nn.ReLU(), + ConvLayer(32, 64, 3, 2), + nn.ReLU(), + ConvLayer(64, 128, 3, 2), + nn.ReLU() + ) + self.ResidualBlock = nn.Sequential( + ResidualLayer(128, 3), + ResidualLayer(128, 3), + ResidualLayer(128, 3), + ResidualLayer(128, 3), + ResidualLayer(128, 3) + ) + self.DeconvBlock = nn.Sequential( + DeconvLayer(128, 64, 3, 2, 1), + nn.ReLU(), + DeconvLayer(64, 32, 3, 2, 1), + nn.ReLU(), + ConvLayer(32, 3, 9, 1, norm="None") + ) + + def forward(self, x): + x = self.ConvBlock(x) + x = self.ResidualBlock(x) + out = self.DeconvBlock(x) + return out + + +class TransformerNetworkTanh(TransformerNetwork): + """A modification of the transformation network that uses Tanh function as output + This follows more closely the architecture outlined in the original paper's supplementary material + his net produces darker images and provides retro styling effect + Reference: https://cs.stanford.edu/people/jcjohns/papers/fast-style/fast-style-supp.pdf + """ + # override __init__ method + + def __init__(self, tanh_multiplier=150): + super(TransformerNetworkTanh, self).__init__() + # Add a Tanh layer before output + self.DeconvBlock = nn.Sequential( + DeconvLayer(128, 64, 3, 2, 1), + nn.ReLU(), + DeconvLayer(64, 32, 3, 2, 1), + nn.ReLU(), + ConvLayer(32, 3, 9, 1, norm="None"), + nn.Tanh() + ) + self.tanh_multiplier = tanh_multiplier + + # Override forward method + def forward(self, x): + return super(TransformerNetworkTanh, self).forward( + x) * self.tanh_multiplier + + +class ConvLayer(nn.Module): + def __init__(self, in_channels, out_channels, + kernel_size, stride, norm="instance"): + super(ConvLayer, self).__init__() + # Padding Layers + padding_size = kernel_size // 2 + self.reflection_pad = nn.ReflectionPad2d(padding_size) + + # Convolution Layer + self.conv_layer = nn.Conv2d( + in_channels, out_channels, kernel_size, stride) + + # Normalization Layers + self.norm_type = norm + if (norm == "instance"): + self.norm_layer = nn.InstanceNorm2d(out_channels, affine=True) + elif (norm == "batch"): + self.norm_layer = nn.BatchNorm2d(out_channels, affine=True) + + def forward(self, x): + x = self.reflection_pad(x) + x = self.conv_layer(x) + if (self.norm_type == "None"): + out = x + else: + out = self.norm_layer(x) + return out + + +class ResidualLayer(nn.Module): + """ + Deep Residual Learning for Image Recognition + + https://arxiv.org/abs/1512.03385 + """ + + def __init__(self, channels=128, kernel_size=3): + super(ResidualLayer, self).__init__() + self.conv1 = ConvLayer(channels, channels, kernel_size, stride=1) + self.relu = nn.ReLU() + self.conv2 = ConvLayer(channels, channels, kernel_size, stride=1) + + def forward(self, x): + identity = x # preserve residual + out = self.relu(self.conv1(x)) # 1st conv layer + activation + out = self.conv2(out) # 2nd conv layer + out = out + identity # add residual + return out + + +class DeconvLayer(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, + stride, output_padding, norm="instance"): + super(DeconvLayer, self).__init__() + + # Transposed Convolution + padding_size = kernel_size // 2 + self.conv_transpose = nn.ConvTranspose2d( + in_channels, + out_channels, + kernel_size, + stride, + padding_size, + output_padding) + + # Normalization Layers + self.norm_type = norm + if (norm == "instance"): + self.norm_layer = nn.InstanceNorm2d(out_channels, affine=True) + elif (norm == "batch"): + self.norm_layer = nn.BatchNorm2d(out_channels, affine=True) + + def forward(self, x): + x = self.conv_transpose(x) + if (self.norm_type == "None"): + out = x + else: + out = self.norm_layer(x) + return out + + +if __name__ == '__main__': + """ + CommandLine: + python ~/code/netharn/netharn/examples/style_transfer.py \ + --xpu=0 \ + --train_dataset=shapes1024 \ + --vali_dataset=shapes1024 \ + """ + harn = setup_harn() + harn.run() diff --git a/netharn/examples/style_transfer_orig.py b/netharn/examples/style_transfer_orig.py new file mode 100644 index 0000000000000000000000000000000000000000..3ee7275d5f7a6031efa654e38d325b70fd60e004 --- /dev/null +++ b/netharn/examples/style_transfer_orig.py @@ -0,0 +1,584 @@ +from torchvision import datasets, models, transforms +import cv2 +import torch +import torch.nn as nn +import torch.optim as optim +import random +import numpy as np +import time +import ubelt as ub + +notes = """ +# GLOBAL SETTINGS + +!wget https://web.eecs.umich.edu/~justincj/models/vgg16-00b39a1b.pth + +# download style image +!wget https://raw.githubusercontent.com/iamRusty/fast-neural-style-pytorch/master/images/mosaic.jpg +""" + + +SAVE_MODEL_EVERY = 500 # 2,000 Images with batch size 4 +SEED = 35 +BATCH_SIZE = 4 +CONTENT_WEIGHT = 17 +STYLE_WEIGHT = 50 +TV_WEIGHT = 1e-6 +ADAM_LR = 0.001 +NUM_EPOCHS = 1 + + +class VGG19(nn.Module): + def __init__(self, vgg_path="models/vgg19-d01eb7cb.pth"): + super(VGG19, self).__init__() + # Load VGG Skeleton, Pretrained Weights + vgg19_features = models.vgg19(pretrained=False) + vgg19_features.load_state_dict(torch.load(vgg_path), strict=False) + self.features = vgg19_features.features + + # Turn-off Gradient History + for param in self.features.parameters(): + param.requires_grad = False + + def forward(self, x): + layers = { + '3': 'relu1_2', + '8': 'relu2_2', + '17': 'relu3_4', + '22': 'relu4_2', + '26': 'relu4_4', + '35': 'relu5_4'} + features = {} + for name, layer in self.features._modules.items(): + x = layer(x) + if name in layers: + features[layers[name]] = x + + return features + + +class VGG16(nn.Module): + def __init__(self, vgg_path="models/vgg16-00b39a1b.pth"): + super(VGG16, self).__init__() + # Load VGG Skeleton, Pretrained Weights + vgg16_features = models.vgg16(pretrained=False) + vgg16_features.load_state_dict(torch.load(vgg_path), strict=False) + self.features = vgg16_features.features + + # Turn-off Gradient History + for param in self.features.parameters(): + param.requires_grad = False + + def forward(self, x): + layers = { + '3': 'relu1_2', + '8': 'relu2_2', + '15': 'relu3_3', + '22': 'relu4_3'} + features = {} + for name, layer in self.features._modules.items(): + x = layer(x) + if name in layers: + features[layers[name]] = x + if (name == '22'): + break + + return features + +# Gram Matrix + + +def gram(tensor): + B, C, H, W = tensor.shape + x = tensor.view(B, C, H * W) + x_t = x.transpose(1, 2) + return torch.bmm(x, x_t) / (C * H * W) + +# Load image file + + +def load_image(path): + # Images loaded as BGR + img = cv2.imread(path) + return img + +# Show image + + +def show(img): + import matplotlib.pyplot as plt + # Convert from BGR to RGB + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + + # imshow() only accepts float [0,1] or int [0,255] + img = np.array(img / 255).clip(0, 1) + + plt.figure(figsize=(10, 5)) + plt.imshow(img) + plt.show() + + +def saveimg(img, image_path): + img = img.clip(0, 255) + cv2.imwrite(image_path, img) + +# Preprocessing ~ Image to Tensor + + +def itot(img, max_size=None): + # Rescale the image + if (max_size is None): + itot_t = transforms.Compose([ + # transforms.ToPILImage(), + transforms.ToTensor(), + transforms.Lambda(lambda x: x.mul(255)) + ]) + else: + H, W, C = img.shape + image_size = tuple( + [int((float(max_size) / max([H, W])) * x) for x in [H, W]]) + itot_t = transforms.Compose([ + transforms.ToPILImage(), + transforms.Resize(image_size), + transforms.ToTensor(), + transforms.Lambda(lambda x: x.mul(255)) + ]) + + # Convert image to tensor + tensor = itot_t(img) + + # Add the batch_size dimension + tensor = tensor.unsqueeze(dim=0) + return tensor + +# Preprocessing ~ Tensor to Image + + +def ttoi(tensor): + # Add the means + # ttoi_t = transforms.Compose([ + # transforms.Normalize([-103.939, -116.779, -123.68],[1,1,1])]) + + # Remove the batch_size dimension + tensor = tensor.squeeze() + #img = ttoi_t(tensor) + img = tensor.cpu().numpy() + + # Transpose from [C, H, W] -> [H, W, C] + img = img.transpose(1, 2, 0) + return img + + +def transfer_color(src, dest): + """ + Transfer Color using YIQ colorspace. Useful in preserving colors in style transfer. + This method assumes inputs of shape [Height, Width, Channel] in BGR Color Space + """ + src, dest = src.clip(0, 255), dest.clip(0, 255) + + # Resize src to dest's size + H, W, _ = src.shape + dest = cv2.resize(dest, dsize=(W, H), interpolation=cv2.INTER_CUBIC) + + # 1 Extract the Destination's luminance + dest_gray = cv2.cvtColor(dest, cv2.COLOR_BGR2GRAY) + # 2 Convert the Source from BGR to YIQ/YCbCr + src_yiq = cv2.cvtColor(src, cv2.COLOR_BGR2YCrCb) + # 3 Combine Destination's luminance and Source's IQ/CbCr + src_yiq[..., 0] = dest_gray + + return cv2.cvtColor(src_yiq, cv2.COLOR_YCrCb2BGR).clip( + 0, 255) # 4 Convert new image from YIQ back to BGR + + +def plot_loss_hist(c_loss, s_loss, total_loss, title="Loss History"): + import matplotlib.pyplot as plt + x = [i for i in range(len(total_loss))] + plt.figure(figsize=[10, 6]) + plt.plot(x, c_loss, label="Content Loss") + plt.plot(x, s_loss, label="Style Loss") + plt.plot(x, total_loss, label="Total Loss") + + plt.legend() + plt.xlabel('Every 500 iterations') + plt.ylabel('Loss') + plt.title(title) + plt.show() + + +class ImageFolderWithPaths(datasets.ImageFolder): + """Custom dataset that includes image file paths. + Extends torchvision.datasets.ImageFolder() + Reference: https://discuss.pytorch.org/t/dataloader-filenames-in-each-batch/4212/2 + """ + # override the __getitem__ method. this is the method dataloader calls + + def __getitem__(self, index): + # this is what ImageFolder normally returns + original_tuple = super(ImageFolderWithPaths, self).__getitem__(index) + + # the image file path + path = self.imgs[index][0] + + # make a new tuple that includes original and the path + tuple_with_path = (*original_tuple, path) + return tuple_with_path + + +class TransformerNetwork(nn.Module): + """Feedforward Transformation Network without Tanh + reference: https://arxiv.org/abs/1603.08155 + exact architecture: https://cs.stanford.edu/people/jcjohns/papers/fast-style/fast-style-supp.pdf + """ + + def __init__(self): + super().__init__() + self.ConvBlock = nn.Sequential( + ConvLayer(3, 32, 9, 1), + nn.ReLU(), + ConvLayer(32, 64, 3, 2), + nn.ReLU(), + ConvLayer(64, 128, 3, 2), + nn.ReLU() + ) + self.ResidualBlock = nn.Sequential( + ResidualLayer(128, 3), + ResidualLayer(128, 3), + ResidualLayer(128, 3), + ResidualLayer(128, 3), + ResidualLayer(128, 3) + ) + self.DeconvBlock = nn.Sequential( + DeconvLayer(128, 64, 3, 2, 1), + nn.ReLU(), + DeconvLayer(64, 32, 3, 2, 1), + nn.ReLU(), + ConvLayer(32, 3, 9, 1, norm="None") + ) + + def forward(self, x): + x = self.ConvBlock(x) + x = self.ResidualBlock(x) + out = self.DeconvBlock(x) + return out + + +class TransformerNetworkTanh(TransformerNetwork): + """A modification of the transformation network that uses Tanh function as output + This follows more closely the architecture outlined in the original paper's supplementary material + his net produces darker images and provides retro styling effect + Reference: https://cs.stanford.edu/people/jcjohns/papers/fast-style/fast-style-supp.pdf + """ + # override __init__ method + + def __init__(self, tanh_multiplier=150): + super(TransformerNetworkTanh, self).__init__() + # Add a Tanh layer before output + self.DeconvBlock = nn.Sequential( + DeconvLayer(128, 64, 3, 2, 1), + nn.ReLU(), + DeconvLayer(64, 32, 3, 2, 1), + nn.ReLU(), + ConvLayer(32, 3, 9, 1, norm="None"), + nn.Tanh() + ) + self.tanh_multiplier = tanh_multiplier + + # Override forward method + def forward(self, x): + return super(TransformerNetworkTanh, self).forward( + x) * self.tanh_multiplier + + +class ConvLayer(nn.Module): + def __init__(self, in_channels, out_channels, + kernel_size, stride, norm="instance"): + super(ConvLayer, self).__init__() + # Padding Layers + padding_size = kernel_size // 2 + self.reflection_pad = nn.ReflectionPad2d(padding_size) + + # Convolution Layer + self.conv_layer = nn.Conv2d( + in_channels, out_channels, kernel_size, stride) + + # Normalization Layers + self.norm_type = norm + if (norm == "instance"): + self.norm_layer = nn.InstanceNorm2d(out_channels, affine=True) + elif (norm == "batch"): + self.norm_layer = nn.BatchNorm2d(out_channels, affine=True) + + def forward(self, x): + x = self.reflection_pad(x) + x = self.conv_layer(x) + if (self.norm_type == "None"): + out = x + else: + out = self.norm_layer(x) + return out + + +class ResidualLayer(nn.Module): + """ + Deep Residual Learning for Image Recognition + + https://arxiv.org/abs/1512.03385 + """ + + def __init__(self, channels=128, kernel_size=3): + super(ResidualLayer, self).__init__() + self.conv1 = ConvLayer(channels, channels, kernel_size, stride=1) + self.relu = nn.ReLU() + self.conv2 = ConvLayer(channels, channels, kernel_size, stride=1) + + def forward(self, x): + identity = x # preserve residual + out = self.relu(self.conv1(x)) # 1st conv layer + activation + out = self.conv2(out) # 2nd conv layer + out = out + identity # add residual + return out + + +class DeconvLayer(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, + stride, output_padding, norm="instance"): + super(DeconvLayer, self).__init__() + + # Transposed Convolution + padding_size = kernel_size // 2 + self.conv_transpose = nn.ConvTranspose2d( + in_channels, + out_channels, + kernel_size, + stride, + padding_size, + output_padding) + + # Normalization Layers + self.norm_type = norm + if (norm == "instance"): + self.norm_layer = nn.InstanceNorm2d(out_channels, affine=True) + elif (norm == "batch"): + self.norm_layer = nn.BatchNorm2d(out_channels, affine=True) + + def forward(self, x): + x = self.conv_transpose(x) + if (self.norm_type == "None"): + out = x + else: + out = self.norm_layer(x) + return out + + +def load_cifar(key='cifar10', workdir=None, transform=None): + """ + key = 'cifar10' + load_cifar(key, workdir=None) + """ + import torchvision + import pickle + import os + if workdir is None: + workdir = ub.ensure_app_cache_dir('netharn') + + if key == 'cifar10': + DATASET = torchvision.datasets.CIFAR10 + dset = DATASET(root=workdir, download=True, transform=transform) + meta_fpath = os.path.join(dset.root, dset.base_folder, 'batches.meta') + meta_dict = pickle.load(open(meta_fpath, 'rb')) + dset.classes = meta_dict['label_names'] + # For some reason the torchvision objects dont have the label names + # in the dataset. But the download directory will have them. + # classes = [ + # 'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', + # 'horse', 'ship', 'truck', + # ] + elif key == 'cifar100': + DATASET = torchvision.datasets.CIFAR100 + dset = DATASET(root=workdir, download=True, transform=transform) + meta_fpath = os.path.join(dset.root, dset.base_folder, 'meta') + meta_dict = pickle.load(open(meta_fpath, 'rb')) + dset.classes = meta_dict['fine_label_names'] + return dset + + +def train(): + """ + CommandLine: + xdoctest -m /home/joncrall/code/netharn/netharn/examples/style_transfer.py train + + Example: + >>> # xdoctest: +SKIP + >>> import sys, ubelt + >>> sys.path.append(ubelt.expandpath('~/code/netharn')) + >>> from netharn.examples.style_transfer import * # NOQA + >>> train() + """ + + STYLE_IMAGE_PATH = ub.grabdata('https://raw.githubusercontent.com/iamRusty/fast-neural-style-pytorch/master/images/mosaic.jpg') + vgg_path = ub.grabdata('https://web.eecs.umich.edu/~justincj/models/vgg16-00b39a1b.pth') + workdir = ub.ensure_app_cache_dir('netharn') + + # Seeds + torch.manual_seed(SEED) + torch.cuda.manual_seed(SEED) + np.random.seed(SEED) + random.seed(SEED) + + # Device + device = ("cuda" if torch.cuda.is_available() else "cpu") + + # Dataset and Dataloader + TRAIN_IMAGE_SIZE = 256 + transform = transforms.Compose([ + transforms.Resize(TRAIN_IMAGE_SIZE), + transforms.CenterCrop(TRAIN_IMAGE_SIZE), + transforms.ToTensor(), + transforms.Lambda(lambda x: x.mul(255)) + ]) + + SAVE_MODEL_PATH = workdir + SAVE_IMAGE_PATH = workdir + # DATASET_PATH = "/content/train" + # train_dataset = datasets.ImageFolder(DATASET_PATH, transform=transform) + + train_dataset = load_cifar(transform=transform) + train_loader = torch.utils.data.DataLoader( + train_dataset, batch_size=BATCH_SIZE, shuffle=True) + + # Load networks + model = VGG16(vgg_path).to(device) + + # Get Style Features + imagenet_neg_mean = torch.tensor( + [-103.939, -116.779, -123.68], dtype=torch.float32).reshape(1, 3, 1, 1).to(device) + style_image = load_image(STYLE_IMAGE_PATH) + style_tensor = itot(style_image).to(device) + style_tensor = style_tensor.add(imagenet_neg_mean) + B, C, H, W = style_tensor.shape + style_features = model(style_tensor.expand([BATCH_SIZE, C, H, W])) + style_gram = {} + for key, value in style_features.items(): + style_gram[key] = gram(value) + + # Optimizer settings + params = list(model.parameters()) + optimizer = optim.Adam(params, lr=ADAM_LR) + + # Loss trackers + content_loss_history = [] + style_loss_history = [] + total_loss_history = [] + batch_content_loss_sum = 0 + batch_style_loss_sum = 0 + batch_total_loss_sum = 0 + + # Optimization/Training Loop + batch_count = 1 + start_time = time.time() + for epoch in range(1, NUM_EPOCHS + 1): + print("========Epoch {}/{}========".format(epoch, NUM_EPOCHS + 1)) + for batch_id, (content_batch, _) in enumerate(train_loader): + # Current Batch size in case of odd batches + curr_batch_size = content_batch.shape[0] + + # Zero-out Gradients + optimizer.zero_grad() + + # Generate images and get features + content_batch = content_batch[:, [2, 1, 0]].to(device) + generated_batch = model(content_batch) + content_features = model(content_batch.add(imagenet_neg_mean)) + generated_features = model(generated_batch.add(imagenet_neg_mean)) + + # Content Loss + MSELoss = nn.MSELoss().to(device) + content_loss = CONTENT_WEIGHT * \ + MSELoss( + content_features['relu2_2'], + generated_features['relu2_2']) + batch_content_loss_sum += content_loss + + # Style Loss + style_loss = 0 + for key, value in generated_features.items(): + s_loss = MSELoss(gram(value), + style_gram[key][:curr_batch_size]) + style_loss += s_loss + style_loss *= STYLE_WEIGHT + batch_style_loss_sum += style_loss + + # Total Loss + total_loss = content_loss + style_loss + batch_total_loss_sum += total_loss.item() + + # Backprop and Weight Update + total_loss.backward() + optimizer.step() + + # Save Model and Print Losses + if (((batch_count - 1) % SAVE_MODEL_EVERY == 0) or (batch_count == NUM_EPOCHS * len(train_loader))): + # Print Losses + print("========Iteration {}/{}========".format(batch_count, + NUM_EPOCHS * len(train_loader))) + print( + "\tContent Loss:\t{:.2f}".format( + batch_content_loss_sum / + batch_count)) + print( + "\tStyle Loss:\t{:.2f}".format( + batch_style_loss_sum / + batch_count)) + print( + "\tTotal Loss:\t{:.2f}".format( + batch_total_loss_sum / + batch_count)) + print( + "Time elapsed:\t{} seconds".format( + time.time() - start_time)) + + # Save Model + checkpoint_path = SAVE_MODEL_PATH + \ + "checkpoint_" + str(batch_count - 1) + ".pth" + torch.save(TransformerNetwork.state_dict(), checkpoint_path) + print( + "Saved TransformerNetwork checkpoint file at {}".format(checkpoint_path)) + + # Save sample generated image + sample_tensor = generated_batch[0].clone( + ).detach().unsqueeze(dim=0) + sample_image = ttoi(sample_tensor.clone().detach()) + sample_image_path = SAVE_IMAGE_PATH + \ + "sample0_" + str(batch_count - 1) + ".png" + saveimg(sample_image, sample_image_path) + # show(sample_image) + print( + "Saved sample tranformed image at {}".format(sample_image_path)) + + # Save loss histories + content_loss_history.append(batch_total_loss_sum / batch_count) + style_loss_history.append(batch_style_loss_sum / batch_count) + total_loss_history.append(batch_total_loss_sum / batch_count) + + # Iterate Batch Counter + batch_count += 1 + + stop_time = time.time() + # Print loss histories + print("Done Training the Transformer Network!") + print("Training Time: {} seconds".format(stop_time - start_time)) + print("========Content Loss========") + print(content_loss_history) + print("========Style Loss========") + print(style_loss_history) + print("========Total Loss========") + print(total_loss_history) + + # Save TransformerNetwork weights + TransformerNetwork.eval() + TransformerNetwork.cpu() + final_path = SAVE_MODEL_PATH + "transformer_weight.pth" + print("Saving TransformerNetwork weights at {}".format(final_path)) + torch.save(TransformerNetwork.state_dict(), final_path) + print("Done saving final net") diff --git a/netharn/export/__init__.py b/netharn/export/__init__.py index 9b01c180fa4dd18c9d8712221fd31ed580e253b3..15794cc6d13320606e6011b5accdcdcf104f86f2 100644 --- a/netharn/export/__init__.py +++ b/netharn/export/__init__.py @@ -1,10 +1,153 @@ """ -mkinit ~/code/netharn/netharn/export +NOTICE: ``netharn.export`` has been refactored into the packages ``liberator`` +which performs general code extraction and ``torch_liberator`` which is +specific to pytorch. This module is deprecated and will be removed in the +future. + +For now here are old docs, slightly updated to reference the correct packages: + +The package torch_liberator.deployed contains DeployedModel, which consists of +logic to take the model topology definition along with the "best" snapshot in a +training directory and package it up into a standalone zipfile. The +DeployedModel can also be used to reload model from this zipfile. Thus this +zipfile can be passed around as a pytorch model topology+pretrained weights +transfer format. + +The file torch_liberator.exporter contains the code that simply exports the +model toplogy via code Uses static analysis to export relevant code that +defines the model topology into a stanadlone file. As long as your model +definition is indepenent of your training code, then the exported file can be +passed around in a similar way to a caffe prototext file. + + +CommandLine: + # Runs the following example + xdoctest -m netharn.export __doc__:0 + +Example: + >>> # xdoctest: +IGNORE_WANT + >>> # This example will train a small model and then deploy it. + >>> import netharn as nh + >>> import ubelt as ub + >>> # + >>> ################################################# + >>> print('--- STEP 1: TRAIN A MODEL ---') + >>> # This will train a toy model with toy data using netharn + >>> hyper = nh.HyperParams(**{ + >>> 'workdir' : ub.ensure_app_cache_dir('netharn/tests/deploy'), + >>> 'name' : 'deploy_demo', + >>> 'xpu' : nh.XPU.coerce('cpu'), + >>> 'datasets' : { + >>> 'train': nh.data.ToyData2d(size=3, border=1, n=256, rng=0), + >>> 'test': nh.data.ToyData2d(size=3, border=1, n=128, rng=1), + >>> }, + >>> 'loaders' : {'batch_size': 64}, + >>> 'model' : (nh.models.ToyNet2d, {}), + >>> 'optimizer' : (nh.optimizers.SGD, {'lr': 0.0001}), + >>> 'criterion' : (nh.criterions.CrossEntropyLoss, {}), + >>> 'initializer' : (nh.initializers.KaimingNormal, {}), + >>> 'scheduler' : (nh.schedulers.ListedLR, { + >>> 'points': {0: .01, 3: 0.1}, + >>> 'interpolate': True, + >>> }), + >>> 'monitor' : (nh.Monitor, {'max_epoch': 3,}), + >>> }) + >>> harn = nh.FitHarn(hyper) + >>> harn.preferences['use_tensorboard'] = False + >>> harn.preferences['timeout'] = 1 + >>> harn.intervals['test'] = 1 + >>> harn.initialize(reset='delete') + >>> harn.run() + --- STEP 1: TRAIN A MODEL --- + RESET HARNESS BY DELETING EVERYTHING IN TRAINING DIR + Symlink: .../.cache/netharn/tests/deploy/fit/runs/deploy_demo/onnxqaww -> .../.cache/netharn/tests/deploy/_mru + ...... + Symlink: .../.cache/netharn/tests/deploy/fit/runs/deploy_demo/onnxqaww -> .../.cache/netharn/tests/deploy/fit/nice/deploy_demo + ...... + INFO: Model has 824 parameters + INFO: Mounting ToyNet2d model on CPU + INFO: Exported model topology to .../.cache/netharn/tests/deploy/fit/runs/deploy_demo/onnxqaww/ToyNet2d_2a3f49.py + INFO: Initializing model weights with: + INFO: * harn.train_dpath = '.../.cache/netharn/tests/deploy/fit/runs/deploy_demo/onnxqaww' + INFO: * harn.name_dpath = '.../.cache/netharn/tests/deploy/fit/name/deploy_demo' + INFO: Snapshots will save to harn.snapshot_dpath = '.../.cache/netharn/tests/deploy/fit/runs/deploy_demo/onnxqaww/torch_snapshots' + INFO: ARGV: + .../.local/conda/envs/py36/bin/python .../.local/conda/envs/py36/bin/ipython + INFO: === begin training 0 / 3 : deploy_demo === + epoch lr:0.01 │ vloss is unevaluated 0/3... rate=0 Hz, eta=?, total=0:00:00, wall=19:32 EST + train loss:0.717 │ 100.00% of 64x8... rate=2093.02 Hz, eta=0:00:00, total=0:00:00, wall=19:32 EST + test loss:0.674 │ 100.00% of 64x4... rate=14103.48 Hz, eta=0:00:00, total=0:00:00, wall=19:32 EST + Populating the interactive namespace from numpy and matplotlib + INFO: === finish epoch 0 / 3 : deploy_demo === + epoch lr:0.04 │ vloss is unevaluated 1/3... rate=0.87 Hz, eta=0:00:02, total=0:00:01, wall=19:32 EST + train loss:0.712 │ 100.00% of 64x8... rate=2771.29 Hz, eta=0:00:00, total=0:00:00, wall=19:32 EST + test loss:0.663 │ 100.00% of 64x4... rate=15867.59 Hz, eta=0:00:00, total=0:00:00, wall=19:32 EST + INFO: === finish epoch 1 / 3 : deploy_demo === + epoch lr:0.07 │ vloss is unevaluated 2/3... rate=1.04 Hz, eta=0:00:00, total=0:00:01, wall=19:32 EST + train loss:0.686 │ 100.00% of 64x8... rate=2743.56 Hz, eta=0:00:00, total=0:00:00, wall=19:32 EST + test loss:0.636 │ 100.00% of 64x4... rate=14332.63 Hz, eta=0:00:00, total=0:00:00, wall=19:32 EST + INFO: === finish epoch 2 / 3 : deploy_demo === + epoch lr:0.1 │ vloss is unevaluated 3/3... rate=1.11 Hz, eta=0:00:00, total=0:00:02, wall=19:32 EST + INFO: Maximum harn.epoch reached, terminating ... + INFO: + INFO: training completed + INFO: harn.train_dpath = '.../.cache/netharn/tests/deploy/fit/runs/deploy_demo/onnxqaww' + INFO: harn.name_dpath = '.../.cache/netharn/tests/deploy/fit/name/deploy_demo' + INFO: view tensorboard results for this run via: + tensorboard --logdir ~/.cache/netharn/tests/deploy/fit/name + [DEPLOYER] Deployed zipfpath=.../.cache/netharn/tests/deploy/fit/runs/deploy_demo/onnxqaww/deploy_ToyNet2d_onnxqaww_002_TXZBYL.zip + INFO: wrote single-file deployment to: '.../.cache/netharn/tests/deploy/fit/runs/deploy_demo/onnxqaww/deploy_ToyNet2d_onnxqaww_002_TXZBYL.zip' + INFO: exiting fit harness. + Out[1]: '.../.cache/netharn/tests/deploy/fit/runs/deploy_demo/onnxqaww/deploy_ToyNet2d_onnxqaww_002_TXZBYL.zip' + >>> # + >>> ########################################## + >>> print('--- STEP 2: DEPLOY THE MODEL ---') + >>> # First we export the model topology to a standalone file + >>> # (Note: this step is done automatically in `harn.run`, but we do + >>> # it again here for demo purposes) + >>> import torch_liberator + >>> topo_fpath = torch_liberator.export_model_code(harn.train_dpath, harn.hyper.model_cls, harn.hyper.model_params) + >>> # Now create an instance of deployed model that points to the + >>> # Training dpath. (Note the directory structure setup by netharn is + >>> # itself a deployment, it just has multiple files) + >>> deployer = torch_liberator.DeployedModel(harn.train_dpath) + >>> # Use the DeployedModel to package the imporant info in train_dpath + >>> # into a standalone zipfile. + >>> zip_fpath = deployer.package() + >>> print('We exported the topology to: {!r}'.format(topo_fpath)) + >>> print('We exported the topology+weights to: {!r}'.format(zip_fpath)) + --- STEP 2: DEPLOY THE MODEL --- + We exported the topology to: '...tests/deploy/fit/runs/deploy_demo/onnxqaww/ToyNet2d_2a3f49.py' + We exported the topology+weights to: '...tests/deploy/fit/runs/deploy_demo/onnxqaww/deploy_ToyNet2d_onnxqaww_002_HVWCGI.zip' + >>> # + >>> ################################################# + >>> print('--- STEP 3: LOAD THE DEPLOYED MODEL ---') + >>> # Now we can move the zipfile anywhere we want, and we should + >>> # still be able to load it (depending on how coupled the model is). + >>> # Create an instance of DeployedModel that points to the zipfile + >>> # (Note: DeployedModel is used to both package and load models) + >>> loader = torch_liberator.DeployedModel(zip_fpath) + >>> model = loader.load_model() + >>> # This model is now loaded with the corret weights. + >>> # You can use it as normal. + >>> model.eval() + >>> images = harn._demo_batch(0)[0][0:1] + >>> outputs = model(images) + >>> print('outputs = {!r}'.format(outputs)) + >>> # Not that the loaded model is independent of harn.model + >>> print('model.__module__ = {!r}'.format(model.__module__)) + >>> print('harn.model.module.__module__ = {!r}'.format(harn.model.module.__module__)) + --- STEP 3: LOAD THE DEPLOYED MODEL --- + outputs = tensor([[0.4105, 0.5895]], grad_fn=) + model.__module__ = 'deploy_ToyNet2d_onnxqaww_002_HVWCGI/ToyNet2d_2a3f49' + harn.model.module.__module__ = 'netharn.models.toynet' """ from netharn.export import deployer from netharn.export import exporter from netharn.export.deployer import (DeployedModel,) from netharn.export.exporter import (export_model_code,) +import warnings +warnings.warn('netharn.export is deprecated, use torch_liberator intead', DeprecationWarning) __all__ = ['DeployedModel', 'deployer', 'export_model_code', 'exporter'] diff --git a/netharn/export/closer.py b/netharn/export/closer.py index 3d3910ad4daec5d9275a659f96a25c69ee03804e..5875d1c12846946cbd97d864635bc448e9db6eb5 100644 --- a/netharn/export/closer.py +++ b/netharn/export/closer.py @@ -1,954 +1,4 @@ # -*- coding: utf-8 -*- -""" -Extracts relevant parts of the source code - -NOTE: - IF THE SOURCE CODE CHANGES WHILE THE RUN IS EXECUTING THEN THIS MAY NOT - WORK CORRECTLY. - -# TODO: -# - [x] Maintain a parse tree instead of raw lines -# - [x] Keep a mapping from "definition names" to the top-level nodes -# in the parse tree that define them. -# - [X] For each extracted node in the parse tree keep track of -# - [X] where it came from -# - [ ] what modifications were made to it -# - [ ] Handle expanding imports nested within functions -# - [ ] Maintain docstring formatting after using the node transformer -""" -from __future__ import absolute_import, division, print_function, unicode_literals -from os.path import isdir -from os.path import join -from os.path import basename -from collections import OrderedDict +from liberator.closer import * # NOQA import warnings -import ast -import astunparse -import inspect -import six -import ubelt as ub -from six.moves import cStringIO -from os.path import abspath -from os.path import sys - - -# There is a bug where closing netharn cant find "HiddenFields" -HACK_FIX_CANNOT_FIND_HIDDEN = 0 - -DEBUG = 0 - - -class Unparser(astunparse.Unparser): - """ - wraps astunparse to fix 2/3 compatibility minor issues - - Notes: - x = np.random.rand(3, 3) - # In python3 this works, but it fails in python2 - x[(..., 2)] - # However, this works in both - x[(Ellipsis, 2)] - # Interestingly, this also works, but is not how astunparse generates code - x[..., 2] - """ - def _Ellipsis(self, t): - # be compatible with python2 if possible - self.write("Ellipsis") - - -def unparse(tree): - """ wraps astunparse to fix 2/3 compatibility minor issues """ - v = cStringIO() - Unparser(tree, file=v) - return v.getvalue() - - -def source_closure(obj, expand_names=[]): - """ - Pulls the minimum amount of code needed to define `obj`. Uses a - combination of dynamic and static introspection. - - Args: - obj (type): the class whose definition will be exported. - - expand_names (List[str]): - EXPERIMENTAL. List of modules that should be expanded into raw - source code. - - Returns: - str: closed_sourcecode: text defining a new python module. - - CommandLine: - xdoctest -m netharn.export.closer source_closure - - Example: - >>> import torchvision - >>> from torchvision import models - >>> got = {} - - >>> model_class = models.AlexNet - >>> text = source_closure(model_class) - >>> assert not undefined_names(text) - >>> got['alexnet'] = ub.hash_data(text) - - >>> model_class = models.DenseNet - >>> text = source_closure(model_class) - >>> assert not undefined_names(text) - >>> got['densenet'] = ub.hash_data(text) - - >>> model_class = models.resnet50 - >>> text = source_closure(model_class) - >>> assert not undefined_names(text) - >>> got['resnet50'] = ub.hash_data(text) - - >>> model_class = models.Inception3 - >>> text = source_closure(model_class) - >>> assert not undefined_names(text) - >>> got['inception3'] = ub.hash_data(text) - - >>> # The hashes will depend on torchvision itself - >>> if torchvision.__version__ == '0.2.1': - >>> # Note: the hashes may change if the exporter changes formats - >>> want = { - >>> 'alexnet': '4b2ab9c8e27b34602bdff99cbc', - >>> 'densenet': 'fef4788586d2b93587ec52dd9', - >>> 'resnet50': '343e6a73e754557fcce3fdb6', - >>> 'inception3': '2e43a58133d0817753383', - >>> } - >>> failed = [] - >>> for k in want: - >>> if not got[k].startswith(want[k]): - >>> item = (k, got[k], want[k]) - >>> print('failed item = {!r}'.format(item)) - >>> failed.append(item) - >>> assert not failed, str(failed) - >>> else: - >>> warnings.warn('Unsupported version of torchvision') - - Example: - >>> # Test a heavier duty class - >>> from netharn.export.closer import * - >>> import netharn as nh - >>> obj = nh.layers.ConvNormNd - >>> expand_names = ['netharn'] - >>> text = source_closure(obj, expand_names) - >>> print(text) - - Ignore: - import netharn as nh - obj = nh.models.yolo2.yolo2.Yolo2 - expand_names = ['netharn'] - expand_names = [] - - print(chr(10).join(closer.logs)) - """ - closer = Closer() - - # First try to add statically (which tends to be slightly nicer) - try: - try: - name = obj.__name__ - modpath = sys.modules[obj.__module__].__file__ - except Exception: - # Otherwise add dynamically - closer.add_dynamic(obj) - else: - closer.add_static(name, modpath) - if expand_names: - closer.expand(expand_names) - closed_sourcecode = closer.current_sourcecode() - except Exception: - print('ERROR IN CLOSING') - print('[[[ START CLOSE LOGS ]]]') - print('closer.logs =\n{}'.format('\n'.join(closer.logs))) - print('[[[ END CLOSE LOGS ]]]') - raise - return closed_sourcecode - - -class Closer(ub.NiceRepr): - """ - Maintains the current state of the source code - - There are 3 major steps: - (a) extract the code to that defines a function or class from a module, - (b) go back to the module and extract extra code required to define any - names that were undefined in the extracted code, and - (c) replace import statements to specified "expand" modules with the actual code - used to define the variables accessed via the imports. - - This results in a standalone file that has absolutely no dependency on the - original module or the specified "expand" modules (the expand module is - usually the module that is doing the training for a network. This means - that you can deploy a model independant of the training framework). - - Note: - This is not designed to work for cases where the code depends on logic - executed in a global scope (e.g. dynamically registering properties) . - I think its actually impossible to statically account for this case in - general. - - Ignore: - >>> from netharn.export.closer import * - >>> import netharn as nh - >>> import fastai.vision - >>> obj = fastai.vision.models.WideResNet - >>> expand_names = ['fastai'] - >>> closer = Closer() - >>> closer.add_dynamic(obj) - >>> closer.expand(expand_names) - >>> #print(ub.repr2(closer.body_defs, si=1)) - >>> print(closer.current_sourcecode()) - - Ignore: - >>> from netharn.export.closer import * - >>> import netharn as nh - >>> from netharn.models.yolo2 import yolo2 - >>> obj = yolo2.Yolo2 - >>> expand_names = ['netharn'] - >>> closer = Closer() - >>> closer.add_static(obj.__name__, sys.modules[obj.__module__].__file__) - >>> closer.expand(expand_names) - >>> #print(ub.repr2(closer.body_defs, si=1)) - >>> print(closer.current_sourcecode()) - """ - def __init__(closer, tag='root'): - closer.header_defs = ub.odict() - closer.body_defs = ub.odict() - closer.visitors = {} - closer.tag = tag - - closer.logs = [] - closer._log_indent = '' - - def debug(closer, msg): - closer.logs.append(closer._log_indent + msg) - - def __nice__(self): - return self.tag - - def _add_definition(closer, d): - closer.debug('_add_definition = {!r}'.format(d)) - import copy - d = copy.deepcopy(d) - # print('ADD DEFINITION d = {!r}'.format(d)) - if 'Import' in d.type: - if d.absname in closer.header_defs: - del closer.header_defs[d.absname] - closer.header_defs[d.absname] = d - else: - if d.absname in closer.body_defs: - del closer.body_defs[d.absname] - closer.body_defs[d.absname] = d - - def current_sourcecode(self): - header_lines = [d.code for d in self.header_defs.values()] - body_lines = [d.code for d in self.body_defs.values()][::-1] - current_sourcecode = '\n'.join(header_lines) - current_sourcecode += '\n\n\n' - current_sourcecode += '\n\n\n'.join(body_lines) - return current_sourcecode - - def add_dynamic(closer, obj): - """ - Add the source to define a live python object - """ - closer.debug('closer.add_dynamic(obj={!r})'.format(obj)) - modname = obj.__module__ - module = sys.modules[modname] - - name = obj.__name__ - - modpath = module.__file__ - if modpath not in closer.visitors: - visitor = ImportVisitor.parse(module=module, modpath=modpath) - closer.visitors[modpath] = visitor - visitor = closer.visitors[modpath] - - d = visitor.extract_definition(name) - closer._add_definition(d) - closer.close(visitor) - - def add_static(closer, name, modpath): - # print('ADD_STATIC name = {} from {}'.format(name, modpath)) - closer.debug('closer.add_static(name={!r}, modpath={!r})'.format(name, modpath)) - if modpath not in closer.visitors: - visitor = ImportVisitor.parse(modpath=modpath) - closer.visitors[modpath] = visitor - visitor = closer.visitors[modpath] - - d = visitor.extract_definition(name) - closer._add_definition(d) - - closer.close(visitor) - - def close(closer, visitor): - """ - Populate all undefined names using the context from a module - """ - # Parse the parent module to find only the relevant global varaibles and - # include those in the extracted source code. - closer.debug('closing') - current_sourcecode = closer.current_sourcecode() - - # Loop until all undefined names are defined - names = True - while names: - # Determine if there are any variables needed from the parent scope - current_sourcecode = closer.current_sourcecode() - # Make sure we process names in the same order for hashability - prev_names = names - names = sorted(undefined_names(current_sourcecode)) - closer.debug(' * undefined_names = {}'.format(names)) - if names == prev_names: - print('visitor.definitions = {}'.format(ub.repr2(visitor.definitions, si=1))) - if DEBUG: - warnings.warn('We were unable do do anything about undefined names') - return - else: - current_sourcecode = closer.current_sourcecode() - print('--- ---') - print('Unable to define names') - print(' * names = {!r}'.format(names)) - print('<<< CURRENT_SOURCE >>>\n{}\n<<<>>>'.format(ub.highlight_code(current_sourcecode))) - print('--- ---') - raise AssertionError('unable to define names: {}'.format(names)) - for name in names: - try: - try: - closer.debug(' * try visitor.extract_definition({})'.format(names)) - d = visitor.extract_definition(name) - except KeyError as ex: - closer.debug(' * encountered issue: {!r}'.format(ex)) - # There is a corner case where we have the definition, - # we just need to move it to the top. - flag = False - for d_ in closer.body_defs.values(): - if name == d_.name: - closer.debug(' * corner case: move definition to top') - closer._add_definition(d_) - flag = True - break - if not flag: - raise - else: - closer.debug(' * add extracted def {}'.format(name)) - closer._add_definition(d) - # type_, text = visitor.extract_definition(name) - except Exception as ex: - closer.debug(' * unable to extracted def {} due to {!r}'.format(name, ex)) - current_sourcecode = closer.current_sourcecode() - print('--- ---') - print('Error computing source code extract_definition') - print(' * failed to close name = {!r}'.format(name)) - # print('<<< CURRENT_SOURCE >>>\n{}\n<<<>>>'.format(ub.highlight_code(current_sourcecode))) - print('--- ---') - if not HACK_FIX_CANNOT_FIND_HIDDEN: - raise - - def expand(closer, expand_names): - """ - Experimental feature. Remove all references to specific modules by - directly copying in the referenced source code. If the code is - referenced from a module, then the references will need to change as - well. - - TODO: - - [ ] Add special unique (mangled) suffixes to all expanded names - to avoid name conflicts. - - Args: - expand_name (List[str]): list of module names. For each module - we expand any reference to that module in the closed source - code by directly copying the referenced code into that file. - This doesn't work in all cases, but it usually does. - Reasons why this wouldn't work include trying to expand - import from C-extension modules and expanding modules with - complicated global-level logic. - - Ignore: - >>> # Test a heavier duty class - >>> from netharn.export.closer import * - >>> import netharn as nh - >>> obj = nh.device.MountedModel - >>> #obj = nh.layers.ConvNormNd - >>> obj = nh.data.CocoDataset - >>> #expand_names = ['ubelt', 'progiter'] - >>> closer = Closer() - >>> closer.add_dynamic(obj) - >>> closer.expand(expand_names) - >>> #print('header_defs = ' + ub.repr2(closer.header_defs, si=1)) - >>> #print('body_defs = ' + ub.repr2(closer.body_defs, si=1)) - >>> print('SOURCE:') - >>> text = closer.current_sourcecode() - >>> print(text) - """ - closer.debug("!!! EXPANDING") - # Expand references to internal modules - flag = True - while flag: - - # Associate all top-level modules with any possible expand_name - # that might trigger them to be expanded. Note this does not - # account for nested imports. - expandable_definitions = ub.ddict(list) - for d in closer.header_defs.values(): - parts = d.native_modname.split('.') - for i in range(1, len(parts) + 1): - root = '.'.join(parts[:i]) - expandable_definitions[root].append(d) - - closer.debug('expandable_definitions = {!r}'.format( - list(expandable_definitions.keys()))) - - flag = False - # current_sourcecode = closer.current_sourcecode() - # closed_visitor = ImportVisitor.parse(source=current_sourcecode) - for root in expand_names: - needs_expansion = expandable_definitions.get(root, []) - - closer.debug('root = {!r}'.format(root)) - closer.debug('needs_expansion = {!r}'.format(needs_expansion)) - for d in needs_expansion: - if d._expanded: - continue - flag = True - # if d.absname == d.native_modname: - if ub.modname_to_modpath(d.absname): - closer.debug('TODO: NEED TO CLOSE module = {}'.format(d)) - # import warnings - # warnings.warn('Closing module {} may not be implemented'.format(d)) - # definition is a module, need to expand its attributes - closer.expand_module_attributes(d) - d._expanded = True - else: - closer.debug('TODO: NEED TO CLOSE attribute varname = {}'.format(d)) - import warnings - # warnings.warn('Closing attribute {} may not be implemented'.format(d)) - # definition is a non-module, directly copy in its code - # We can directly replace this import statement by - # copy-pasting the relevant code from the other module - # (ASSUMING THERE ARE NO NAME CONFLICTS) - - assert d.type == 'ImportFrom' - - try: - native_modpath = ub.modname_to_modpath(d.native_modname) - if native_modpath is None: - raise Exception('Cannot find the module path for modname={!r}. ' - 'Are you missing an __init__.py?'.format(d.native_modname)) - sub_closer = Closer(closer.tag + '.sub') - sub_closer.add_static(d.name, native_modpath) - # sub_visitor = sub_closer.visitors[d.native_modname] - sub_closer.expand(expand_names) - # sub_closer.close(sub_visitor) - except NotAPythonFile as ex: - warnings.warn('CANNOT EXPAND d = {!r}, REASON: {}'.format(d, repr(ex))) - d._expanded = True - raise - continue - except Exception as ex: - warnings.warn('CANNOT EXPAND d = {!r}, REASON: {}'.format(d, repr(ex))) - d._expanded = True - raise - continue - # raise - - # Hack: remove the imported definition and add the explicit definition - # TODO: FIXME: more robust modification and replacement - d._code = '# ' + d.code - d._expanded = True - - for d_ in sub_closer.header_defs.values(): - closer._add_definition(d_) - for d_ in sub_closer.body_defs.values(): - closer._add_definition(d_) - - # print('sub_visitor = {!r}'.format(sub_visitor)) - # closer.close(sub_visitor) - closer.debug('CLOSED attribute d = {}'.format(d)) - - def expand_module_attributes(closer, d): - """ - Args: - d (Definition): the definition to expand - """ - # current_sourcecode = closer.current_sourcecode() - # closed_visitor = ImportVisitor.parse(source=current_sourcecode) - assert 'Import' in d.type - varname = d.name - varmodpath = ub.modname_to_modpath(d.absname) - modname = d.absname - - def _exhaust(varname, modname, modpath): - closer.debug('REWRITE ACCESSOR varname={!r}, modname={}, modpath={}'.format(varname, modname, modpath)) - - # Modify the current node definitions and recompute code - # TODO: make more robust - rewriter = RewriteModuleAccess(varname) - for d_ in closer.body_defs.values(): - rewriter.visit(d_.node) - d_._code = unparse(d_.node) - - closer.debug('rewriter.accessed_attrs = {!r}'.format(rewriter.accessed_attrs)) - - # For each modified attribute, copy in the appropriate source. - for subname in rewriter.accessed_attrs: - submodname = modname + '.' + subname - submodpath = ub.modname_to_modpath(submodname) - if submodpath is not None: - # if the accessor is to another module, exhaust until - # we reach a non-module - closer.debug('EXAUSTING: {}, {}, {}'.format(subname, submodname, submodpath)) - _exhaust(subname, submodname, submodpath) - else: - # Otherwise we can directly add the referenced attribute - closer.debug('FINALIZE: {} from {}'.format(subname, modpath)) - closer.add_static(subname, modpath) - - _exhaust(varname, modname, varmodpath) - d._code = '# ' + d.code - - -def _parse_static_node_value(node): - """ - Extract a constant value from a node if possible - """ - if isinstance(node, ast.Num): - value = node.n - elif isinstance(node, ast.Str): - value = node.s - elif isinstance(node, ast.List): - value = list(map(_parse_static_node_value, node.elts)) - elif isinstance(node, ast.Tuple): - value = tuple(map(_parse_static_node_value, node.elts)) - elif isinstance(node, (ast.Dict)): - keys = map(_parse_static_node_value, node.keys) - values = map(_parse_static_node_value, node.values) - value = OrderedDict(zip(keys, values)) - # value = dict(zip(keys, values)) - elif six.PY3 and isinstance(node, (ast.NameConstant)): - value = node.value - elif (six.PY2 and isinstance(node, ast.Name) and - node.id in ['None', 'True', 'False']): - # disregard pathological python2 corner cases - value = {'None': None, 'True': True, 'False': False}[node.id] - else: - msg = ('Cannot parse a static value from non-static node ' - 'of type: {!r}'.format(type(node))) - # print('node.__dict__ = {!r}'.format(node.__dict__)) - # print('msg = {!r}'.format(msg)) - raise TypeError(msg) - return value - - -def undefined_names(sourcecode): - """ - Parses source code for undefined names - - Example: - >>> print(ub.repr2(undefined_names('x = y'), nl=0)) - {'y'} - """ - import pyflakes.api - import pyflakes.reporter - - class CaptureReporter(pyflakes.reporter.Reporter): - def __init__(reporter, warningStream, errorStream): - reporter.syntax_errors = [] - reporter.messages = [] - reporter.unexpected = [] - - def unexpectedError(reporter, filename, msg): - reporter.unexpected.append(msg) - - def syntaxError(reporter, filename, msg, lineno, offset, text): - reporter.syntax_errors.append(msg) - - def flake(reporter, message): - reporter.messages.append(message) - - names = set() - - reporter = CaptureReporter(None, None) - pyflakes.api.check(sourcecode, '_.py', reporter) - for msg in reporter.messages: - if msg.__class__.__name__.endswith('UndefinedName'): - assert len(msg.message_args) == 1 - names.add(msg.message_args[0]) - return names - - -class RewriteModuleAccess(ast.NodeTransformer): - """ - Refactors attribute accesses into top-level references. - In other words, instances of . change to . - - Any attributes that were modified are stored in `accessed_attrs`. - - Example: - >>> from netharn.export.closer import * - >>> source = ub.codeblock( - ... ''' - ... foo.bar = 3 - ... foo.baz.bar = 3 - ... biz.foo.baz.bar = 3 - ... ''') - >>> pt = ast.parse(source) - >>> visitor = RewriteModuleAccess('foo') - >>> orig = unparse(pt) - >>> print(orig) - foo.bar = 3 - foo.baz.bar = 3 - biz.foo.baz.bar = 3 - >>> visitor.visit(pt) - >>> modified = unparse(pt) - >>> print(modified) - bar = 3 - baz.bar = 3 - biz.foo.baz.bar = 3 - >>> visitor.accessed_attrs - ['bar', 'baz'] - """ - def __init__(self, modname): - self.modname = modname - self.level = 0 - self.accessed_attrs = [] - - def visit_Import(self, node): - # if self.level == 0: - # return None - return node - - def visit_ImportFrom(self, node): - # if self.level == 0: - # return None - return node - - def visit_FunctionDef(self, node): - self.level += 1 - self.generic_visit(node) - self.level -= 1 - return node - - def visit_ClassDef(self, node): - self.level += 1 - self.generic_visit(node) - self.level -= 1 - return node - - def visit_Attribute(self, node): - # print('VISIT ATTR: node = {!r}'.format(node.__dict__)) - self.generic_visit(node) - if isinstance(node.value, ast.Name): - if node.value.id == self.modname: - self.accessed_attrs.append(node.attr) - new_node = ast.Name(node.attr, node.ctx) - old_node = node - return ast.copy_location(new_node, old_node) - return node - - -class Definition(ub.NiceRepr): - def __init__(self, name, node, type=None, code=None, absname=None, - modpath=None, modname=None, native_modname=None): - self.name = name - self.node = node - self.type = type - self._code = code - self.absname = absname - self.modpath = modpath - self.modname = modname - self.native_modname = native_modname - self._expanded = False - - @property - def code(self): - if self._code is None: - try: - if self._expanded or self.type == 'Assign': - # always use astunparse if we have expanded - raise Exception - # Attempt to dynamically extract the source code because it - # keeps formatting better. - module = ub.import_module_from_name(self.modname) - obj = getattr(module, self.name) - self._code = inspect.getsource(obj).strip('\n') - except Exception: - # Fallback on static sourcecode extraction - # (NOTE: it should be possible to keep formatting with a bit of - # work) - self._code = unparse(self.node).strip('\n') - return self._code - - def __nice__(self): - parts = [] - parts.append('name={}'.format(self.name)) - parts.append('type={}'.format(self.type)) - if self.absname is not None: - parts.append('absname={}'.format(self.absname)) - if self.native_modname is not None: - parts.append('native_modname={}'.format(self.native_modname)) - return ', '.join(parts) - - -class NotAPythonFile(ValueError): - pass - - -class ImportVisitor(ast.NodeVisitor, ub.NiceRepr): - """ - Used to search for dependencies in the original module - - References: - https://greentreesnakes.readthedocs.io/en/latest/nodes.html - - Example: - >>> from netharn.export.closer import * - >>> from netharn.export import closer - >>> modpath = closer.__file__ - >>> sourcecode = ub.codeblock( - ... ''' - ... from ubelt.util_const import * - ... import a - ... import b - ... import c.d - ... import e.f as g - ... from . import h - ... from .i import j - ... from . import k, l, m - ... from n import o, p, q - ... r = 3 - ... ''') - >>> visitor = ImportVisitor.parse(source=sourcecode, modpath=modpath) - >>> print(ub.repr2(visitor.definitions, si=1)) - """ - - def __init__(visitor, modpath=None, modname=None, module=None, pt=None): - super(ImportVisitor, visitor).__init__() - visitor.pt = pt - visitor.modpath = modpath - visitor.modname = modname - visitor.module = module - - visitor.definitions = {} - visitor.top_level = True - - def __nice__(self): - if self.modname is not None: - return self.modname - else: - return "" - - @classmethod - def parse(ImportVisitor, source=None, modpath=None, modname=None, - module=None): - if module is not None: - if source is None: - source = inspect.getsource(module) - if modpath is None: - modname = module.__file__ - if modname is None: - modname = module.__name__ - - if modpath is not None: - if modpath.endswith('.pyc'): - modpath = modpath.replace('.pyc', '.py') # python 2 hack - - if isdir(modpath): - modpath = join(modpath, '__init__.py') - if modname is None: - modname = ub.modpath_to_modname(modpath) - - if modpath is not None: - if source is None: - if not modpath.endswith(('.py', '>')): - raise NotAPythonFile('can only parse python files, not {}'.format(modpath)) - source = open(modpath, 'r').read() - - if source is None: - raise ValueError('unable to derive source code') - - source = ub.ensure_unicode(source) - if six.PY2: - try: - pt = ast.parse(source) - except SyntaxError as ex: - if 'encoding declaration in Unicode string' in ex.args[0]: - pt = ast.parse(source.encode()) - else: - raise - else: - pt = ast.parse(source) - visitor = ImportVisitor(modpath, modname, module, pt=pt) - visitor.visit(pt) - return visitor - - def extract_definition(visitor, name): - """ - Given the name of a variable / class / function / moodule, extract the - relevant lines of source code that define that structure from the - visited module. - """ - return visitor.definitions[name] - - def visit_Import(visitor, node): - for d in visitor._import_definitions(node): - visitor.definitions[d.name] = d - visitor.generic_visit(node) - - def visit_ImportFrom(visitor, node): - for d in visitor._import_from_definition(node): - visitor.definitions[d.name] = d - visitor.generic_visit(node) - - def visit_Assign(visitor, node): - for target in node.targets: - key = getattr(target, 'id', None) - if key is not None: - try: - static_val = _parse_static_node_value(node.value) - code = '{} = {}'.format(key, ub.repr2(static_val)) - except TypeError: - #code = unparse(node).strip('\n') - code = None - - if DEBUG: - if key in visitor.definitions: - # OVERLOADED - print('OVERLOADED key = {!r}'.format(key)) - - visitor.definitions[key] = Definition( - key, node, code=code, type='Assign', - modpath=visitor.modpath, - modname=visitor.modname, - absname=visitor.modname + '.' + key, - native_modname=visitor.modname, - ) - - def visit_FunctionDef(visitor, node): - visitor.definitions[node.name] = Definition( - node.name, node, type='FunctionDef', - modpath=visitor.modpath, - modname=visitor.modname, - absname=visitor.modname + '.' + node.name, - native_modname=visitor.modname, - ) - # Ignore any non-top-level imports - if not visitor.top_level: - visitor.generic_visit(node) - # ast.NodeVisitor.generic_visit(visitor, node) - - def visit_ClassDef(visitor, node): - visitor.definitions[node.name] = Definition( - node.name, node, type='ClassDef', - modpath=visitor.modpath, - modname=visitor.modname, - absname=visitor.modname + '.' + node.name, - native_modname=visitor.modname, - ) - # Ignore any non-top-level imports - if not visitor.top_level: - visitor.generic_visit(node) - # ast.NodeVisitor.generic_visit(visitor, node) - - def _import_definitions(visitor, node): - for alias in node.names: - varname = alias.asname or alias.name - if alias.asname: - line = 'import {} as {}'.format(alias.name, alias.asname) - else: - line = 'import {}'.format(alias.name) - absname = alias.name - yield Definition(varname, node, code=line, - absname=absname, - native_modname=absname, - modpath=visitor.modpath, - modname=visitor.modname, - type='Import') - - def _import_from_definition(visitor, node): - """ - Ignore: - from netharn.export.closer import * - visitor = ImportVisitor.parse(module=module) - print('visitor.definitions = {}'.format(ub.repr2(visitor.definitions, sv=1))) - """ - if node.level: - # Handle relative imports - if visitor.modpath is not None: - try: - rel_modpath = ub.split_modpath(abspath(visitor.modpath))[1] - except ValueError: - warnings.warn('modpath={} does not exist'.format(visitor.modpath)) - rel_modpath = basename(abspath(visitor.modpath)) - modparts = rel_modpath.replace('\\', '/').split('/') - parts = modparts[:-node.level] - prefix = '.'.join(parts) - if node.module: - prefix = prefix + '.' - else: - warnings.warn('Unable to rectify absolute import') - prefix = '.' * node.level - else: - prefix = '' - - if node.module is not None: - abs_modname = prefix + node.module - else: - abs_modname = prefix - - for alias in node.names: - varname = alias.asname or alias.name - if alias.asname: - line = 'from {} import {} as {}'.format(abs_modname, alias.name, alias.asname) - else: - line = 'from {} import {}'.format(abs_modname, alias.name) - absname = abs_modname + '.' + alias.name - if varname == '*': - # HACK - abs_modpath = ub.modname_to_modpath(abs_modname) - for d in ImportVisitor.parse(modpath=abs_modpath).definitions.values(): - if not d.name.startswith('_'): - yield d - else: - yield Definition(varname, node, code=line, absname=absname, - modpath=visitor.modpath, - modname=visitor.modname, - native_modname=abs_modname, - type='ImportFrom') - - -def _closefile(fpath, modnames): - """ - An api to remove dependencies from code by "closing" them. - - CommandLine: - xdoctest -m ~/code/netharn/netharn/export/closer.py _closefile - xdoctest -m netharn.export.closer _closefile --fpath=~/code/boltons/tests/test_cmdutils.py --modnames=ubelt, - - Example: - >>> # SCRIPT - >>> # ENTRYPOINT - >>> import scriptconfig as scfg - >>> config = scfg.quick_cli({ - >>> 'fpath': scfg.Path(None), - >>> 'modnames': scfg.Value([]), - >>> }) - >>> fpath = config['fpath'] = ub.expandpath('~/code/boltons/tests/test_cmdutils.py') - >>> modnames = config['modnames'] = ['ubelt'] - >>> _closefile(**config) - """ - from xdoctest import static_analysis as static - modpath = fpath - expand_names = modnames - source = open(fpath, 'r').read() - calldefs = static.parse_calldefs(source, fpath) - calldefs.pop('__doc__', None) - - closer = Closer() - for key in calldefs.keys(): - closer.add_static(key, modpath) - closer.expand(expand_names) - #print(ub.repr2(closer.body_defs, si=1)) - print(closer.current_sourcecode()) +warnings.warn('netharn.export.closer is deprecated, use liberator.closer intead', DeprecationWarning) diff --git a/netharn/export/deployer.py b/netharn/export/deployer.py index 4d37988d687381d62487ab4cb1d32731a19847bb..affb792f7fcfbc7d0f8af2a1554cae2823e7e280 100644 --- a/netharn/export/deployer.py +++ b/netharn/export/deployer.py @@ -1,733 +1,4 @@ # -*- coding: utf-8 -*- -""" -Deployment component of the Pytorch exporter. - -This file contains DeployedModel, which consists of logic to take the -model topology definition along with the "best" snapshot in a training -directory and package it up into a standalone zipfile. The DeployedModel can -also be used to reload model from this zipfile. Thus this zipfile can be passed -around as a pytorch model topology+pretrained weights transfer format. - -The following docstring illustrates how this module may be used. - -CommandLine: - # Runs the following example - xdoctest -m netharn.export.deployer __doc__:0 - - # Runs all the doctests - xdoctest -m netharn.export.deployer all - -Example: - >>> # xdoc: +IGNORE_WANT - >>> # This example will train a small model and then deploy it. - >>> import netharn as nh - >>> # - >>> ################################################# - >>> print('--- STEP 1: TRAIN A MODEL ---') - >>> # This will train a toy model with toy data using netharn - >>> hyper = nh.HyperParams(**{ - >>> 'workdir' : ub.ensure_app_cache_dir('netharn/tests/deploy'), - >>> 'name' : 'deploy_demo', - >>> 'xpu' : nh.XPU.coerce('cpu'), - >>> 'datasets' : { - >>> 'train': nh.data.ToyData2d(size=3, border=1, n=256, rng=0), - >>> 'test': nh.data.ToyData2d(size=3, border=1, n=128, rng=1), - >>> }, - >>> 'loaders' : {'batch_size': 64}, - >>> 'model' : (nh.models.ToyNet2d, {}), - >>> 'optimizer' : (nh.optimizers.SGD, {'lr': 0.0001}), - >>> 'criterion' : (nh.criterions.CrossEntropyLoss, {}), - >>> 'initializer' : (nh.initializers.KaimingNormal, {}), - >>> 'scheduler' : (nh.schedulers.ListedLR, { - >>> 'points': {0: .01, 3: 0.1}, - >>> 'interpolate': True, - >>> }), - >>> 'monitor' : (nh.Monitor, {'max_epoch': 3,}), - >>> }) - >>> harn = nh.FitHarn(hyper) - >>> harn.preferences['use_tensorboard'] = False - >>> harn.preferences['timeout'] = 1 - >>> harn.intervals['test'] = 1 - >>> harn.initialize(reset='delete') - >>> harn.run() - --- STEP 1: TRAIN A MODEL --- - RESET HARNESS BY DELETING EVERYTHING IN TRAINING DIR - Symlink: .../.cache/netharn/tests/deploy/fit/runs/deploy_demo/onnxqaww -> .../.cache/netharn/tests/deploy/_mru - ...... - Symlink: .../.cache/netharn/tests/deploy/fit/runs/deploy_demo/onnxqaww -> .../.cache/netharn/tests/deploy/fit/nice/deploy_demo - ...... - INFO: Model has 824 parameters - INFO: Mounting ToyNet2d model on CPU - INFO: Exported model topology to .../.cache/netharn/tests/deploy/fit/runs/deploy_demo/onnxqaww/ToyNet2d_2a3f49.py - INFO: Initializing model weights with: - INFO: * harn.train_dpath = '.../.cache/netharn/tests/deploy/fit/runs/deploy_demo/onnxqaww' - INFO: * harn.name_dpath = '.../.cache/netharn/tests/deploy/fit/name/deploy_demo' - INFO: Snapshots will save to harn.snapshot_dpath = '.../.cache/netharn/tests/deploy/fit/runs/deploy_demo/onnxqaww/torch_snapshots' - INFO: ARGV: - .../.local/conda/envs/py36/bin/python .../.local/conda/envs/py36/bin/ipython - INFO: === begin training 0 / 3 : deploy_demo === - epoch lr:0.01 │ vloss is unevaluated 0/3... rate=0 Hz, eta=?, total=0:00:00, wall=19:32 EST - train loss:0.717 │ 100.00% of 64x8... rate=2093.02 Hz, eta=0:00:00, total=0:00:00, wall=19:32 EST - test loss:0.674 │ 100.00% of 64x4... rate=14103.48 Hz, eta=0:00:00, total=0:00:00, wall=19:32 EST - Populating the interactive namespace from numpy and matplotlib - INFO: === finish epoch 0 / 3 : deploy_demo === - epoch lr:0.04 │ vloss is unevaluated 1/3... rate=0.87 Hz, eta=0:00:02, total=0:00:01, wall=19:32 EST - train loss:0.712 │ 100.00% of 64x8... rate=2771.29 Hz, eta=0:00:00, total=0:00:00, wall=19:32 EST - test loss:0.663 │ 100.00% of 64x4... rate=15867.59 Hz, eta=0:00:00, total=0:00:00, wall=19:32 EST - INFO: === finish epoch 1 / 3 : deploy_demo === - epoch lr:0.07 │ vloss is unevaluated 2/3... rate=1.04 Hz, eta=0:00:00, total=0:00:01, wall=19:32 EST - train loss:0.686 │ 100.00% of 64x8... rate=2743.56 Hz, eta=0:00:00, total=0:00:00, wall=19:32 EST - test loss:0.636 │ 100.00% of 64x4... rate=14332.63 Hz, eta=0:00:00, total=0:00:00, wall=19:32 EST - INFO: === finish epoch 2 / 3 : deploy_demo === - epoch lr:0.1 │ vloss is unevaluated 3/3... rate=1.11 Hz, eta=0:00:00, total=0:00:02, wall=19:32 EST - INFO: Maximum harn.epoch reached, terminating ... - INFO: - INFO: training completed - INFO: harn.train_dpath = '.../.cache/netharn/tests/deploy/fit/runs/deploy_demo/onnxqaww' - INFO: harn.name_dpath = '.../.cache/netharn/tests/deploy/fit/name/deploy_demo' - INFO: view tensorboard results for this run via: - tensorboard --logdir ~/.cache/netharn/tests/deploy/fit/name - [DEPLOYER] Deployed zipfpath=.../.cache/netharn/tests/deploy/fit/runs/deploy_demo/onnxqaww/deploy_ToyNet2d_onnxqaww_002_TXZBYL.zip - INFO: wrote single-file deployment to: '.../.cache/netharn/tests/deploy/fit/runs/deploy_demo/onnxqaww/deploy_ToyNet2d_onnxqaww_002_TXZBYL.zip' - INFO: exiting fit harness. - Out[1]: '.../.cache/netharn/tests/deploy/fit/runs/deploy_demo/onnxqaww/deploy_ToyNet2d_onnxqaww_002_TXZBYL.zip' - >>> # - >>> ########################################## - >>> print('--- STEP 2: DEPLOY THE MODEL ---') - >>> # First we export the model topology to a standalone file - >>> # (Note: this step is done automatically in `harn.run`, but we do - >>> # it again here for demo purposes) - >>> from netharn.export import exporter - >>> topo_fpath = exporter.export_model_code(harn.train_dpath, harn.hyper.model_cls, harn.hyper.model_params) - >>> # Now create an instance of deployed model that points to the - >>> # Training dpath. (Note the directory structure setup by netharn is - >>> # itself a deployment, it just has multiple files) - >>> deployer = DeployedModel(harn.train_dpath) - >>> # Use the DeployedModel to package the imporant info in train_dpath - >>> # into a standalone zipfile. - >>> zip_fpath = deployer.package() - >>> print('We exported the topology to: {!r}'.format(topo_fpath)) - >>> print('We exported the topology+weights to: {!r}'.format(zip_fpath)) - --- STEP 2: DEPLOY THE MODEL --- - We exported the topology to: '...tests/deploy/fit/runs/deploy_demo/onnxqaww/ToyNet2d_2a3f49.py' - We exported the topology+weights to: '...tests/deploy/fit/runs/deploy_demo/onnxqaww/deploy_ToyNet2d_onnxqaww_002_HVWCGI.zip' - >>> # - >>> ################################################# - >>> print('--- STEP 3: LOAD THE DEPLOYED MODEL ---') - >>> # Now we can move the zipfile anywhere we want, and we should - >>> # still be able to load it (depending on how coupled the model is). - >>> # Create an instance of DeployedModel that points to the zipfile - >>> # (Note: DeployedModel is used to both package and load models) - >>> loader = DeployedModel(zip_fpath) - >>> model = loader.load_model() - >>> # This model is now loaded with the corret weights. - >>> # You can use it as normal. - >>> model.eval() - >>> images = harn._demo_batch(0)['input'][0:1] - >>> outputs = model(images) - >>> print('outputs = {!r}'.format(outputs)) - >>> # Not that the loaded model is independent of harn.model - >>> print('model.__module__ = {!r}'.format(model.__module__)) - >>> print('harn.model.module.__module__ = {!r}'.format(harn.model.module.__module__)) - --- STEP 3: LOAD THE DEPLOYED MODEL --- - outputs = tensor([[0.4105, 0.5895]], grad_fn=) - model.__module__ = 'deploy_ToyNet2d_onnxqaww_002_HVWCGI/ToyNet2d_2a3f49' - harn.model.module.__module__ = 'netharn.models.toynet' -""" -from __future__ import absolute_import, division, print_function, unicode_literals -import glob -import json -import six -import ubelt as ub -# import warnings -import zipfile -import os -from os.path import exists -from os.path import isdir -from os.path import join -from os.path import relpath - -__all__ = ['DeployedModel'] - -if six.PY2: - FileNotFoundError = OSError - - -def existing_snapshots(train_dpath): - # NOTE: Specific to netharn directory structure - import parse - snapshot_dpath = join(train_dpath, 'torch_snapshots/') - prev_states = sorted(glob.glob(join(snapshot_dpath, '_epoch_*.pt'))) - snapshots = {parse.parse('{}_epoch_{num:d}.pt', path).named['num']: path - for path in prev_states} - return snapshots - - -def find_best_snapshot(train_dpath): - """ - Returns snapshot written by monitor if available otherwise takes the last - one. - """ - # NOTE: Specific to netharn directory structure - # Netharn should populate best_snapshot.pt if there is a validation set. - # Other names are to support older codebases. - expected_names = [ - 'best_snapshot.pt', - 'best_snapshot2.pt', - 'final_snapshot.pt', - 'deploy_snapshot.pt', - ] - for snap_fname in expected_names: - snap_fpath = join(train_dpath, snap_fname) - if exists(snap_fpath): - break - - if not exists(snap_fpath): - snap_fpath = None - - if not snap_fpath: - epoch_to_fpath = existing_snapshots(train_dpath) - if epoch_to_fpath: - snap_fpath = epoch_to_fpath[max(epoch_to_fpath)] - return snap_fpath - - -def unpack_model_info(path): - """ - return paths to the most relevant files in a zip or path deployment. - - If path is not a zipfile, this function expects a netharn fit directory - structure. - - Args: - path (PathLike): either a zip deployment or train_dpath. - Preferably this is a zip deployment file or a path to an unzipped - deploy file. If this is a train_dpath, then it should at least - contain a model topology py file and snapshot pt file, otherwise - subsequent usage will likely fail. - """ - info = { - 'train_info_fpath': None, - 'snap_fpath': None, - 'model_fpath': None, - - # TODO: need to rename and allow a list of arbitrary files - 'glance': [], # a list of files in the glance directory - } - def populate(root, fpaths): - # TODO: make more robust - for fpath in fpaths: - # FIXME: make this more general and robust - if fpath.endswith('.json'): - info['train_info_fpath'] = join(root, fpath) - if fpath.endswith('.pt'): - info['snap_fpath'] = join(root, fpath) - if fpath.endswith('.py'): - new_fpath = join(root, fpath) - if info['model_fpath'] is not None: - try: - # Try to take the most recent path if possible. - # This will fail if the file is in a zipfile - # (because we should not package multiple models) - cur_time = os.stat(info['model_fpath']).st_mtime - new_time = os.stat(new_fpath).st_mtime - if new_time < cur_time: - continue # Keep the current path - except OSError: - raise Exception( - 'Multiple model paths! {} and {}'.format( - info['model_fpath'], fpath)) - info['model_fpath'] = new_fpath - # TODO: make including arbitrary files easier - if fpath.startswith(('glance/', 'glance\\')): - info['glance'].append(join(root, fpath)) - - if path.endswith('.zip'): - zipfpath = path - myzip = zipfile.ZipFile(zipfpath, 'r') - with zipfile.ZipFile(zipfpath, 'r') as myzip: - populate(zipfpath, (f.filename for f in myzip.filelist)) - - elif exists(path) and isdir(path): - # Populate core files - populate(path, os.listdir(path)) - # Populate extra glanceable files - populate(path, [ - relpath(p, path) for p in glob.glob(join(path, 'glance/*'))]) - # If there are no snapshots in the root directory, then - # use the latest snapshot from the torch_snapshots dir - if info['snap_fpath'] is None: - info['snap_fpath'] = find_best_snapshot(path) - - else: - raise ValueError('cannot unpack model ' + path) - return info - - -def _make_package_name2(info): - """ - Construct a unique and descriptive name for the deployment - """ - snap_fpath = info['snap_fpath'] - model_fpath = info['model_fpath'] - train_info_fpath = info['train_info_fpath'] - - if train_info_fpath and exists(train_info_fpath): - train_info = json.load(open(train_info_fpath, 'r')) - model_name = train_info['hyper']['model'][0].split('.')[-1] - train_hash = ub.hash_data(train_info['train_id'], hasher='sha512', - base='abc', types=True)[0:8] - else: - model_name = os.path.splitext(os.path.basename(model_fpath))[0] - train_hash = 'UNKNOWN-TRAINID' - print('WARNING: Train info metadata does not exist') - - try: - # netharn models contain epoch info in the weights file - import torch - state = torch.load(snap_fpath, - map_location=lambda storage, location: storage) - epoch = '{:03d}'.format(state['epoch']) - except Exception: - epoch = 'UNKNOWN-EPOCH' - - weights_hash = ub.hash_file(snap_fpath, base='abc', - hasher='sha512')[0:6].upper() - - deploy_name = 'deploy_{model}_{trainid}_{epoch}_{weights}'.format( - model=model_name, trainid=train_hash, epoch=epoch, - weights=weights_hash) - return deploy_name - - -def _package_deploy2(dpath, info, name=None): - """ - Combine the model, weights, and info files into a single deployable file - - Args: - dpath (PathLike): where to dump the deployment - info (Dict): containing model_fpath and snap_fpath and optionally - train_info_fpath and glance, which is a list of extra files. - name (str, default=None): the name of the zipfile to deploy to. - If not specified, one will be constructed. - - Ignore: - dpath = '/home/joncrall/.cache/netharn/tests/_package_custom' - path = '/home/joncrall/work/opir/fit/name/_Sim3-kw6-99-finetune_ML3D_BEST_2018-9-20_LR1e-4_f2_vel0.0_hn0.25_bs64_nr5.0' - info = unpack_model_info(path) - zipfpath = _package_deploy2(dpath, info) - - - """ - model_fpath = info['model_fpath'] - snap_fpath = info['snap_fpath'] - train_info_fpath = info.get('train_info_fpath', None) - - if not snap_fpath: - raise FileNotFoundError('No weights are associated with the model') - - if name is None: - deploy_name = _make_package_name2(info) - deploy_fname = deploy_name + '.zip' - else: - if not name.endswith('.zip'): - raise ValueError('The deployed package name must end in .zip') - deploy_name = os.path.splitext(name)[0] - deploy_fname = name - - def zwrite(myzip, fpath, fname=None): - if fname is None: - fname = relpath(fpath, dpath) - myzip.write(fpath, arcname=join(deploy_name, fname)) - - zipfpath = join(dpath, deploy_fname) - with zipfile.ZipFile(zipfpath, 'w') as myzip: - if train_info_fpath and exists(train_info_fpath): - zwrite(myzip, train_info_fpath, fname='train_info.json') - zwrite(myzip, snap_fpath, fname='deploy_snapshot.pt') - zwrite(myzip, model_fpath, fname=os.path.basename(model_fpath)) - # Add some quick glanceable info - for p in info.get('glance', []): - zwrite(myzip, p, fname=join('glance', os.path.basename(p))) - # for bestacc_fpath in glob.glob(join(train_dpath, 'best_epoch_*')): - # zwrite(myzip, bestacc_fpath) - # for p in glob.glob(join(train_dpath, 'glance/*')): - # zwrite(myzip, p) - print('[DEPLOYER] Deployed zipfpath={}'.format(zipfpath)) - return zipfpath - - -class DeployedModel(ub.NiceRepr): - """ - Can setup an initializer and model from a deployed zipfile or a train path - - CommandLine: - xdoctest -m netharn.export.deployer DeployedModel - - Example: - >>> # Test the train folder as the model deployment - >>> train_dpath = _demodata_trained_dpath() - >>> self = DeployedModel(train_dpath) - >>> model_ = self.model_definition() - >>> initializer_ = self.initializer_definition() - >>> model = model_[0](**model_[1]) - >>> assert initializer_[1].get('fpath', None) is not None, 'initializer isnt setup correctly' - >>> initializer = initializer_[0](**initializer_[1]) - >>> initializer(model) - ... - >>> print('model.__module__ = {!r}'.format(model.__module__)) - - # >>> if six.PY3: - # ... assert model.__module__ == 'ToyNet2d_2a3f49' - # ... else: - # ... assert model.__module__ == 'ToyNet2d_d573a3' - - Example: - >>> # Test the zip file as the model deployment - >>> zip_fpath = _demodata_zip_fpath() - >>> self = DeployedModel(zip_fpath) - >>> model_ = self.model_definition() - >>> initializer_ = self.initializer_definition() - >>> model = model_[0](**model_[1]) - >>> assert initializer_[1].get('fpath', None) is not None, 'initializer isnt setup correctly' - >>> initializer = initializer_[0](**initializer_[1]) - >>> initializer(model) - ... - >>> # NOTE: the module name should be consistent, but due to - >>> # small library changes it often changes, so we are permissive - >>> # with this got/want test - >>> print('model.__module__ = {!r}'.format(model.__module__)) - model.__module__ = 'deploy_ToyNet2d_..._.../ToyNet2d_...' - - model.__module__ = 'deploy_ToyNet2d_mhuhweia_000_.../ToyNet2d_...' - - model.__module__ = 'deploy_ToyNet2d_rljhgepw_000_.../ToyNet2d_2a3f49' - """ - def __init__(self, path): - self.path = path - self._model = None - self._info = None - - @classmethod - def custom(DeployedModel, snap_fpath, model, initkw=None, train_info_fpath=None): - """ - Create a deployed model even if the model wasnt trained with FitHarn - - This just requires specifying a bit more information, which FitHarn - would have tracked. - - Args: - snap_fpath (PathLike): - path to the exported (snapshot) weights file - - model (PathLike or nn.Module): can either be - (1) a path to model topology (created via `export_model_code`) - (2) the model class or an instance of the class - - initkw (Dict): if model is a class or instance, then - you must pass the keyword arguments used to construct it. - - train_info_fpath (PathLike, optional): - path to a json file containing additional training metadata - - Example: - >>> # Setup raw components - >>> train_dpath = _demodata_trained_dpath() - >>> deployed = DeployedModel(train_dpath) - >>> snap_fpath = deployed.info['snap_fpath'] - >>> model, initkw = deployed.model_definition() - >>> train_info_fpath = deployed.info['train_info_fpath'] - >>> # Past raw components to custom - >>> self = DeployedModel.custom(snap_fpath, model, initkw) - >>> dpath = ub.ensure_app_cache_dir('netharn', 'tests/_package_custom') - >>> self.package(dpath) - """ - if isinstance(model, six.string_types): - model_fpath = model - if initkw is not None: - raise ValueError('initkw not used when model is a path') - else: - import tempfile - from netharn.export import exporter - dpath = tempfile.mkdtemp() - model_fpath = exporter.export_model_code(dpath, model, initkw=initkw) - - _info = { - 'model_fpath': model_fpath, - 'snap_fpath': snap_fpath, - 'train_info_fpath': train_info_fpath, - } - self = DeployedModel(None) - self._info = _info - return self - - def __nice__(self): - return self.__json__() - - def __json__(self): - if self.path is None: - if self._info: - return ub.repr2(self._info, nl=0) - else: - return self.path - - def package(self, dpath=None, name=None): - """ - If self.path is a directory, packages important info into a deployable - zipfile. - - Args: - dpath (PathLike, optional): directory to dump your packaged model. - If not specified, it uses the netharn train_dpath if available. - name (str, default=None): the name of the zipfile to deploy to. - If not specified, one will be constructed. - - Returns: - PathLike: path to single-file deployment - """ - if dpath is None: - if self.path is None: - raise ValueError('Must specify dpath for custom deployments') - else: - if self.path.endswith('.zip'): - raise Exception('Deployed model is already a package') - dpath = self.path - - zip_fpath = _package_deploy2(dpath, self.info, name=name) - return zip_fpath - - @property - def info(self): - if self._info is None: - self._info = self.unpack_info() - return self._info - - def unpack_info(self): - return unpack_model_info(self.path) - - def model_definition(self): - model_fpath = self.info['model_fpath'] - module = ub.import_module_from_path(model_fpath) - - export_version = getattr(module, '__pt_export_version__', '0') - export_version = list(map(int, export_version.split('.'))) - if export_version >= [0, 2, 0]: - model_cls = module.get_model_cls() - initkw = module.get_initkw() - else: - # Hack to get information from older versions of pytorch_export - import inspect - from xdoctest import static_analysis - print('Hacking to grab model_cls and initkw') - model = module.make() - model_cls = model.__class__ - source = inspect.getsource(module.make) - print(source) - initkw = static_analysis.parse_static_value('initkw', source=source) - # Try to reconstruct initkw - model_ = (model_cls, initkw) - return model_ - - def initializer_definition(self): - import netharn as nh - initializer_ = (nh.initializers.Pretrained, - {'fpath': self.info['snap_fpath']}) - return initializer_ - - def train_info(self): - import netharn as nh - train_info_fpath = self.info.get('train_info_fpath', None) - if train_info_fpath is not None: - train_info = json.load(nh.util.zopen(train_info_fpath, 'r')) - else: - train_info = None - return train_info - - def load_model(self): - if self._model is not None: - return self._model - - model_cls, model_kw = self.model_definition() - model = model_cls(**model_kw) - - if True: - # Always load models onto the CPU first - # import netharn as nh - model = model.to('cpu') - # devices = {k: item.device for k, item in model.state_dict().items()} - # nh.XPU.from_data(model) - - # TODO: load directly from instead of using initializer self.info['snap_fpath']? - # Actually we can't because we lose the zopen stuff. Its probably ok - # To depend on netharn a little bit. - # import torch - # info = self.unpack_info() - # state_dict = torch.load(self.info['snap_fpath']) - # model.load_state_dict() - - initializer_ = self.initializer_definition() - initializer = initializer_[0](**initializer_[1]) - - assert model is not None - - initializer(model) - return model - - @classmethod - def ensure_mounted_model(cls, deployed, xpu=None, log=print): - """ - Ensure that a deployed model is loaded and mounted. - - Helper method that can accept either a raw model or packaged deployed - model is loaded and mounted on a specific XPU. This provides a one line - solution for applications that may want to ensure that a model is - mounted and ready for predict. When the model is already mounted this - is very fast and just passes the data through. If the input is a - packaged deployed file, then it does the required work to prep the - model. - - Args: - deployed (DeployedModel | PathLike | torch.nn.Module): - either a packed deployed model, a path to a deployed model, or - an already mounted torch Module. - xpu (str | XPU): which device to mount on - log (callable, optional): logging or print function - - Returns: - Tuple[torch.nn.Module, XPU]: - the mounted model, and the device it is mounted on. - """ - import netharn as nh - import torch - if not isinstance(xpu, nh.XPU): - xpu = nh.XPU.coerce(xpu) - - if isinstance(deployed, six.string_types): - deployed = nh.export.DeployedModel(deployed) - - if isinstance(deployed, torch.nn.Module): - # User passed in the model directly - model = deployed - try: - if xpu != nh.XPU.from_data(model): - log('Re-Mount model on {}'.format(xpu)) - model = xpu.mount(model) - except Exception: - log('Re-Mount model on {}'.format(xpu)) - model = xpu.mount(model) - elif isinstance(deployed, nh.export.DeployedModel): - model = deployed.load_model() - log('Mount {} on {}'.format(deployed, xpu)) - model = xpu.mount(model) - else: - raise TypeError('Unable to ensure {!r} as a mounted model'.format( - deployed)) - - return model, xpu - - @classmethod - def coerce(DeployedModel, arg): - """ - Attempt to coerce the argument into a deployed model. - - Args: - arg (DeployedModel | PathLike | torch.nn.Module) : can be: - (1) a DeployedModel object - (2) a path to a deploy file - (3) a live pytorch module - (4) a path to a .pt file in a netharn train snapshot directory. - - Returns: - DeployedModel - """ - from os.path import dirname - import torch - if isinstance(arg, DeployedModel): - # The input is already a DeployedModel - deployed = arg - elif isinstance(arg, torch.nn.Module): - # The argument is a live pytorch model - deployed = DeployedModel(None) - deployed._model = arg - elif isinstance(arg, six.string_types): - # handle the case where we are given a weights file - # use heuristics try and determine model topology - if arg.endswith('.pt'): - snap_fpath = arg - dpath_cands = [] - # Look the pt file's directory for topology and train info - dpath1 = dirname(snap_fpath) - dpath_cands = [dpath1] - # The files might also be in the parent directory - if not exists(join(dpath1, 'train_info.json')): - dpath_cands.append(dirname(dpath1)) - # Search for the files in the candidate directories - train_info_cands = list(ub.find_path( - 'train_info.json', path=dpath_cands, exact=True)) - model_cands = list(ub.find_path( - '*.py', path=dpath_cands, exact=False)) - if len(model_cands) == 0: - raise Exception('Model topology does not exist for {!r}.'.format(arg)) - elif len(model_cands) > 1: - raise Exception('Conflicting model topologies for {!r}.'.format(arg)) - else: - model_fpath = model_cands[0] - if len(train_info_cands) == 0: - train_info_fpath = None - elif len(train_info_cands) > 1: - raise AssertionError('Conflicting train_info.json files') - else: - train_info_fpath = train_info_cands[0] - deployed = DeployedModel.custom(snap_fpath, model_fpath, - train_info_fpath=train_info_fpath) - else: - # Assume we have a netharn deploy path - deployed = DeployedModel(arg) - else: - # Unhandled case - raise TypeError(type(arg)) - return deployed - - -def _demodata_zip_fpath(): - zip_path = DeployedModel(_demodata_trained_dpath()).package() - return zip_path - - -def _demodata_toy_harn(): - # This will train a toy model with toy data using netharn - import netharn as nh - hyper = nh.HyperParams(**{ - 'workdir' : ub.ensure_app_cache_dir('netharn/tests/deploy'), - 'name' : 'deploy_demo_static', - 'xpu' : nh.XPU.coerce('cpu'), - 'datasets' : {'train': nh.data.ToyData2d(size=3, rng=0)}, - 'loaders' : {'batch_size': 64}, - 'model' : (nh.models.ToyNet2d, {}), - 'optimizer' : (nh.optimizers.SGD, {'lr': 0.0001}), - 'criterion' : (nh.criterions.FocalLoss, {}), - 'initializer' : (nh.initializers.KaimingNormal, {}), - 'monitor' : (nh.Monitor, {'max_epoch': 1}), - }) - harn = nh.FitHarn(hyper) - harn.preferences['use_tensorboard'] = False - return harn - - -def _demodata_trained_dpath(): - harn = _demodata_toy_harn() - harn.run() # TODO: make this run faster if we don't need to rerun - if len(list(glob.glob(join(harn.train_dpath, '*.py')))) > 1: - # If multiple models are deployed some hash changed. Need to reset - harn.initialize(reset='delete') - harn.run() # don't relearn if we already finished this one - return harn.train_dpath - - -if __name__ == '__main__': - """ - CommandLine: - xdoctest -m netharn.export.deployer all - """ - import xdoctest - xdoctest.doctest_module(__file__) +from torch_liberator.deployer import * # NOQA +import warnings +warnings.warn('netharn.export.deployer is deprecated, use torch_liberator.deployer intead', DeprecationWarning) diff --git a/netharn/export/exporter.py b/netharn/export/exporter.py index b6718975906a3ab2e05156b07cbfc896d7f37726..0736c33762555f6a842018714500234c386b0dec 100644 --- a/netharn/export/exporter.py +++ b/netharn/export/exporter.py @@ -1,346 +1,4 @@ # -*- coding: utf-8 -*- -""" -Export component of the Pytorch exporter. - -This is the code that simply exports the model toplogy via code - -Uses static analysis to export relevant code that defines the model topology -into a stanadlone file. As long as your model definition is indepenent of your -training code, then the exported file can be passed around in a similar way to -a caffe prototext file. - -TODO: - - [ ]: Look into: https://www.reddit.com/r/MachineLearning/comments/a856oe/d_pytorch_10_deployment_pipeline/ec9w94c/ - - >>> from torchvision.models import densenet - >>> import torch - >>> model = densenet.DenseNet(growth_rate=16).eval() - >>> traced = torch.jit.trace(model, example_inputs=(torch.randn(2, 3, 224, 224), )) - >>> traced.save("densenet.pt") - >>> model_ = torch.jit.load("densenet.pt") - - -CommandLine: - xdoctest -m netharn.export.exporter export_model_code - xdoctest -m netharn.export.exporter source_closure:1 - - xdoctest -m netharn.export.exporter all -""" -from __future__ import absolute_import, division, print_function, unicode_literals -import ast -import six # NOQA -import re -import hashlib -import io -import pickle -import tokenize -import ubelt as ub +from torch_liberator.exporter import * # NOQA import warnings -from os.path import join -from . import closer - -__all__ = ['export_model_code'] - - -__pt_export_version__ = '0.5.1' - - -def export_model_code(dpath, model, initkw=None, export_modules=[]): - """ - Exports the class used to define a pytorch model as a new python module. - - Exports the minimum amount of code needed to make a self-contained Python - module defining the pytorch model class. This exports the actual source - code. The advantage of using this over pickle is that the original code can - change arbitrarilly because all dependencies on the original code are - removed in the exported code. - - Args: - dpath (str): directory to dump the model - model (tuple or type or object): class or class instance (e.g. torch.nn.Module) - name (str): name to use for the file (defaults to the classname) - initkw (dict): if specified, creates the function `make`, which - initializes the network with the specific arguments. - export_modules (List[str]): A list of modules that the exported code - should not depend on. Any code referenced from these modules will - be statically extracted and copied into the model definition. - Note that this feature is experimental. - - Returns: - str: static_modpath: path to the saved model file. - While you could put the output path in your PYTHONPATH, it is best - to use `ub.import_module_from_path` to "load" the model instead. - - Example: - >>> from netharn.export.exporter import export_model_code - >>> from torchvision.models import densenet - >>> import torchvision - >>> from os.path import basename - >>> initkw = {'growth_rate': 16} - >>> model = densenet.DenseNet(**initkw) - >>> dpath = ub.ensure_app_cache_dir('netharn/tests') - >>> static_modpath = export_model_code(dpath, model, initkw) - >>> print('static_modpath = {!r}'.format(static_modpath)) - ... - >>> mod_fname = (basename(static_modpath)) - >>> print('mod_fname = {!r}'.format(mod_fname)) - >>> if torchvision.__version__ == '0.2.2': - >>> if six.PY2: - >>> assert mod_fname == 'DenseNet_b7ec43.py', 'got={}'.format(mod_fname) - >>> else: - >>> assert mod_fname == 'DenseNet_256629.py', 'got={}'.format(mod_fname) - >>> # now the module can be loaded - >>> module = ub.import_module_from_path(static_modpath) - >>> loaded = module.make() - >>> assert model.features.denseblock1.denselayer1.conv2.out_channels == 16 - >>> assert loaded.features.denseblock1.denselayer1.conv2.out_channels == 16 - >>> assert model is not loaded - """ - if isinstance(model, type): - model_class = model - else: - model_class = model.__class__ - classname = model_class.__name__ - - if initkw is None: - raise NotImplementedError( - 'ERROR: The params passed to the model __init__ must be available') - footer = '' - else: - # First see if we can get away with a simple encoding of initkw - try: - # Do not use repr. The text produced is non-deterministic for - # dictionaries. Instead, use ub.repr2, which is deterministic. - init_text = ub.repr2(initkw, nl=1) - eval(init_text, {}) - init_code = ub.codeblock( - 'initkw = {}' - ).format(init_text) - except Exception: - # fallback to pickle - warnings.warn('Initialization params might not be serialized ' - 'deterministically') - init_bytes = repr(pickle.dumps(initkw, protocol=0)) - init_code = ub.codeblock( - ''' - import pickle - initkw = pickle.loads({}) - ''' - ).format(init_bytes) - init_code = ub.indent(init_code).lstrip() - # create a function to instanciate the class - footer = '\n\n' + ub.codeblock( - ''' - __pt_export_version__ = '{__pt_export_version__}' - - - def get_initkw(): - """ creates an instance of the model """ - {init_code} - return initkw - - - def get_model_cls(): - model_cls = {classname} - return model_cls - - - def make(): - """ creates an instance of the model """ - initkw = get_initkw() - model_cls = get_model_cls() - model = model_cls(**initkw) - return model - ''' - ).format(classname=classname, init_code=init_code, - __pt_export_version__=__pt_export_version__) - - # TODO: assert that the name "make" is not used in the model body - - body = closer.source_closure(model_class, expand_names=export_modules) - - body_footer = body + footer + '\n' - # dont need to hash the header, because comments are removed anyway - - # with open('debug-closer.py', 'w') as file: - # file.write(body_footer) - hashid = hash_code(body_footer) - - header = ub.codeblock( - ''' - """ - This module was autogenerated by netharn/export/exporter.py - original_module={} - classname={} - timestamp={} - hashid={} - """ - ''').format(model_class.__module__, classname, ub.timestamp(), hashid) - - sourcecode = header + '\n' + body_footer - - static_modname = classname + '_' + hashid[0:6] - static_modpath = join(dpath, static_modname + '.py') - with open(static_modpath, 'w') as file: - file.write(sourcecode) - return static_modpath - - -def remove_comments_and_docstrings(source): - r""" - Args: - source (str): uft8 text of source code - - Returns: - str: out: the source with comments and docstrings removed. - - References: - https://stackoverflow.com/questions/1769332/remove-comments-docstrings - - Example: - >>> source = ub.codeblock( - ''' - def foo(): - 'The spaces before this docstring are tokenize.INDENT' - test = [ - 'The spaces before this string do not get a token' - ] - ''') - >>> out = remove_comments_and_docstrings(source) - >>> want = ub.codeblock( - ''' - def foo(): - pass - test = [ - 'The spaces before this string do not get a token' - ]''').splitlines() - >>> got = [o.rstrip() for o in out.splitlines()] - >>> assert got == want - - - >>> source = ub.codeblock( - ''' - def foo(): - " docstring " - ''') - >>> out = remove_comments_and_docstrings(source) - >>> print(out) - >>> source = ub.codeblock( - ''' - class foo(): - r{qqq} - docstring - {qqq} - ''').format(qqq='"' * 3) - >>> out = remove_comments_and_docstrings(source) - >>> print(out) - - """ - source = ub.ensure_unicode(source) - io_obj = io.StringIO(source) - output_parts = [] - prev_toktype = tokenize.INDENT - last_lineno = -1 - last_col = 0 - for tok in tokenize.generate_tokens(io_obj.readline): - token_type = tok[0] - token_string = tok[1] - start_line, start_col = tok[2] - end_line, end_col = tok[3] - # ltext = tok[4] - # The following two conditionals preserve indentation. - # This is necessary because we're not using tokenize.untokenize() - # (because it spits out code with copious amounts of oddly-placed - # whitespace). - if start_line > last_lineno: - last_col = 0 - if start_col > last_col: - output_parts.append((' ' * (start_col - last_col))) - # Remove comments: - if token_type == tokenize.COMMENT: - pass - # This series of conditionals removes docstrings: - elif token_type == tokenize.STRING: - if prev_toktype != tokenize.INDENT: - # This is likely a docstring; double-check we're not inside an - # operator: - if prev_toktype != tokenize.NEWLINE: - # Note regarding NEWLINE vs NL: The tokenize module - # differentiates between newlines that start a new statement - # and newlines inside of operators such as parens, brackes, - # and curly braces. Newlines inside of operators are - # NEWLINE and newlines that start new code are NL. - # Catch whole-module docstrings: - if start_col > 0: - # Unlabelled indentation means we're inside an operator - output_parts.append(token_string) - # Note regarding the INDENT token: The tokenize module does - # not label indentation inside of an operator (parens, - # brackets, and curly braces) as actual indentation. - else: - # NOTE: simply removing docstrings may create invalid code - # in cases where the only body is a docstring (e.g. a - # custom exception). Insert a pass to prevent this. It - # would be nice to detect when this is necessary. - output_parts.append('pass') - else: - output_parts.append(token_string) - prev_toktype = token_type - last_col = end_col - last_lineno = end_line - out = ''.join(output_parts) - return out - - -def hash_code(sourcecode): - r""" - Hashes source code text, but tries to normalize things like whitespace and - comments, so very minor changes wont change the hash. - - Args: - source (str): uft8 text of source code - - Returns: - str: hashid: 128 character (512 byte) hash of the normalized input - - Notes: - The return value of this function is based on the AST parse tree, which - might change between different version of Python. However, within the - same version of Python, the results should be consistent. - - CommandLine: - xdoctest -m /home/joncrall/code/netharn/netharn/export/exporter.py hash_code - - Example: - >>> hashid1 = (hash_code('x = 1')[0:8]) - >>> hashid2 = (hash_code('x=1 # comments and spaces dont matter')[0:8]) - >>> hashid3 = (hash_code('\nx=1')[0:8]) - >>> assert ub.allsame([hashid1, hashid2, hashid3]) - >>> hashid4 = hash_code('x=2')[0:8] - >>> assert hashid1 != hashid4 - """ - # Strip docstrings before making a parse tree - sourcecode = ub.ensure_unicode(sourcecode) - - stripped = remove_comments_and_docstrings(sourcecode) - - # Also remove pytorch_export version info (not sure if correct?) - stripped = re.sub('__pt_export_version__ = .*', '', stripped) - - parse_tree = ast.parse(stripped) - # hashing the parse tree will normalize for a lot possible small changes - ast_dump = ast.dump(parse_tree) - - hasher = hashlib.sha512() - hasher.update(ast_dump.encode('utf8')) - hashid = hasher.hexdigest() - return hashid - - -if __name__ == '__main__': - """ - CommandLine: - xdoctest -m netharn.export.exporter - """ - import xdoctest - xdoctest.doctest_module(__file__) +warnings.warn('netharn.export.exporter is deprecated, use torch_liberator.exporter intead', DeprecationWarning) diff --git a/netharn/fit_harn.py b/netharn/fit_harn.py index e2937e4978cdf9b948a34e251c1b4dd0a8b7630e..bb831f0b8760845200a82140701015e257b3d284 100644 --- a/netharn/fit_harn.py +++ b/netharn/fit_harn.py @@ -38,8 +38,7 @@ Note: CommandLine: xdoctest netharn.fit_harn __doc__:0 - xdoctest netharn.fit_harn __doc__:0 --progiter - xdoctest netharn.fit_harn __doc__:0 --progiter --profile --xpu=cpu + xdoctest netharn.fit_harn __doc__:0 --profile --xpu=cpu Example: >>> import netharn as nh @@ -88,7 +87,7 @@ Example: >>> # non-algorithmic behavior configs (do not change learned models) >>> harn.preferences['use_tensorboard'] = False >>> harn.preferences['timeout'] = 0.5 - >>> # harn.preferences['colored'] = False + >>> harn.preferences['auto_prepare_batch'] = True >>> # start training. >>> harn.initialize(reset='delete') >>> harn.run() # note: run calls initialize it hasn't already been called. @@ -129,6 +128,14 @@ TODO: [ ] - ability to run an iteration of the validation data within an epoch, perhaps we could allow the user to redefine how long an epoch is. + [ ] - Update for torch 1.1 lr scheduler behavior. Allow schedulers to be + called either after each epoch or after each batch iteration (for + schedulers like CyclicLR, OneCycleLR). + + [X] - Show LR in the batch progress bar (if updated on an iteration basis) + [ ] - How does the netharn scheduler redesign interact with torch 1.1? + [ ] - Stochastic Weight Averaging - https://pytorch.org/docs/stable/optim.html#putting-it-all-together + """ from __future__ import absolute_import, division, print_function, unicode_literals import glob @@ -145,16 +152,17 @@ import traceback from os.path import join from os.path import exists from os.path import dirname +from distutils.version import LooseVersion import torch import numpy as np import ubelt as ub import scriptconfig as scfg +import torch_liberator from netharn import hyperparams from netharn import util -from netharn import export from netharn.util import profiler from netharn.util import strip_ansi from netharn.exceptions import (CannotResume, SkipBatch, StopTraining, @@ -524,7 +532,8 @@ class InitializeMixin(object): # this allows us to print logging calls to the terminal stdout_handler = logging.StreamHandler(sys.stdout) stdout_handler.setFormatter(s_formatter) - if ub.argflag('--verbose'): + + if harn.preferences['verbose'] > 1 or ub.argflag('--verbose'): stdout_handler.setLevel(logging.DEBUG) else: stdout_handler.setLevel(logging.INFO) @@ -534,6 +543,9 @@ class InitializeMixin(object): _log.addHandler(a_handler) _log.addHandler(stdout_handler) + # hack in attribute for internal use + _log._stdout_handler = stdout_handler + harn._log = _log harn.debug('Initialized logging') @@ -715,6 +727,7 @@ class ProgMixin(object): def _make_prog(harn, *args, **kw): chunksize = kw.pop('chunksize', None) + show_wall = kw.pop('show_wall', False) if harn.preferences['use_tqdm'] is not None: import warnings @@ -734,8 +747,14 @@ class ProgMixin(object): import tqdm # NOQA Prog = tqdm.tqdm elif harn.preferences['prog_backend'] == 'progiter': - Prog = functools.partial( - ub.ProgIter, chunksize=chunksize, verbose=1, time_thresh=2.0) + if LooseVersion(ub.__version__) >= LooseVersion('0.9.3'): + Prog = functools.partial( + ub.ProgIter, chunksize=chunksize, verbose=1, + time_thresh=2.0, show_wall=show_wall) + else: + Prog = functools.partial( + ub.ProgIter, chunksize=chunksize, verbose=1, + time_thresh=2.0) else: raise KeyError(harn.preferences['prog_backend']) return Prog(*args, **kw) @@ -751,6 +770,14 @@ class ProgMixin(object): str : the message to be used in the progress bar """ parts = ['{}:{:.4g}'.format(k, v) for k, v in metric_dict.items()] + + if learn and harn.epoch == 0: + HACK_WARMUP = bool(harn.dynamics['warmup_iters']) + if HACK_WARMUP: + lrs = set(harn._current_lrs()) + lr_str = ','.join(['{:.4g}'.format(lr) for lr in lrs]) + parts.append('lr=' + lr_str) + if harn.preferences['prog_backend'] == 'progiter': if learn and harn.scheduler and getattr(harn.scheduler, '__batchaware__', False): lr = harn.scheduler.get_lr() @@ -763,7 +790,7 @@ class ProgMixin(object): else: bs = 'x{}'.format(batch_size) parts = [bs] + parts - if six.PY2: + if not harn.preferences['allow_unicode'] or six.PY2: # work around a unicode issue with tqdm in python2 msg = ' | ' .join(parts) + ' |' else: @@ -785,15 +812,16 @@ class ProgMixin(object): def _update_main_prog_desc(harn): lrs = set(harn._current_lrs()) lr_str = ','.join(['{:.4g}'.format(lr) for lr in lrs]) - if six.PY2: + if not harn.preferences['allow_unicode'] or six.PY2: desc = 'epoch lr:{} | {}'.format(lr_str, harn.monitor.message()) else: desc = 'epoch lr:{} │ {}'.format(lr_str, harn.monitor.message()) if not harn.preferences['colored']: desc = strip_ansi(desc) - harn.debug(desc) harn.main_prog.set_description(desc, refresh=False) if isinstance(harn.main_prog, ub.ProgIter): + # Write progress message to the log file + harn.debug(harn.main_prog.format_message().strip()) if not harn.main_prog.started: # harn.main_prog.ensure_newline() harn.main_prog.clearline = False @@ -801,6 +829,7 @@ class ProgMixin(object): harn.main_prog.adjust = False harn.main_prog.begin() else: + harn.debug(desc) harn._update_prog_postfix(harn.main_prog) @@ -898,9 +927,16 @@ class LogMixin(object): msg (str): a debug message to log """ if harn._log: + + if harn._log._stdout_handler.level <= logging.DEBUG: + # Use our hacked attribute to ensure newlines if we are + # writting debug info to stdout + harn._ensure_prog_newline() + msg = strip_ansi(six.text_type(msg)) # Encode to prevent errors on windows terminals - # On windows there is a sometimes a UnicodeEncodeError: For more details see: https://wiki.python.org/moin/PrintFails + # On windows there is a sometimes a UnicodeEncodeError: + # For more details see: https://wiki.python.org/moin/PrintFails if sys.platform.startswith('win32'): harn._log.debug(msg.encode('utf8')) else: @@ -1100,7 +1136,7 @@ class SnapshotMixin(object): are: checkpoint, explicit, and initial. explicit (bool, default=False): if True, the snapshot is also - tagged by a hash and saved to the explit_checkpoints directory. + tagged by a hash and saved to the explicit_checkpoints directory. DEPRECTATED, use mode. Returns: @@ -1116,7 +1152,7 @@ class SnapshotMixin(object): mode = 'explicit' if mode == 'explicit': - dpath = ub.ensuredir((harn.train_dpath, 'explit_checkpoints')) + dpath = ub.ensuredir((harn.train_dpath, 'explicit_checkpoints')) stamp = ub.timestamp() save_fname = '_epoch_{:08d}_{}.pt'.format(harn.epoch, stamp) elif mode == 'checkpoint': @@ -1132,7 +1168,7 @@ class SnapshotMixin(object): save_fname = '_epoch_{:08d}.pt'.format(harn.epoch) elif mode == 'initial': dpath = ub.ensuredir((harn.train_dpath, 'initial_state')) - save_fname = 'initial_state.pt'.format(harn.epoch) + save_fname = 'initial_state.pt'.format() else: raise KeyError(mode) @@ -1157,6 +1193,9 @@ class SnapshotMixin(object): def best_snapshot(harn): """ Return the path to the current "best" snapshot. + + Returns: + str - find the path to the best """ # Netharn should populate best_snapshot.pt if there is a validation set. # Other names are to support older codebases. @@ -1183,9 +1222,8 @@ class SnapshotMixin(object): if epoch_to_fpath: fpath = epoch_to_fpath[max(epoch_to_fpath)] - if fpath is None: - raise Exception('cannot find / determine the best snapshot') - + # if fpath is None: + # raise Exception('cannot find / determine the best snapshot') return fpath @@ -1323,15 +1361,15 @@ class ScheduleMixin(object): warmup_iters = harn.dynamics['warmup_iters'] warmup_ratio = harn.dynamics['warmup_ratio'] # 1.0 / 3.0 if cur_iters < warmup_iters: - for cur_iters in range(0, warmup_iters): - regular_lr = _get_optimizer_values(harn.optimizer, 'initial_lr') - if warmup == 'linear': - k = (1 - (cur_iters + 1) / warmup_iters) * (1 - warmup_ratio) - warmup_lr = [_lr * (1 - k) for _lr in regular_lr] - else: - raise KeyError(warmup) - # harn.debug('warmup_lr = {}'.format(warmup_lr)) - _set_optimizer_values(harn.optimizer, 'lr', warmup_lr) + # for cur_iters in range(0, warmup_iters): + regular_lr = _get_optimizer_values(harn.optimizer, 'initial_lr') + if warmup == 'linear': + k = (1 - (cur_iters + 1) / warmup_iters) * (1 - warmup_ratio) + warmup_lr = [_lr * (1 - k) for _lr in regular_lr] + else: + raise KeyError(warmup) + # harn.debug('warmup_lr = {}'.format(warmup_lr)) + _set_optimizer_values(harn.optimizer, 'lr', warmup_lr) # TODO: REFACTOR SO NETHARN HAS A PROPER ITERATION MODE if getattr(harn.scheduler, '__batchaware__', False): @@ -1467,7 +1505,8 @@ class CoreMixin(object): total=harn.monitor.max_epoch, disable=not harn.preferences['show_prog'], leave=True, dynamic_ncols=True, - position=0, initial=harn.epoch) + show_wall=True, position=0, + initial=harn.epoch) harn._update_main_prog_desc() # Loader dict should be ordered @@ -1630,7 +1669,7 @@ class CoreMixin(object): model_class = harn.hyper.model_cls model_params = harn.hyper.model_params export_modules = harn.preferences['export_modules'] - static_modpath = export.export_model_code( + static_modpath = torch_liberator.export_model_code( harn.train_dpath, model_class, initkw=model_params, export_modules=export_modules) harn.info('Exported model topology to {}'.format(static_modpath)) @@ -1647,16 +1686,24 @@ class CoreMixin(object): Returns: str: path to the deploy zipfile. """ - harn._export() + static_modpath = harn._export() harn.debug('packaging deploying model') if True: - # HOTFIX: if the best snapshot doesnt exist we need to make one - if export.deployer.find_best_snapshot(harn.train_dpath) is None: - harn.save_snapshot() + snap_fpath = harn.best_snapshot() + if snap_fpath is None: + # if the best snapshot doesnt exist we need to make one + harn.debug( + 'Cannot find "best" snapshot, write an explit one instead') + snap_fpath = harn.save_snapshot(explicit=True) try: - deploy_fpath = export.DeployedModel(harn.train_dpath).package() + train_info_fpath = join(harn.train_dpath, 'train_info.json') + deploy_fpath = torch_liberator.DeployedModel.custom( + snap_fpath=snap_fpath, + model=static_modpath, + train_info_fpath=train_info_fpath, + ).package(harn.train_dpath) harn.info('wrote single-file deployment to: {!r}'.format( deploy_fpath)) @@ -1893,6 +1940,8 @@ class CoreMixin(object): ### THIS IS THE CRITICAL LOOP ### ################################# + STEP_LR_BEFORE = True + for bx in range(n_batches): if DEMO and bx > DEMO_BX: break @@ -1906,6 +1955,12 @@ class CoreMixin(object): harn.bxs[tag] = bx # harn.debug('{} batch iteration {}'.format(tag, bx)) + if STEP_LR_BEFORE: + if learn: + # Some schedulers update every batch + # TODO: needs further rectification + harn._step_scheduler_batch() + batch = harn.prepare_batch(raw_batch) if is_profiling: @@ -1985,14 +2040,26 @@ class CoreMixin(object): # hack to force progiter to reach 100% at the end # This should be fixed in progiter. steps_taken = (bx - prog._iter_idx) + 1 - prog.update(steps_taken) + if bx == 0: + # HACK, after ubelt 0.9.3 we can use force=True + prog._iter_idx += steps_taken + prog._update_measurements() + prog._update_estimates() + prog.display_message() + harn.debug(prog.format_message().strip()) + else: + prog_updated = prog.update(steps_taken) + if prog_updated: + harn.debug(prog.format_message().strip()) if use_tqdm: harn._update_prog_postfix(prog) - # Some schedulers update every batch - if learn: - harn._step_scheduler_batch() + if not STEP_LR_BEFORE: + # old way that I think is buggy + if learn: + # Some schedulers update every batch + harn._step_scheduler_batch() except SkipBatch: harn.warn('skipping batch') if harn.check_interval('display_' + tag, bx): @@ -2012,6 +2079,8 @@ class CoreMixin(object): # harn.optimizer.zero_grad() prog.refresh() + if not use_tqdm: + harn.debug(prog.format_message().strip()) prog.close() harn.epoch_prog = None @@ -2253,38 +2322,37 @@ class CoreCallbacks(object): necessary to support distributed training. """ batch = raw_batch - import warnings - warnings.warn( - 'The behavior of prepare_batch will change in the future. ' - 'The new behavior will be a simple no-op ' - 'For maximum compatibility override prepare_batch.', - DeprecationWarning) - try: - if isinstance(raw_batch, (tuple, list)): - batch_inputs, batch_labels = raw_batch - raw_batch = { - 'input': batch_inputs, - 'label': batch_labels, - } - if isinstance(raw_batch, dict): - batch = raw_batch.copy() - batch = harn.xpu.move(batch) - else: - print('ERROR: raw_batch = {}'.format(type(raw_batch))) - raise TypeError( - 'could not prepare raw batch {}'.format(type(raw_batch))) - except Exception: - harn.warn('Error occurred in default prepare_batch. ' - 'Perhaps you should overload it?') - raise - return batch + if harn.preferences['auto_prepare_batch']: + # Automatically move data + try: + if isinstance(raw_batch, (tuple, list)): + batch = harn.xpu.move(raw_batch) + elif isinstance(raw_batch, dict): + batch = raw_batch.copy() + batch = harn.xpu.move(batch) + else: + print('ERROR: raw_batch = {}'.format(type(raw_batch))) + raise TypeError( + 'could not prepare raw batch {}'.format(type(raw_batch))) + + except Exception: + harn.warn('Error occurred in default prepare_batch. ' + 'Perhaps you should overload it?') + raise + return batch + else: + return batch def run_batch(harn, batch): """ Basic connection inputs -> model -> outputs -> criterion -> loss - Overload Encouraged, but not always necessary + This is the meat and potatoes of your deep learning algorithm, + everything else is boilerplate. You define how to pass your inputs into + your model and then compute your loss here. We provide a default + implementation that will work for basic tasks as long as the model and + loss are well defined, but you will typically need to overload this. Note: You may return loss as a flat dictionary mapping string keys to @@ -2292,22 +2360,35 @@ class CoreCallbacks(object): and each loss component will be automatically logged. Args: - batch (object): the current batch + batch (object): the current batch as generated by the data loader. + Note: use :func:`ExtraMixins.._demo_batch` (i.e. + ``harn._demo_batch()``) to generate an example batch for + interactive / testing / other usage. Returns: - Tuple[object, Tensor|Dict]: (outputs, loss) + Tuple[object, Tensor|Dict]: + tuple containing: + outputs - whatever the output of the model was + loss - either a single scalar loss or a dictionary of + scalar losses (the harness use the keys as labels to + track different losses). """ # Simple forward prop and loss computation try: if isinstance(batch, dict): + # The extensible case where your batch is a dictionary with + # keys "input" and "label", which themselves are usually + # dictionaries. outputs = harn.model(batch['input']) loss = harn.criterion(outputs, batch['label']) - elif isinstance(batch, tuple): + elif isinstance(batch, (tuple, list)) and len(batch) == 2: + # The "standard" non-extensible case you see in tutorials where + # items from the dataset are returned as a input / label tuple inputs, labels = batch - outputs = harn.model(*inputs) - loss = harn.criterion(outputs, *labels) + outputs = harn.model(inputs) + loss = harn.criterion(outputs, labels) else: - raise TypeError('Could not run batch') + raise TypeError('Could not run batch: {}'.format(type(batch))) except Exception: if harn.criterion: harn.error('You must overwrite run_batch if ' @@ -2555,7 +2636,6 @@ class FitHarn(ExtraMixins, InitializeMixin, ProgMixin, LogMixin, SnapshotMixin, monitors performance of the validation set. SeeAlso `netharn.monitor`. - Note: hyper is optional. If you choose not to specify it then you must overwrite harn._setup_modules and create the requires class instances @@ -2799,12 +2879,27 @@ class FitHarnPreferences(scfg.Config): 'limits the amount of time training can take') ), + 'auto_prepare_batch': scfg.Value(False, help=( + 'In the case where prepare_batch is not overwritten, ' + 'changes the behavior of the default prepare_batch ' + 'to automatically move tensors onto the model XPU' + )), + + 'verbose': scfg.Value(1, help=( + 'verbosity level, ' + 'if >1 shows debug info in stdout')), + # Deprecated 'use_tqdm': scfg.Value(None, help='deprecated'), 'colored': scfg.Value(True, help=( 'allow for ANSI colored text in stdout logs, ' - 'otherwise it is stripped')), + 'otherwise it is stripped. ' + 'DEPRECATED use NO_COLOR environ instead')), + + 'allow_unicode': scfg.Value(True, help=( + 'allow for unicode characters in messages, otherwise ' + ' we approximate them with ascii')), } diff --git a/netharn/hyperparams.py b/netharn/hyperparams.py index 740eaffbc1726b8c6aaed0672f87d5c88ba23c7a..71ae2941837744ad0453bc5530f1df34934415c3 100644 --- a/netharn/hyperparams.py +++ b/netharn/hyperparams.py @@ -906,6 +906,9 @@ class HyperParams(object): temp_initializer = hyper.make_initializer() init_history = temp_initializer.history() + # TODO: software versions + + train_info = ub.odict([ ('train_hashid', train_hashid), @@ -1004,6 +1007,52 @@ class HyperParams(object): }) return hyper + +def module_version_infos(): + """ + + References: + https://packaging.python.org/guides/single-sourcing-package-version/ + """ + try: + from importlib import metadata + except ImportError: + # Running on pre-3.8 Python; use importlib-metadata package + import importlib_metadata as metadata + import sys + modnames = ['torch', 'cv2', 'netharn', 'PIL', 'numpy'] + infos = [] + for modname in modnames: + info = {'name': modname} + + try: + module = sys.modules[modname] + version_0 = getattr(module, '__version__', None) + except Exception: + version_0 = None + + try: + version_1 = metadata.version(modname) + except Exception: + version_1 = None + + possible_versions = {version_1, version_0} - {None} + if len(possible_versions) == 1: + info['version'] = ub.peek(possible_versions) + else: + info['possible_versions'] = possible_versions + + if modname == 'torch': + info['torch.version.cuda'] = torch.version.cuda + info['torch.cuda.is_available()'] = torch.cuda.is_available() + + infos.append(info) + + # The conda info step is too slow (3 seconds) + from netharn.util.collect_env import get_env_info + env_info = get_env_info()._asdict() + info['__env__'] = env_info + if __name__ == '__main__': r""" CommandLine: diff --git a/netharn/initializers/_nx_ext/__init__.py b/netharn/initializers/_nx_ext/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c0e9e1bf22e4d816d96b0361ab2c62f29eb4f89f --- /dev/null +++ b/netharn/initializers/_nx_ext/__init__.py @@ -0,0 +1,76 @@ +""" +TEMPORARY FORK +-------------- + +CommandLine: + sedr networkx.algorithms.isomorphism._embedding netharn.initializers._nx_ext + sedr netharn.initializers._nx_ext netharn.initializers._nx_ext * True + + +Subpackages for helpers and such related to the ordered subtree embedding / +isomorphism problems. + +Contains routines for solving balanced sequence and path subproblems. Only the +final graph-based API is exposed, but modification to the internals (is / will +be) available via keyword arguments. + +balanced_sequence.py - core python implementations for the longest common +balanced sequence subproblem. + +balanced_sequence_cython.pyx - +faster alternative implementsions for balanced_sequence.py + +tree_embedding.py - defines reduction from tree problem to balanced sequence +problems. + +path_embedding.py - defines reduction from path problem to tree problem (not +core, this is just useful for testing among other things). + +demodata.py - Contains data for docstrings, benchmarks, and synthetic problems + + +Outstanding Issues +------------------ +- [ ] Multiple implementations of the algorithm backend / data structure + reduction, need to reduce the impelmentation and / or determine a better + mechansim for allowing the user to switch between them. + +- [ ] strhack is not a good API in `tree_to_seq` + +- [ ] Should we return which edges were contracted in each tree to create the + embeddings? That seems useful (but maybe not equivalent to the embeddings + themselves?) + +- [ ] How to deal with cython + networkx? Do we need to fix that skbuild with + pypy? + +- [ ] The open_to_node problem: + Note, we may be able to simply use the position of each opening token + as a proxy for unique tokens. Pass in an ordered list of nodes, then + just use their indexes. + + +CommandLine +----------- +xdoctest -m netharn.initializers._nx_ext list +xdoctest -m netharn.initializers._nx_ext all + +# Run all tests in this module +DPATH=$(python -c " +import os; import netharn.initializers._nx_ext as m; +print(os.path.dirname(m.__file__))") +pytest --xdoctest $DPATH --xdoc-analysis=dynamic + +# The mkinit tool helps autogenerate explicit `__init__.py` files +mkinit ~/code/networkx/netharn.initializers._nx_ext/__init__.py -w +""" + +__submodules__ = [ + 'tree_embedding', +] + +# from netharn.initializers._nx_ext import tree_embedding +from netharn.initializers._nx_ext.tree_embedding import ( + maximum_common_ordered_tree_embedding) + +__all__ = ['maximum_common_ordered_tree_embedding'] diff --git a/netharn/initializers/_nx_ext/_bseq_expt.py b/netharn/initializers/_nx_ext/_bseq_expt.py new file mode 100644 index 0000000000000000000000000000000000000000..5fad7de9cd800f7b8cf3baa9eb6731f0f334e2e5 --- /dev/null +++ b/netharn/initializers/_nx_ext/_bseq_expt.py @@ -0,0 +1,392 @@ +from netharn.initializers._nx_ext.balanced_sequence import UnbalancedException, IdentityDict # NOQA +from netharn.initializers._nx_ext.balanced_sequence import generate_all_decomp, _cython_lcs_backend, _lcs_iter_simple_alt2, _lcs_iter_prehash2, _lcs_recurse, _lcs_iter_simple, _lcs_iter_simple_alt1, _lcs_iter_prehash # NOQA + + +def _lcs_iter_simple_alt3(full_seq1, full_seq2, open_to_close, node_affinity, open_to_node): + """ + Depth first stack trajectory and replace try except statements with ifs + + This is the current best pure-python algorithm candidate + + >>> full_seq1 = '{({})([[]([]){(()(({()[]({}{})}))){}}])}' + >>> full_seq2 = '{[({{}}{{[][{}]}(()[(({()})){[]()}])})]}' + >>> open_to_close = {'{': '}', '(': ')', '[': ']'} + >>> full_seq1 = '[][[]][]' + >>> full_seq2 = '[[]][[]]' + >>> open_to_close = {'[': ']'} + >>> import operator as op + >>> node_affinity = op.eq + >>> open_to_node = IdentityDict() + >>> res = _lcs_iter_simple_alt3(full_seq1, full_seq2, open_to_close, node_affinity, open_to_node) + >>> embeddings, val, delseq = res + >>> print('embeddings = {!r}'.format(embeddings[0])) + >>> print('delseq = {!r}'.format(delseq[0])) + """ + all_decomp1 = generate_all_decomp(full_seq1, open_to_close, open_to_node) + all_decomp2 = generate_all_decomp(full_seq2, open_to_close, open_to_node) + + key0 = (full_seq1, full_seq2) + frame0 = key0 + stack = [frame0] + + # Memoize mapping (seq1, seq2) -> best size, embeddings, deleted edges + _results = {} + + # Populate base cases + empty1 = type(next(iter(all_decomp1.keys())))() + empty2 = type(next(iter(all_decomp2.keys())))() + best = (empty1, empty2) + base_result = (0, best, ([], [])) + for seq1 in all_decomp1.keys(): + key1 = seq1 + t1, a1, b1, head1, tail1, head_tail1 = all_decomp1[key1] + _results[(seq1, empty2)] = base_result + _results[(head1, empty2)] = base_result + _results[(tail1, empty2)] = base_result + _results[(head_tail1, empty2)] = base_result + + for seq2 in all_decomp2.keys(): + key2 = seq2 + t2, a2, b2, head2, tail2, head_tail2 = all_decomp2[key2] + _results[(empty1, seq2)] = base_result + _results[(empty1, head2)] = base_result + _results[(empty1, tail2)] = base_result + _results[(empty1, head_tail2)] = base_result + + del frame0 + del empty1 + del empty2 + del best + del base_result + + while stack: + key = stack[-1] + if key not in _results: + seq1, seq2 = key + + t1, a1, b1, head1, tail1, head_tail1 = all_decomp1[seq1] + t2, a2, b2, head2, tail2, head_tail2 = all_decomp2[seq2] + + # Case 2: The current edge in sequence1 is deleted + try_key = (head_tail1, seq2) + if try_key in _results: + cand1 = _results[try_key] + x, y, z = cand1 + z1, z2 = z + z1 = z1 + [a1] + z2 = z2 + [a2] + z3 = (z1, z2) + cand1 = (x, y, z3) + else: + # stack.append(key) + stack.append(try_key) + continue + + # Case 3: The current edge in sequence2 is deleted + try_key = (seq1, head_tail2) + if try_key in _results: + cand2 = _results[try_key] + x, y, z = cand2 + z1, z2 = z + z1 = z1 + [a1] + z2 = z2 + [a2] + z3 = (z1, z2) + cand2 = (x, y, z3) + else: + # stack.append(key) + stack.append(try_key) + continue + + # Case 1: The LCS involves this edge + affinity = node_affinity(t1, t2) + if affinity: + try_key = (head1, head2) + if try_key in _results: + pval_h, new_heads, delseq_h = _results[try_key] + else: + # stack.append(key) + stack.append(try_key) + continue + + try_key = (tail1, tail2) + if try_key in _results: + pval_t, new_tails, delseq_t = _results[try_key] + else: + # stack.append(key) + stack.append(try_key) + continue + + new_head1, new_head2 = new_heads + new_tail1, new_tail2 = new_tails + + subseq1 = a1 + new_head1 + b1 + new_tail1 + subseq2 = a2 + new_head2 + b2 + new_tail2 + + res3 = (subseq1, subseq2) + val3 = pval_h + pval_t + affinity + + h1, h2 = delseq_h + t1, t2 = delseq_t + + delseq3 = (h1 + t1, h2 + t2) + cand3 = (val3, res3, delseq3) + else: + cand3 = (-1, None) + + # We solved the frame + _results[key] = max(cand1, cand2, cand3) + stack.pop() + + val, best, delseq = _results[key0] + found = (best, val, delseq) + return found + + +def balanced_decomp2(sequence, open_to_close, start=0): + gen = generate_balance2(sequence, open_to_close) + for tup in gen: + (bal_curr, tok_curr, idx1, idx2) = tup + if idx2 == start: + stop = idx1 + assert bal_curr + break + + return start, stop + # pop_open = sequence[0:1] + # pop_close = sequence[head_stop:head_stop + 1] + # head = sequence[1:head_stop] + # tail = sequence[head_stop + 1:] + # head_tail = head + tail + # return pop_open, pop_close, head, tail, head_tail + + +def generate_balance2(sequence, open_to_close, start=0): + """ + Alternate version that also returns index information + + Yields + ------ + bool, T, int, int + is balanced + opening token + opening token index + current token index + + + Example + ------- + >>> open_to_close = {0: 1} + >>> seq = sequence = [0, 0, 0, 1, 1, 1, 0, 1] + >>> gen = list(generate_balance2(sequence, open_to_close)) + >>> for flag, token, idx1, idx2 in gen: + >>> print('flag={:d}, token={}, {}, {}'.format(flag, token, idx1, idx2)) + + balanced_decomp2(sequence, open_to_close) + """ + stack = [] + # Traversing the Expression + for curr_idx, token in enumerate(sequence, start=start): + + if token in open_to_close: + # Push opening elements onto the stack + stack.append((token, curr_idx)) + open_idx = -1 + else: + # Check that closing elements + if not stack: + raise UnbalancedException + prev_open, open_idx = stack.pop() + want_close = open_to_close[prev_open] + + if token != want_close: + raise UnbalancedException + + # If the stack is empty the sequence is currently balanced + currently_balanced = not bool(stack) + yield currently_balanced, token, curr_idx, open_idx + + if stack: + raise UnbalancedException + + +def generate_all_decomp2(full_seq, open_to_close, open_to_node=None): + """ + Alternate version where we keep track of indices instead + + Example + ------- + >>> full_seq = '0010010010111101' + >>> open_to_close = {'0': '1'} + >>> full_seq = '{[{}]}[()]' + >>> open_to_close = {'[': ']', '{': '}', '(': ')'} + >>> list(generate_balance2(full_seq, open_to_close)) + >>> all_decomp = generate_all_decomp2(full_seq, open_to_close) + + >>> from netharn.initializers._nx_ext import demodata + >>> full_seq, open_to_close = demodata.random_balanced_sequence(5, mode='number') + >>> all_decomp = generate_all_decomp2(full_seq, open_to_close) + """ + if open_to_node is None: + open_to_node = IdentityDict() + all_decomp = {} + + start = 0 + stop = len(full_seq) + deleted = [] + stack = [ + ('f', full_seq, start, stop, deleted) + ] + + DEBUG = 1 + + while stack: + t, seq, seq_start, seq_stop, seq_del = stack.pop() + if DEBUG: + import ubelt as ub + print('-----') + print(list(full_seq)) + + isdel = ['X' if b else ' ' for b in ub.boolmask(seq_del, len(full_seq))] + sep = ' : ' + pos = list(' ' * len(full_seq)) + pos[seq_start] = 'S' + pos[seq_stop - 1] = 'T' + prefix = ': ' + def padjoin(s): + return sep.join(['{:>2}'.format(c) for c in s]) + print(prefix + padjoin(range(len(full_seq)))) + print(prefix + padjoin(full_seq) + ' <- full_seq') + print(prefix + padjoin(isdel) + ' <- seq_del') + print(prefix + padjoin(pos) + ' <- seq_start, seq_stop') + + val = seq_start, seq_stop, seq_del + print('seq = {}, {!r}, {}'.format(t, seq, val)) + base = full_seq[seq_start:seq_stop] + print('base = {!r}'.format(base)) + rel_pad_del = [idx - seq_start for idx in seq_del if idx >= seq_start] + keep_idxs = sorted(set(range(len(base))) - set(rel_pad_del)) + newlist = [base[idx] for idx in keep_idxs] + try: + recon = ''.join(newlist) + except TypeError: + recon = tuple(newlist) + print('recon = {!r}'.format(recon)) + if seq: + rel_start, rel_stop = balanced_decomp2(seq, open_to_close) + + rel_head_start = rel_start + 1 + rel_head_stop = rel_stop + rel_tail_start = rel_stop + 1 + rel_tail_stop = len(seq) + if DEBUG > 1: + print('rel_start = {!r}'.format(rel_start)) + print('rel_stop = {!r}'.format(rel_stop)) + print('rel_head_start = {!r}'.format(rel_head_start)) + print('rel_head_stop = {!r}'.format(rel_head_stop)) + print('rel_tail_start = {!r}'.format(rel_tail_start)) + print('rel_tail_stop = {!r}'.format(rel_tail_stop)) + + rel_pad_del = [idx - seq_start for idx in seq_del if seq_start <= idx <= seq_stop] + if DEBUG: + print('rel_pad_del = {!r}'.format(rel_pad_del)) + + # I think there is a cumsum way of doing this, I'm being dense atm + # seq = '3' * 10 + # rel_pad_del = [4, 5, 9, 11] + hack_map = list(range(1 + len(seq) + len(rel_pad_del))) + for idx in sorted(rel_pad_del, reverse=True): + del hack_map[idx] + + if DEBUG: + print('hack_map = {!r}'.format(hack_map)) + + # I believe it is the case that the deleted indexes will only be + # able to cause a shift in the abs_tail_stop, the abs_tail_start, + # abs_head_stop, and abs_head_start should never "conflict" with + # the deleted indexes (I think). + + # num_del_after_tail_start = sum(abs_tail_start <= i <= seq_stop for i in seq_del) + # print('num_del_after_tail_start = {!r}'.format(num_del_after_tail_start)) + # num_del_before_tail_start = sum(0 <= i <= rel_tail_stop for i in rel_pad_del) + + abs_head_start = hack_map[rel_head_start] + seq_start + abs_head_stop = hack_map[rel_head_stop] + seq_start + + abs_tail_start = hack_map[rel_tail_start] + seq_start + abs_tail_stop = hack_map[rel_tail_stop] + seq_start + + if DEBUG > 1: + print('abs_head_start = {!r}'.format(abs_head_start)) + print('abs_head_stop = {!r}'.format(abs_head_stop)) + + print('abs_tail_start = {!r}'.format(abs_tail_start)) + print('abs_tail_stop = {!r}'.format(abs_tail_stop)) + + head_sl = slice(rel_head_start, rel_head_stop) + tail_sl = slice(rel_tail_start, rel_tail_stop) + + head = seq[head_sl] + tail = seq[tail_sl] + head_tail = head + tail + + head_del = seq_del + tail_del = seq_del + + if abs_head_stop == abs_head_start: + # case where tail is empty (which head_tail doesnt matter + # anyway but this is just a POC + abs_head_tail_start = abs_tail_start + else: + abs_head_tail_start = abs_head_start + + if abs_tail_stop == abs_tail_start: + # case where tail is empty (which head_tail doesnt matter + # anyway but this is just a POC + abs_head_tail_stop = abs_head_stop + else: + abs_head_tail_stop = abs_tail_stop + + abs_del_start = seq_start + rel_start + abs_del_stop = seq_start + rel_stop + + # head_tail_del = [abs_del_start, abs_del_stop] + seq_del + assert abs_del_start < abs_head_tail_start + if abs_del_stop < abs_head_tail_stop: + head_tail_del = [abs_del_stop] + seq_del + else: + head_tail_del = seq_del + + # seq[head_sl] + seq[tail_sl] + + # pop_open, pop_close, head, tail, head_tail = balanced_decomp2(seq, open_to_close) + # node = open_to_node[pop_open[0]] + all_decomp[seq] = (seq_start, seq_stop, seq_del) + # (node, pop_open, pop_close, head, tail, head_tail) + + if abs_head_stop > len(full_seq): + raise AssertionError + if abs_tail_stop > len(full_seq): + raise AssertionError + if abs_head_tail_stop > len(full_seq): + raise AssertionError + + if head: + if DEBUG: + print('head = {!r}'.format(head)) + head_del = [i for i in head_del if abs_head_start <= i < abs_head_stop] + stack.append(('h', head, abs_head_start, abs_head_stop, head_del)) + if tail: + if DEBUG: + print('tail = {!r}'.format(tail)) + tail_del = [i for i in tail_del if abs_tail_start <= i < abs_tail_stop] + stack.append(('t', tail, abs_tail_start, abs_tail_stop, tail_del)) + if tail and head: + if DEBUG: + print('head_tail = {!r}'.format(head_tail)) + print('head_tail_del = {!r}'.format(head_tail_del)) + head_tail_del = [i for i in head_tail_del if abs_head_tail_start <= i < abs_head_tail_stop] + stack.append(('ht', head_tail, abs_head_tail_start, abs_head_tail_stop, head_tail_del)) + if DEBUG: + assert seq == recon + + return all_decomp diff --git a/netharn/initializers/_nx_ext/balanced_sequence.py b/netharn/initializers/_nx_ext/balanced_sequence.py new file mode 100644 index 0000000000000000000000000000000000000000..576bdda33dd86601ee6fd5ebf9593ca83e16ea05 --- /dev/null +++ b/netharn/initializers/_nx_ext/balanced_sequence.py @@ -0,0 +1,1146 @@ +""" +Balanced sequences are used via reduction to solve the maximum common subtree +embedding problem. +""" +import operator + + +def longest_common_balanced_sequence( + seq1, seq2, open_to_close, open_to_node=None, + node_affinity='auto', impl='iter-prehash2'): + """ + Finds the longest common balanced sequence between two sequences + + Parameters + ---------- + seq1, seq2: Iterable + two input balanced sequences + + open_to_close : Dict + a mapping from opening to closing tokens in the balanced sequence + + open_to_node : Dict | None + a dictionary that maps a sequence token to a token corresponding to an + original problem (e.g. a tree node), if unspecified an identity mapping + is assumed. FIXME: see outstanding issues. + WILL LIKELY CHANGE IN THE FUTURE + + node_affinity : None | str | callable + Function for to determine if two nodes can be matched. The return is + interpreted as a weight that is used to break ties. If None then any + node can match any other node and only the topology is important. + The default is "eq", which is the same as ``operator.eq``. + + impl : str + Determines the backend implementation. There are currently 8 different + backend implementations: + + recurse, iter, iter-prehash, iter-prehash2, iter-alt1, iter-alt2, + iter-alt2-cython, and iter-prehash2-cython. + + Example + ------- + >>> # extremely simple case + >>> seq1 = '[][[]][]' + >>> seq2 = '[[]][[]]' + >>> open_to_close = {'[': ']'} + >>> best, value = longest_common_balanced_sequence(seq1, seq2, open_to_close) + >>> subseq1, subseq2 = best + >>> print('subseq1 = {!r}'.format(subseq1)) + subseq1 = '[][[]]' + + >>> # 1-label case from the paper (see Example 5) + >>> # https://pdfs.semanticscholar.org/0b6e/061af02353f7d9b887f9a378be70be64d165.pdf + >>> seq1 = '0010010010111100001011011011' + >>> seq2 = '001000101101110001000100101110111011' + >>> open_to_close = {'0': '1'} + >>> best, value = longest_common_balanced_sequence(seq1, seq2, open_to_close) + >>> subseq1, subseq2 = best + >>> print('subseq1 = {!r}'.format(subseq1)) + >>> assert value == 13 + subseq1 = '00100101011100001011011011' + + >>> # 3-label case + >>> seq1 = '{({})([[]([]){(()(({()[]({}{})}))){}}])}' + >>> seq2 = '{[({{}}{{[][{}]}(()[(({()})){[]()}])})]}' + >>> open_to_close = {'{': '}', '(': ')', '[': ']'} + >>> best, value = longest_common_balanced_sequence(seq1, seq2, open_to_close) + >>> subseq1, subseq2 = best + >>> print('subseq1 = {!r}'.format(subseq1)) + >>> assert value == 10 + subseq1 = '{{}[][]()(({()})){}}' + """ + if node_affinity == 'auto' or node_affinity == 'eq': + node_affinity = operator.eq + if node_affinity is None: + def _matchany(a, b): + return True + node_affinity = _matchany + if open_to_node is None: + open_to_node = IdentityDict() + full_seq1 = seq1 + full_seq2 = seq2 + if impl == 'auto': + if _cython_lcs_backend(): + impl = 'iter-alt2-cython' + else: + impl = 'iter-alt2' + + if impl == 'recurse': + _memo = {} + _seq_memo = {} + best, value = _lcs_recurse( + full_seq1, full_seq2, open_to_close, node_affinity, open_to_node, + _memo, _seq_memo) + elif impl == 'iter': + best, value = _lcs_iter_simple( + full_seq1, full_seq2, open_to_close, node_affinity, open_to_node) + elif impl == 'iter-prehash': + best, value = _lcs_iter_prehash( + full_seq1, full_seq2, open_to_close, node_affinity, open_to_node) + elif impl == 'iter-prehash2': + best, value = _lcs_iter_prehash2( + full_seq1, full_seq2, open_to_close, node_affinity, open_to_node) + elif impl == 'iter-alt1': + best, value = _lcs_iter_simple_alt1( + full_seq1, full_seq2, open_to_close, node_affinity, open_to_node) + elif impl == 'iter-alt2': + best, value = _lcs_iter_simple_alt2( + full_seq1, full_seq2, open_to_close, node_affinity, open_to_node) + elif impl == 'iter-alt2-cython': + balanced_sequence_cython = _cython_lcs_backend(error='raise') + best, value = balanced_sequence_cython._lcs_iter_simple_alt2_cython( + full_seq1, full_seq2, open_to_close, node_affinity, open_to_node) + elif impl == 'iter-prehash2-cython': + balanced_sequence_cython = _cython_lcs_backend(error='raise') + best, value = balanced_sequence_cython._lcs_iter_prehash2_cython( + full_seq1, full_seq2, open_to_close, node_affinity, open_to_node) + else: + raise KeyError(impl) + return best, value + + +def available_impls_longest_common_balanced_sequence(): + """ + Returns all available implementations for + :func:`longest_common_balanced_sequence`. + """ + from netharn.initializers._nx_ext import balanced_sequence + impls = [] + if balanced_sequence._cython_lcs_backend(): + impls += [ + 'iter-alt2-cython', + 'iter-prehash2-cython', + ] + + # Pure python backends + impls += [ + 'iter-prehash2', + 'iter-alt2', + 'iter-alt1', + 'iter-prehash', + 'iter', + 'recurse', + ] + return impls + + +def _cython_lcs_backend(error='ignore'): + """ + Returns the cython backend if available, otherwise None + """ + try: + from netharn.initializers._nx_ext import balanced_sequence_cython + except Exception: + if error == 'ignore': + return None + elif error == 'raise': + raise + else: + raise KeyError(error) + else: + return balanced_sequence_cython + + +def _lcs_iter_simple_alt2(full_seq1, full_seq2, open_to_close, node_affinity, open_to_node): + """ + Depth first stack trajectory and replace try except statements with ifs + + This is the current best pure-python algorithm candidate + + >>> full_seq1 = '{({})([[]([]){(()(({()[]({}{})}))){}}])}' + >>> full_seq2 = '{[({{}}{{[][{}]}(()[(({()})){[]()}])})]}' + >>> open_to_close = {'{': '}', '(': ')', '[': ']'} + >>> full_seq1 = '[][[]][]' + >>> full_seq2 = '[[]][[]]' + >>> open_to_close = {'[': ']'} + >>> import operator as op + >>> node_affinity = op.eq + >>> open_to_node = IdentityDict() + >>> res = _lcs_iter_simple_alt2(full_seq1, full_seq2, open_to_close, node_affinity, open_to_node) + >>> val, embeddings = res + """ + all_decomp1 = generate_all_decomp(full_seq1, open_to_close, open_to_node) + all_decomp2 = generate_all_decomp(full_seq2, open_to_close, open_to_node) + + key0 = (full_seq1, full_seq2) + frame0 = key0 + stack = [frame0] + + # Memoize mapping (seq1, seq2) -> best size, embeddings, deleted edges + _results = {} + + # Populate base cases + empty1 = type(next(iter(all_decomp1.keys())))() + empty2 = type(next(iter(all_decomp2.keys())))() + best = (empty1, empty2) + base_result = (0, best) + for seq1 in all_decomp1.keys(): + key1 = seq1 + t1, a1, b1, head1, tail1, head_tail1 = all_decomp1[key1] + _results[(seq1, empty2)] = base_result + _results[(head1, empty2)] = base_result + _results[(tail1, empty2)] = base_result + _results[(head_tail1, empty2)] = base_result + + for seq2 in all_decomp2.keys(): + key2 = seq2 + t2, a2, b2, head2, tail2, head_tail2 = all_decomp2[key2] + _results[(empty1, seq2)] = base_result + _results[(empty1, head2)] = base_result + _results[(empty1, tail2)] = base_result + _results[(empty1, head_tail2)] = base_result + + del frame0 + del empty1 + del empty2 + del best + del base_result + + while stack: + key = stack[-1] + if key not in _results: + seq1, seq2 = key + + t1, a1, b1, head1, tail1, head_tail1 = all_decomp1[seq1] + t2, a2, b2, head2, tail2, head_tail2 = all_decomp2[seq2] + + # Case 2: The current edge in sequence1 is deleted + try_key = (head_tail1, seq2) + if try_key in _results: + cand1 = _results[try_key] + else: + # stack.append(key) + stack.append(try_key) + continue + + # Case 3: The current edge in sequence2 is deleted + try_key = (seq1, head_tail2) + if try_key in _results: + cand2 = _results[try_key] + else: + # stack.append(key) + stack.append(try_key) + continue + + # Case 1: The LCS involves this edge + affinity = node_affinity(t1, t2) + if affinity: + try_key = (head1, head2) + if try_key in _results: + pval_h, new_heads = _results[try_key] + else: + # stack.append(key) + stack.append(try_key) + continue + + try_key = (tail1, tail2) + if try_key in _results: + pval_t, new_tails = _results[try_key] + else: + # stack.append(key) + stack.append(try_key) + continue + + new_head1, new_head2 = new_heads + new_tail1, new_tail2 = new_tails + + subseq1 = a1 + new_head1 + b1 + new_tail1 + subseq2 = a2 + new_head2 + b2 + new_tail2 + + res3 = (subseq1, subseq2) + val3 = pval_h + pval_t + affinity + cand3 = (val3, res3) + else: + cand3 = (-1, None) + + # We solved the frame + _results[key] = max(cand1, cand2, cand3) + stack.pop() + + val, best = _results[key0] + found = (best, val) + return found + + +def _lcs_iter_prehash2(full_seq1, full_seq2, open_to_close, node_affinity, open_to_node): + """ + Version of the lcs iterative algorithm where we precompute hash values + + See :func:`longest_common_balanced_sequence` for parameter details. + """ + + all_decomp1 = generate_all_decomp_prehash(full_seq1, open_to_close, open_to_node) + all_decomp2 = generate_all_decomp_prehash(full_seq2, open_to_close, open_to_node) + + key_decomp1 = {} + key_decomp2 = {} + _results = {} + # Populate base cases + empty1 = type(next(iter(all_decomp1.keys())))() + empty2 = type(next(iter(all_decomp2.keys())))() + empty1_key = hash(empty1) + empty2_key = hash(empty2) + best = (empty1, empty2) + base_result = (0, best) + for seq1, info1 in all_decomp1.items(): + seq1_key = hash(seq1) + head1_key, tail1_key, head_tail1_key = all_decomp1[seq1][5:8] + _results[(seq1_key, empty2_key)] = base_result + _results[(head1_key, empty2_key)] = base_result + _results[(tail1_key, empty2_key)] = base_result + _results[(head_tail1_key, empty2_key)] = base_result + key_decomp1[seq1_key] = info1 + + for seq2, info2 in all_decomp2.items(): + seq2_key = hash(seq2) + head2_key, tail2_key, head_tail2_key = all_decomp2[seq2][5:8] + _results[(empty1_key, seq2_key)] = base_result + _results[(empty1_key, head2_key)] = base_result + _results[(empty1_key, tail2_key)] = base_result + _results[(empty1_key, head_tail2_key)] = base_result + key_decomp2[seq2_key] = info2 + + full_seq1_key = hash(full_seq1) + full_seq2_key = hash(full_seq2) + key0 = (full_seq1_key, full_seq2_key) + frame0 = key0, full_seq1, full_seq2 + stack = [frame0] + missing_frames = [] + while stack: + frame = stack[-1] + key, seq1, seq2 = frame + seq1_key, seq2_key = key + if key not in _results: + missing_frames.clear() + + info1 = key_decomp1[seq1_key] + tok1, seq1, head1, tail1, head_tail1, head1_key, tail1_key, head_tail1_key, a1, b1 = info1 + + # if seq2_key not in key_decomp2: + info2 = key_decomp2[seq2_key] + tok2, seq2, head2, tail2, head_tail2, head2_key, tail2_key, head_tail2_key, a2, b2 = info2 + + affinity = node_affinity(tok1, tok2) + + # Case 2: The current edge in sequence1 is deleted + try_key = (head_tail1_key, seq2_key) + if try_key in _results: + cand1 = _results[try_key] + else: + miss_frame = try_key, head_tail1, seq2 + stack.append(miss_frame) + continue + + # Case 3: The current edge in sequence2 is deleted + try_key = (seq1_key, head_tail2_key) + if try_key in _results: + cand2 = _results[try_key] + else: + miss_frame = try_key, seq1, head_tail2 + stack.append(miss_frame) + continue + + # Case 1: The LCS involves this edge + if affinity: + try_key = (head1_key, head2_key) + if try_key in _results: + pval_h, new_heads = _results[try_key] + else: + miss_frame = try_key, head1, head2 + stack.append(miss_frame) + continue + + try_key = (tail1_key, tail2_key) + if try_key in _results: + pval_t, new_tails = _results[try_key] + else: + miss_frame = try_key, tail1, tail2 + stack.append(miss_frame) + continue + + new_head1, new_head2 = new_heads + new_tail1, new_tail2 = new_tails + + subseq1 = a1 + new_head1 + b1 + new_tail1 + subseq2 = a2 + new_head2 + b2 + new_tail2 + + res3 = (subseq1, subseq2) + val3 = pval_h + pval_t + affinity + cand3 = (val3, res3) + else: + cand3 = (-1, None) + + # We solved the frame + _results[key] = max(cand1, cand2, cand3) + stack.pop() + + # The stack pop is our solution + (val, best) = _results[key0] + found = (best, val) + return found + + +def _lcs_recurse(seq1, seq2, open_to_close, node_affinity, open_to_node, _memo, _seq_memo): + """ + Surprisingly, this recursive implementation is one of the faster + pure-python methods for certain input types. However, its major drawback is + that it can raise a RecurssionError if the inputs are too deep. + """ + if not seq1: + return (seq1, seq1), 0 + elif not seq2: + return (seq2, seq2), 0 + else: + key1 = hash(seq1) # using hash(seq) is faster than seq itself + key2 = hash(seq2) + key = hash((key1, key2)) + if key in _memo: + return _memo[key] + + if key1 in _seq_memo: + a1, b1, head1, tail1, head1_tail1 = _seq_memo[key1] + else: + a1, b1, head1, tail1, head1_tail1 = balanced_decomp_unsafe(seq1, open_to_close) + _seq_memo[key1] = a1, b1, head1, tail1, head1_tail1 + + if key2 in _seq_memo: + a2, b2, head2, tail2, head2_tail2 = _seq_memo[key2] + else: + a2, b2, head2, tail2, head2_tail2 = balanced_decomp_unsafe(seq2, open_to_close) + _seq_memo[key2] = a2, b2, head2, tail2, head2_tail2 + + # Case 2: The current edge in sequence1 is deleted + best, val = _lcs_recurse(head1_tail1, seq2, open_to_close, node_affinity, open_to_node, _memo, _seq_memo) + + # Case 3: The current edge in sequence2 is deleted + cand, val_alt = _lcs_recurse(seq1, head2_tail2, open_to_close, node_affinity, open_to_node, _memo, _seq_memo) + if val_alt > val: + best = cand + val = val_alt + + # Case 1: The LCS involves this edge + t1 = open_to_node[a1[0]] + t2 = open_to_node[a2[0]] + affinity = node_affinity(t1, t2) + if affinity: + new_heads, pval_h = _lcs_recurse(head1, head2, open_to_close, node_affinity, open_to_node, _memo, _seq_memo) + new_tails, pval_t = _lcs_recurse(tail1, tail2, open_to_close, node_affinity, open_to_node, _memo, _seq_memo) + + new_head1, new_head2 = new_heads + new_tail1, new_tail2 = new_tails + + subseq1 = a1 + new_head1 + b1 + new_tail1 + subseq2 = a2 + new_head2 + b2 + new_tail2 + + cand = (subseq1, subseq2) + val_alt = pval_h + pval_t + affinity + if val_alt > val: + best = cand + val = val_alt + + found = (best, val) + _memo[key] = found + return found + + +def _lcs_iter_simple(full_seq1, full_seq2, open_to_close, node_affinity, open_to_node): + """ + Converts _lcs_recursive to an iterative algorithm using a fairly + straightforward method that effectivly simulates callstacks. + Uses a breadth-first trajectory and try-except to catch missing + memoized results (which seems to be slightly slower than if statements). + """ + all_decomp1 = generate_all_decomp(full_seq1, open_to_close, open_to_node) + all_decomp2 = generate_all_decomp(full_seq2, open_to_close, open_to_node) + + args0 = (full_seq1, full_seq2) + frame0 = args0 + stack = [frame0] + + _results = {} + # Populate base cases + empty1 = type(next(iter(all_decomp1.keys())))() + empty2 = type(next(iter(all_decomp2.keys())))() + best = (empty1, empty2) + base_result = (0, best) + for seq1 in all_decomp1.keys(): + key1 = seq1 + t1, a1, b1, head1, tail1, head_tail1 = all_decomp1[key1] + _results[(seq1, empty2)] = base_result + _results[(head1, empty2)] = base_result + _results[(tail1, empty2)] = base_result + _results[(head_tail1, empty2)] = base_result + + for seq2 in all_decomp2.keys(): + key2 = seq2 + t2, a2, b2, head2, tail2, head_tail2 = all_decomp2[key2] + _results[(empty1, seq2)] = base_result + _results[(empty1, head2)] = base_result + _results[(empty1, tail2)] = base_result + _results[(empty1, head_tail2)] = base_result + + del args0 + del frame0 + del empty1 + del empty2 + del best + del base_result + + missing_frames = [] + while stack: + key = stack.pop() + if key not in _results: + seq1, seq2 = key + missing_frames.clear() + + t1, a1, b1, head1, tail1, head_tail1 = all_decomp1[seq1] + t2, a2, b2, head2, tail2, head_tail2 = all_decomp2[seq2] + + # Case 2: The current edge in sequence1 is deleted + try: + try_key = (head_tail1, seq2) + cand1 = _results[try_key] + except KeyError: + missing_frames.append(try_key) + + # Case 3: The current edge in sequence2 is deleted + try: + try_key = (seq1, head_tail2) + cand2 = _results[try_key] + except KeyError: + missing_frames.append(try_key) + + # Case 1: The LCS involves this edge + affinity = node_affinity(t1, t2) + if affinity: + try: + try_key = (head1, head2) + pval_h, new_heads = _results[try_key] + except KeyError: + missing_frames.append(try_key) + + try: + try_key = (tail1, tail2) + pval_t, new_tails = _results[try_key] + except KeyError: + missing_frames.append(try_key) + + if not missing_frames: + new_head1, new_head2 = new_heads + new_tail1, new_tail2 = new_tails + + subseq1 = a1 + new_head1 + b1 + new_tail1 + subseq2 = a2 + new_head2 + b2 + new_tail2 + + res3 = (subseq1, subseq2) + val3 = pval_h + pval_t + affinity + cand3 = (val3, res3) + else: + cand3 = (-1, None) + + if missing_frames: + # We did not solve this frame yet + stack.append(key) + stack.extend(missing_frames) + # stack.extend(missing_frames[::-1]) + else: + # We solved the frame + _results[key] = max(cand1, cand2, cand3) + + val, best = _results[key] + found = (best, val) + return found + + +def _lcs_iter_simple_alt1(full_seq1, full_seq2, open_to_close, node_affinity, open_to_node): + """ + Depth first stack trajectory + """ + all_decomp1 = generate_all_decomp(full_seq1, open_to_close, open_to_node) + all_decomp2 = generate_all_decomp(full_seq2, open_to_close, open_to_node) + + args0 = (full_seq1, full_seq2) + frame0 = args0 + stack = [frame0] + + _results = {} + # Populate base cases + empty1 = type(next(iter(all_decomp1.keys())))() + empty2 = type(next(iter(all_decomp2.keys())))() + best = (empty1, empty2) + base_result = (0, best) + for seq1 in all_decomp1.keys(): + key1 = seq1 + t1, a1, b1, head1, tail1, head_tail1 = all_decomp1[key1] + _results[(seq1, empty2)] = base_result + _results[(head1, empty2)] = base_result + _results[(tail1, empty2)] = base_result + _results[(head_tail1, empty2)] = base_result + + for seq2 in all_decomp2.keys(): + key2 = seq2 + t2, a2, b2, head2, tail2, head_tail2 = all_decomp2[key2] + _results[(empty1, seq2)] = base_result + _results[(empty1, head2)] = base_result + _results[(empty1, tail2)] = base_result + _results[(empty1, head_tail2)] = base_result + + del args0 + del frame0 + del empty1 + del empty2 + del best + del base_result + + while stack: + key = stack.pop() + if key not in _results: + seq1, seq2 = key + + t1, a1, b1, head1, tail1, head_tail1 = all_decomp1[seq1] + t2, a2, b2, head2, tail2, head_tail2 = all_decomp2[seq2] + + # Case 2: The current edge in sequence1 is deleted + try: + try_key = (head_tail1, seq2) + cand1 = _results[try_key] + except KeyError: + stack.append(key) + stack.append(try_key) + continue + + # Case 3: The current edge in sequence2 is deleted + try: + try_key = (seq1, head_tail2) + cand2 = _results[try_key] + except KeyError: + stack.append(key) + stack.append(try_key) + continue + + # Case 1: The LCS involves this edge + affinity = node_affinity(t1, t2) + if affinity: + try: + try_key = (head1, head2) + pval_h, new_heads = _results[try_key] + except KeyError: + stack.append(key) + stack.append(try_key) + continue + + try: + try_key = (tail1, tail2) + pval_t, new_tails = _results[try_key] + except KeyError: + stack.append(key) + stack.append(try_key) + continue + + new_head1, new_head2 = new_heads + new_tail1, new_tail2 = new_tails + + subseq1 = a1 + new_head1 + b1 + new_tail1 + subseq2 = a2 + new_head2 + b2 + new_tail2 + + res3 = (subseq1, subseq2) + val3 = pval_h + pval_t + affinity + cand3 = (val3, res3) + else: + cand3 = (-1, None) + + # We solved the frame + _results[key] = max(cand1, cand2, cand3) + + val, best = _results[key] + found = (best, val) + return found + + +def _lcs_iter_prehash(full_seq1, full_seq2, open_to_close, node_affinity, open_to_node): + """ + Version of the lcs iterative algorithm where we precompute hash values. + Uses a breadth-first trajectory. + """ + all_decomp1 = generate_all_decomp_prehash(full_seq1, open_to_close, open_to_node) + all_decomp2 = generate_all_decomp_prehash(full_seq2, open_to_close, open_to_node) + + key_decomp1 = {} + key_decomp2 = {} + _results = {} + # Populate base cases + empty1 = type(next(iter(all_decomp1.keys())))() + empty2 = type(next(iter(all_decomp2.keys())))() + empty1_key = hash(empty1) + empty2_key = hash(empty2) + best = (empty1, empty2) + base_result = (0, best) + for seq1, info1 in all_decomp1.items(): + seq1_key = hash(seq1) + head1_key, tail1_key, head_tail1_key = all_decomp1[seq1][5:8] + _results[(seq1_key, empty2_key)] = base_result + _results[(head1_key, empty2_key)] = base_result + _results[(tail1_key, empty2_key)] = base_result + _results[(head_tail1_key, empty2_key)] = base_result + key_decomp1[seq1_key] = info1 + + for seq2, info2 in all_decomp2.items(): + seq2_key = hash(seq2) + head2_key, tail2_key, head_tail2_key = all_decomp2[seq2][5:8] + _results[(empty1_key, seq2_key)] = base_result + _results[(empty1_key, head2_key)] = base_result + _results[(empty1_key, tail2_key)] = base_result + _results[(empty1_key, head_tail2_key)] = base_result + key_decomp2[seq2_key] = info2 + + full_seq1_key = hash(full_seq1) + full_seq2_key = hash(full_seq2) + key0 = (full_seq1_key, full_seq2_key) + frame0 = key0, full_seq1, full_seq2 + stack = [frame0] + missing_frames = [] + while stack: + frame = stack.pop() + key, seq1, seq2 = frame + seq1_key, seq2_key = key + if key not in _results: + missing_frames.clear() + + try: + info1 = key_decomp1[seq1_key] + except KeyError: + info1 = balanced_decomp_prehash(seq1, open_to_close) + key_decomp1[seq1_key] = info1 + tok1, seq1, head1, tail1, head_tail1, head1_key, tail1_key, head_tail1_key, a1, b1 = info1 + + try: + info2 = key_decomp2[seq2_key] + except KeyError: + info2 = balanced_decomp_prehash(seq2, open_to_close) + key_decomp2[seq2_key] = info2 + tok2, seq2, head2, tail2, head_tail2, head2_key, tail2_key, head_tail2_key, a2, b2 = info2 + + affinity = node_affinity(tok1, tok2) + + # Case 2: The current edge in sequence1 is deleted + try: + try_key = (head_tail1_key, seq2_key) + cand1 = _results[try_key] + except KeyError: + miss_frame = try_key, head_tail1, seq2 + missing_frames.append(miss_frame) + + # Case 3: The current edge in sequence2 is deleted + try: + try_key = (seq1_key, head_tail2_key) + cand2 = _results[try_key] + except KeyError: + miss_frame = try_key, seq1, head_tail2 + missing_frames.append(miss_frame) + + # Case 1: The LCS involves this edge + if affinity: + try: + try_key = (head1_key, head2_key) + pval_h, new_heads = _results[try_key] + except KeyError: + miss_frame = try_key, head1, head2 + missing_frames.append(miss_frame) + + try: + try_key = (tail1_key, tail2_key) + pval_t, new_tails = _results[try_key] + except KeyError: + miss_frame = try_key, tail1, tail2 + missing_frames.append(miss_frame) + + if not missing_frames: + new_head1, new_head2 = new_heads + new_tail1, new_tail2 = new_tails + + subseq1 = a1 + new_head1 + b1 + new_tail1 + subseq2 = a2 + new_head2 + b2 + new_tail2 + + res3 = (subseq1, subseq2) + val3 = pval_h + pval_t + affinity + cand3 = (val3, res3) + else: + cand3 = (-1, None) + + if missing_frames: + # We did not solve this frame yet + stack.append(frame) + stack.extend(missing_frames[::-1]) + else: + # We solved the frame + _results[key] = max(cand1, cand2, cand3) + + # The stack pop is our solution + (val, best) = _results[key] + found = (best, val) + return found + + +class UnbalancedException(Exception): + """ + Denotes that a sequence was unbalanced + """ + pass + + +class IdentityDict: + """ + Used when ``open_to_node`` is unspecified + """ + def __getitem__(self, key): + return key + + +def generate_all_decomp(seq, open_to_close, open_to_node=None): + """ + Generates all decompositions of a single balanced sequence by + recursive decomposition of the head, tail, and head|tail. + + Parameters + ---------- + seq : Tuple | str + a tuple of hashable items or a string where each character is an item + + open_to_close : Dict + a dictionary that maps opening tokens to closing tokens in the balanced + sequence problem. + + open_to_node : Dict + a dictionary that maps a sequence token to a token corresponding to an + original problem (e.g. a tree node) + + Returns + ------- + Dict : mapping from a sub-sequence to its decomposition + + Notes + ----- + In the paper: See Definition 2, 4, Lemma, 1, 2, 3, 4. + + Example + ------- + >>> # Example 2 in the paper (one from each column) + >>> seq = '00100100101111' + >>> open_to_close = {'0': '1'} + >>> all_decomp = generate_all_decomp(seq, open_to_close) + >>> assert len(all_decomp) == len(seq) // 2 + >>> import pprint + >>> pprint.pprint(all_decomp) + {'00100100101111': ('0', '0', '1', '010010010111', '', '010010010111'), + '0010010111': ('0', '0', '1', '01001011', '', '01001011'), + '001011': ('0', '0', '1', '0101', '', '0101'), + '01': ('0', '0', '1', '', '', ''), + '010010010111': ('0', '0', '1', '', '0010010111', '0010010111'), + '01001011': ('0', '0', '1', '', '001011', '001011'), + '0101': ('0', '0', '1', '', '01', '01')} + + Example + ------- + >>> open_to_close = {'{': '}', '(': ')', '[': ']'} + >>> seq = '({[[]]})[[][]]{{}}' + >>> all_decomp = generate_all_decomp(seq, open_to_close) + >>> node, *decomp = all_decomp[seq] + >>> pop_open, pop_close, head, tail, head_tail = decomp + >>> print('node = {!r}'.format(node)) + >>> print('pop_open = {!r}'.format(pop_open)) + >>> print('pop_close = {!r}'.format(pop_close)) + >>> print('head = {!r}'.format(head)) + >>> print('tail = {!r}'.format(tail)) + >>> print('head_tail = {!r}'.format(head_tail)) + node = '(' + pop_open = '(' + pop_close = ')' + head = '{[[]]}' + tail = '[[][]]{{}}' + head_tail = '{[[]]}[[][]]{{}}' + >>> decomp_alt = balanced_decomp(seq, open_to_close) + >>> assert decomp_alt == tuple(decomp) + + Example + ------- + >>> from netharn.initializers._nx_ext.demodata import random_balanced_sequence + >>> seq, open_to_close = random_balanced_sequence(10) + >>> all_decomp = generate_all_decomp(seq, open_to_close) + """ + if open_to_node is None: + open_to_node = IdentityDict() + all_decomp = {} + stack = [seq] + while stack: + seq = stack.pop() + if seq not in all_decomp and seq: + pop_open, pop_close, head, tail, head_tail = balanced_decomp(seq, open_to_close) + node = open_to_node[pop_open[0]] + all_decomp[seq] = (node, pop_open, pop_close, head, tail, head_tail) + if head: + if tail: + stack.append(head_tail) + stack.append(tail) + stack.append(head) + elif tail: + stack.append(tail) + return all_decomp + + +def balanced_decomp(sequence, open_to_close): + """ + Generates a decomposition of a balanced sequence. + + Parameters + ---------- + sequence : str + balanced sequence to be decomposed + + open_to_close: dict + a dictionary that maps opening tokens to closing tokens in the balanced + sequence problem. + + Returns + ------- + : tuple[T, T, T, T, T] + where ``T = type(sequence)`` + Contents of this tuple are: + + 0. a1 - a sequence of len(1) containing the current opening token + 1. b1 - a sequence of len(1) containing the current closing token + 2. head - head of the sequence + 3. tail - tail of the sequence + 4. head_tail - the concatanted head and tail + + Example + ------- + >>> # Example 3 from the paper + >>> sequence = '001000101101110001000100101110111011' + >>> open_to_close = {'0': '1'} + >>> a1, b1, head, tail, head_tail = balanced_decomp(sequence, open_to_close) + >>> print('head = {!r}'.format(head)) + >>> print('tail = {!r}'.format(tail)) + head = '010001011011' + tail = '0001000100101110111011' + + Example + ------- + >>> open_to_close = {0: 1} + >>> sequence = [0, 0, 0, 1, 1, 1, 0, 1] + >>> a1, b1, head, tail, head_tail = balanced_decomp(sequence, open_to_close) + >>> print('a1 = {!r}'.format(a1)) + >>> print('b1 = {!r}'.format(b1)) + >>> print('head = {!r}'.format(head)) + >>> print('tail = {!r}'.format(tail)) + >>> print('head_tail = {!r}'.format(head_tail)) + a1 = [0] + b1 = [1] + head = [0, 0, 1, 1] + tail = [0, 1] + head_tail = [0, 0, 1, 1, 0, 1] + >>> a2, b2, tail1, tail2, head_tail2 = balanced_decomp(tail, open_to_close) + + Example + ------- + >>> open_to_close = {'{': '}', '(': ')', '[': ']'} + >>> sequence = '({[[]]})[[][]]' + >>> a1, b1, head, tail, head_tail = balanced_decomp(sequence, open_to_close) + >>> print('a1 = {!r}'.format(a1)) + >>> print('b1 = {!r}'.format(b1)) + >>> print('head = {!r}'.format(head)) + >>> print('tail = {!r}'.format(tail)) + >>> print('head_tail = {!r}'.format(head_tail)) + a1 = '(' + b1 = ')' + head = '{[[]]}' + tail = '[[][]]' + head_tail = '{[[]]}[[][]]' + >>> a2, b2, tail1, tail2, head_tail2 = balanced_decomp(tail, open_to_close) + >>> print('a2 = {!r}'.format(a2)) + >>> print('b2 = {!r}'.format(b2)) + >>> print('tail1 = {!r}'.format(tail1)) + >>> print('tail2 = {!r}'.format(tail2)) + >>> print('head_tail2 = {!r}'.format(head_tail2)) + a2 = '[' + b2 = ']' + tail1 = '[][]' + tail2 = '' + head_tail2 = '[][]' + """ + gen = generate_balance(sequence, open_to_close) + + bal_curr, tok_curr = next(gen) + pop_open = sequence[0:1] + want_close = open_to_close[tok_curr] + + head_stop = 1 + for head_stop, (bal_curr, tok_curr) in enumerate(gen, start=1): + if tok_curr is None: + break + elif bal_curr and tok_curr == want_close: + pop_close = sequence[head_stop:head_stop + 1] + break + head = sequence[1:head_stop] + tail = sequence[head_stop + 1:] + head_tail = head + tail + return pop_open, pop_close, head, tail, head_tail + + +def generate_balance(sequence, open_to_close): + """ + Iterates through a balanced sequence and reports if the sequence-so-far + is balanced at that position or not. + + Parameters + ---------- + sequence: List[Tuple] | str: + an input balanced sequence + + open_to_close : Dict + a mapping from opening to closing tokens in the balanced sequence + + Raises + ------ + UnbalancedException - if the input sequence is not balanced + + Yields + ------ + Tuple[bool, T]: + boolean indicating if the sequence is balanced at this index, + and the current token + + Example + ------- + >>> open_to_close = {0: 1} + >>> sequence = [0, 0, 0, 1, 1, 1] + >>> gen = list(generate_balance(sequence, open_to_close)) + >>> for flag, token in gen: + >>> print('flag={:d}, token={}'.format(flag, token)) + + Example + ------- + >>> from netharn.initializers._nx_ext.demodata import random_balanced_sequence + >>> sequence, open_to_close = random_balanced_sequence(4) + >>> print('sequence = {!r}'.format(sequence)) + >>> gen = list(generate_balance(sequence, open_to_close)) + >>> for flag, token in gen: + >>> print('flag={:d}, token={}'.format(flag, token)) + """ + stack = [] + # Traversing the Expression + for token in sequence: + + if token in open_to_close: + # Push opening elements onto the stack + stack.append(token) + else: + # Check that closing elements + if not stack: + raise UnbalancedException + prev_open = stack.pop() + want_close = open_to_close[prev_open] + + if token != want_close: + raise UnbalancedException + + # If the stack is empty the sequence is currently balanced + currently_balanced = not bool(stack) + yield currently_balanced, token + + if stack: + raise UnbalancedException + + +def generate_all_decomp_prehash(seq, open_to_close, open_to_node): + """ + Like :func:`generate_all_decomp` but additionally returns the + precomputed hashes of the sequences. + """ + all_decomp = {} + stack = [seq] + while stack: + seq = stack.pop() + if seq: + # key = hash(seq) + key = seq + if key not in all_decomp: + info = balanced_decomp_prehash(seq, open_to_close, open_to_node) + head, tail, head_tail = info[2:5] + all_decomp[key] = info + stack.append(head_tail) + stack.append(head) + stack.append(tail) + return all_decomp + + +def balanced_decomp_prehash(seq, open_to_close, open_to_node): + """ + Like :func:`balanced_decomp` but additionally returns the + precomputed hashes of the sequences. + """ + pop_open, pop_close, head, tail, head_tail = balanced_decomp_unsafe(seq, open_to_close) + head_key = hash(head) + tail_key = hash(tail) + head_tail_key = hash(head_tail) + node = open_to_node[pop_open[0]] + a = pop_open + b = pop_close + info = (node, seq, head, tail, head_tail, head_key, tail_key, head_tail_key, a, b) + return info + + +def balanced_decomp_unsafe(sequence, open_to_close): + """ + Same as :func:`balanced_decomp` but assumes that ``sequence`` is valid + balanced sequence in order to execute faster. + """ + gen = generate_balance_unsafe(sequence, open_to_close) + + bal_curr, tok_curr = next(gen) + pop_open = sequence[0:1] + want_close = open_to_close[tok_curr] + + head_stop = 1 + for head_stop, (bal_curr, tok_curr) in enumerate(gen, start=1): + if bal_curr and tok_curr == want_close: + pop_close = sequence[head_stop:head_stop + 1] + break + head = sequence[1:head_stop] + tail = sequence[head_stop + 1:] + head_tail = head + tail + return pop_open, pop_close, head, tail, head_tail + + +def generate_balance_unsafe(sequence, open_to_close): + """ + Same as :func:`generate_balance` but assumes that ``sequence`` is valid + balanced sequence in order to execute faster. + """ + stacklen = 0 + for token in sequence: + if token in open_to_close: + stacklen += 1 + else: + stacklen -= 1 + yield stacklen == 0, token diff --git a/netharn/initializers/_nx_ext/balanced_sequence_cython.pyx b/netharn/initializers/_nx_ext/balanced_sequence_cython.pyx new file mode 100644 index 0000000000000000000000000000000000000000..b9f0df5279b345aeb0fab4b00671363ff99d7e93 --- /dev/null +++ b/netharn/initializers/_nx_ext/balanced_sequence_cython.pyx @@ -0,0 +1,344 @@ +# distutils: language = c++ +""" +This module re-implements functions in :module:`balanced_sequence` in cython +and obtains 40-50x speedups in common circumstances. There are likely more +speed improvements that could be made. + +CommandLine +----------- +# Explicitly build this cython module (must in networkx repo root) +cythonize -a -i networkx/algorithms/isomorphism/_embedding/balanced_sequence_cython.pyx + + +Examples +-------- +>>> from networkx.algorithms.isomorphism._embedding.balanced_sequence_cython import _lcs_iter_prehash2_cython +>>> from networkx.algorithms.isomorphism._embedding.balanced_sequence_cython import _lcs_iter_simple_alt2_cython +>>> from networkx.algorithms.isomorphism._embedding.demodata import random_balanced_sequence +>>> seq1, open_to_close1 = random_balanced_sequence(300, mode='paren') +>>> seq2, open_to_close2 = random_balanced_sequence(300, mode='paren') +>>> open_to_close = {**open_to_close1, **open_to_close2} +>>> full_seq1 = seq1 +>>> full_seq2 = seq2 +>>> import operator +>>> node_affinity = operator.eq +>>> open_to_node = IdentityDict() +>>> best1, value1 = _lcs_iter_prehash2_cython(full_seq1, full_seq2, open_to_close, node_affinity, open_to_node) +>>> best2, value2 = _lcs_iter_simple_alt2_cython(full_seq1, full_seq2, open_to_close, node_affinity, open_to_node) +>>> assert value1 == value1 +""" + + +def _lcs_iter_prehash2_cython(full_seq1, full_seq2, open_to_close, node_affinity, open_to_node): + """ + Version of the lcs iterative algorithm where we precompute hash values. + + This is the current fastest implementation candidate for the LCS problem, + but note that the alternative version is faster in some cases. + """ + cdef dict all_decomp1 = generate_all_decomp_prehash_cython(full_seq1, open_to_close, open_to_node) + cdef dict all_decomp2 = generate_all_decomp_prehash_cython(full_seq2, open_to_close, open_to_node) + cdef dict key_decomp1 = {} + cdef dict key_decomp2 = {} + + cdef dict _results = {} + # Populate base cases + empty1 = type(next(iter(all_decomp1.keys())))() + empty2 = type(next(iter(all_decomp2.keys())))() + cdef Py_hash_t empty1_key = hash(empty1) + cdef Py_hash_t empty2_key = hash(empty2) + cdef tuple best = (empty1, empty2) + + cdef tuple info1, info2 + cdef tuple try_key, key + cdef Py_hash_t seq1_key, seq2_key + cdef Py_hash_t head1_key, tail1_key, head_tail1_key + cdef Py_hash_t head2_key, tail2_key, head_tail2_key + cdef tuple frame + cdef tuple miss_frame + + base_result = (0, best) + for seq1, info1 in all_decomp1.items(): + seq1_key = hash(seq1) + head1_key, tail1_key, head_tail1_key = all_decomp1[seq1][5:8] + _results[(seq1_key, empty2_key)] = base_result + _results[(head1_key, empty2_key)] = base_result + _results[(tail1_key, empty2_key)] = base_result + _results[(head_tail1_key, empty2_key)] = base_result + key_decomp1[seq1_key] = info1 + + for seq2, info2 in all_decomp2.items(): + seq2_key = hash(seq2) + head2_key, tail2_key, head_tail2_key = all_decomp2[seq2][5:8] + _results[(empty1_key, seq2_key)] = base_result + _results[(empty1_key, head2_key)] = base_result + _results[(empty1_key, tail2_key)] = base_result + _results[(empty1_key, head_tail2_key)] = base_result + key_decomp2[seq2_key] = info2 + + cdef Py_hash_t full_seq1_key = hash(full_seq1) + cdef Py_hash_t full_seq2_key = hash(full_seq2) + + cdef tuple key0 = (full_seq1_key, full_seq2_key) + cdef tuple frame0 = (key0, full_seq1, full_seq2) + cdef list stack = [frame0] + + while stack: + frame = stack[-1] + key, seq1, seq2 = frame + seq1_key, seq2_key = key + if key not in _results: + info1 = key_decomp1[seq1_key] + tok1, seq1, head1, tail1, head_tail1, head1_key, tail1_key, head_tail1_key, a1, b1 = info1 + + info2 = key_decomp2[seq2_key] + tok2, seq2, head2, tail2, head_tail2, head2_key, tail2_key, head_tail2_key, a2, b2 = info2 + + affinity = node_affinity(tok1, tok2) + + # Case 2: The current edge in sequence1 is deleted + try_key = (head_tail1_key, seq2_key) + if try_key in _results: + cand1 = _results[try_key] + else: + miss_frame = try_key, head_tail1, seq2 + stack.append(miss_frame) + continue + + # Case 3: The current edge in sequence2 is deleted + try_key = (seq1_key, head_tail2_key) + if try_key in _results: + cand2 = _results[try_key] + else: + miss_frame = try_key, seq1, head_tail2 + stack.append(miss_frame) + continue + + # Case 1: The LCS involves this edge + if affinity: + try_key = (head1_key, head2_key) + if try_key in _results: + pval_h, new_heads = _results[try_key] + else: + miss_frame = try_key, head1, head2 + stack.append(miss_frame) + continue + + try_key = (tail1_key, tail2_key) + if try_key in _results: + pval_t, new_tails = _results[try_key] + else: + miss_frame = try_key, tail1, tail2 + stack.append(miss_frame) + continue + + new_head1, new_head2 = new_heads + new_tail1, new_tail2 = new_tails + + subseq1 = a1 + new_head1 + b1 + new_tail1 + subseq2 = a2 + new_head2 + b2 + new_tail2 + + res3 = (subseq1, subseq2) + val3 = pval_h + pval_t + affinity + cand3 = (val3, res3) + else: + cand3 = (-1, None) + + # We solved the frame + _results[key] = max(cand1, cand2, cand3) + stack.pop() + + # The stack pop is our solution + (val, best) = _results[key0] + found = (best, val) + return found + + + + +def _lcs_iter_simple_alt2_cython(full_seq1, full_seq2, open_to_close, node_affinity, open_to_node): + """ + Depth first stack trajectory and replace try except statements with ifs + """ + if open_to_node is None: + open_to_node = IdentityDict() + all_decomp1 = generate_all_decomp_cython(full_seq1, open_to_close, open_to_node) + all_decomp2 = generate_all_decomp_cython(full_seq2, open_to_close, open_to_node) + + key0 = (full_seq1, full_seq2) + frame0 = key0 + stack = [frame0] + + _results = {} + # Populate base cases + empty1 = type(next(iter(all_decomp1.keys())))() + empty2 = type(next(iter(all_decomp2.keys())))() + best = (empty1, empty2) + base_result = (0, best) + for seq1 in all_decomp1.keys(): + key1 = seq1 + t1, a1, b1, head1, tail1, head_tail1 = all_decomp1[key1] + _results[(seq1, empty2)] = base_result + _results[(head1, empty2)] = base_result + _results[(tail1, empty2)] = base_result + _results[(head_tail1, empty2)] = base_result + + for seq2 in all_decomp2.keys(): + key2 = seq2 + t2, a2, b2, head2, tail2, head_tail2 = all_decomp2[key2] + _results[(empty1, seq2)] = base_result + _results[(empty1, head2)] = base_result + _results[(empty1, tail2)] = base_result + _results[(empty1, head_tail2)] = base_result + + while stack: + key = stack[-1] + if key not in _results: + seq1, seq2 = key + + t1, a1, b1, head1, tail1, head_tail1 = all_decomp1[seq1] + t2, a2, b2, head2, tail2, head_tail2 = all_decomp2[seq2] + + # Case 2: The current edge in sequence1 is deleted + try_key = (head_tail1, seq2) + if try_key in _results: + cand1 = _results[try_key] + else: + # stack.append(key) + stack.append(try_key) + continue + + # Case 3: The current edge in sequence2 is deleted + try_key = (seq1, head_tail2) + if try_key in _results: + cand2 = _results[try_key] + else: + # stack.append(key) + stack.append(try_key) + continue + + # Case 1: The LCS involves this edge + affinity = node_affinity(t1, t2) + if affinity: + try_key = (head1, head2) + if try_key in _results: + pval_h, new_heads = _results[try_key] + else: + # stack.append(key) + stack.append(try_key) + continue + + try_key = (tail1, tail2) + if try_key in _results: + pval_t, new_tails = _results[try_key] + else: + # stack.append(key) + stack.append(try_key) + continue + + new_head1, new_head2 = new_heads + new_tail1, new_tail2 = new_tails + + subseq1 = a1 + new_head1 + b1 + new_tail1 + subseq2 = a2 + new_head2 + b2 + new_tail2 + + res3 = (subseq1, subseq2) + val3 = pval_h + pval_t + affinity + cand3 = (val3, res3) + else: + cand3 = (-1, None) + + # We solved the frame + _results[key] = max(cand1, cand2, cand3) + stack.pop() + + val, best = _results[key0] + found = (best, val) + return found + + +cdef tuple balanced_decomp_unsafe_cython(sequence, dict open_to_close): + """ + Cython version of :func:`balanced_decomp_unsafe`. + """ + cdef int stacklen = 1 # always +1 in the first iteration + cdef int head_stop = 1 + + tok_curr = sequence[0] + want_close = open_to_close[tok_curr] + + # for tok_curr in sequence[1:]: + for head_stop in range(1, len(sequence)): + tok_curr = sequence[head_stop] + stacklen += 1 if tok_curr in open_to_close else -1 + if stacklen == 0 and tok_curr == want_close: + pop_close = sequence[head_stop:head_stop + 1] + break + + pop_open = sequence[0:1] + head = sequence[1:head_stop] + tail = sequence[head_stop + 1:] + head_tail = head + tail + return pop_open, pop_close, head, tail, head_tail + + +cdef generate_all_decomp_cython(seq, open_to_close, open_to_node=None): + """ + Cython version of :func:`generate_all_decomp`. + """ + all_decomp = {} + stack = [seq] + while stack: + seq = stack.pop() + if seq not in all_decomp and seq: + pop_open, pop_close, head, tail, head_tail = balanced_decomp_unsafe_cython(seq, open_to_close) + node = open_to_node[pop_open[0]] + all_decomp[seq] = (node, pop_open, pop_close, head, tail, head_tail) + stack.append(head_tail) + stack.append(head) + stack.append(tail) + return all_decomp + + +cdef tuple balanced_decomp_prehash_cython(seq, dict open_to_close, open_to_node): + """ + Cython version of :func:`balanced_decomp_unsafe`. + """ + cdef tuple info + pop_open, pop_close, head, tail, head_tail = balanced_decomp_unsafe_cython(seq, open_to_close) + cdef Py_hash_t head_key = hash(head) + cdef Py_hash_t tail_key = hash(tail) + cdef Py_hash_t head_tail_key = hash(head_tail) + node = open_to_node[pop_open[0]] + a = pop_open + b = pop_close + info = (node, seq, head, tail, head_tail, head_key, tail_key, head_tail_key, a, b) + return info + + +cdef dict generate_all_decomp_prehash_cython(seq, dict open_to_close, open_to_node): + """ + Cython version of :func:`generate_all_decomp_prehash`. + """ + cdef dict all_decomp = {} + cdef list stack = [seq] + cdef tuple info + while stack: + seq = stack.pop() + if seq: + # key = hash(seq) + key = seq + if key not in all_decomp: + info = balanced_decomp_prehash_cython(seq, open_to_close, open_to_node) + head, tail, head_tail = info[2:5] + all_decomp[key] = info + stack.append(head_tail) + stack.append(head) + stack.append(tail) + return all_decomp + + +class IdentityDict: + """ Used when ``open_to_node`` is unspecified """ + def __getitem__(self, key): + return key diff --git a/netharn/initializers/_nx_ext/benchmarks.py b/netharn/initializers/_nx_ext/benchmarks.py new file mode 100644 index 0000000000000000000000000000000000000000..2a9e70a93366f63cc91f8ec2a8b082f8c9a230e4 --- /dev/null +++ b/netharn/initializers/_nx_ext/benchmarks.py @@ -0,0 +1,387 @@ +from netharn.initializers._nx_ext.path_embedding import ( # NOQA + maximum_common_path_embedding) +# from netharn.initializers._nx_ext.tree_embedding import ( # NOQA +# maximum_common_ordered_tree_embedding, tree_to_seq) +from netharn.initializers._nx_ext.demodata import random_paths +from netharn.initializers._nx_ext.demodata import random_ordered_tree # NOQA +import operator + + +def bench_maximum_common_path_embedding(): + """ + xdoctest -m netharn.initializers._nx_ext.benchmarks bench_maximum_common_path_embedding + """ + import itertools as it + import ubelt as ub + import timerit + from netharn.initializers._nx_ext import balanced_sequence + from netharn.initializers._nx_ext import path_embedding + + data_modes = [] + + # Define which implementations we are going to test + run_basis = { + 'mode': [ + 'chr', + # 'number' + # 'tuple', # by far the slowest + ], + 'impl': balanced_sequence.available_impls_longest_common_balanced_sequence(), + } + + # Define the properties of the random data we are going to test on + data_basis = { + 'size': [20, 50], + 'max_depth': [8, 16], + 'common': [8, 16], + 'prefix_depth1': [0, 4], + 'prefix_depth2': [0, 4], + # 'labels': [26 ** 1, 26 ** 8] + 'labels': [1, 26] + } + + # run_basis['impl'] = set(run_basis['impl']) & { + # 'iter-alt2-cython', + # 'iter-prehash2-cython', + # 'iter-prehash2', + # 'iter-alt2', + # # 'iter-alt1', + # # 'iter-prehash', + # # 'iter', + # # 'recurse' + # } + + # TODO: parametarize demo names + # BENCH_MODE = None + # BENCH_MODE = 'small' + # BENCH_MODE = 'small2' + # BENCH_MODE = 'recursion-error' + BENCH_MODE = 'medium' + # BENCH_MODE = 'large' + + if BENCH_MODE == 'small': + data_basis = { + 'size': [30], + 'max_depth': [8, 2], + 'common': [2, 8], + 'prefix_depth1': [0, 4], + 'prefix_depth2': [0], + 'labels': [4] + } + run_basis['impl'] = set(run_basis['impl']) & { + # 'iter-alt2-cython', + 'iter-prehash2-cython', + 'iter-prehash2', + # 'iter-alt2', + # 'iter', + # 'recurse', + } + run_basis['impl'] = ub.oset(balanced_sequence.available_impls_longest_common_balanced_sequence()) - { + 'recurse', + } + # runparam_to_time = { + # ('chr', 'iter-prehash2-cython'): {'mean': 0.062, 'max': 0.157}, + # ('chr', 'iter-prehash2') : {'mean': 0.071, 'max': 0.185}, + # } + + if BENCH_MODE == 'small2': + data_basis = { + 'size': [30], + 'max_depth': [8, 2], + 'common': [2, 8], + 'prefix_depth1': [0, 4], + 'prefix_depth2': [0], + 'labels': [4] + } + run_basis['impl'] = ub.oset(balanced_sequence.available_impls_longest_common_balanced_sequence()) - { + 'recurse', + } + run_basis['mode'] = ['number', 'chr'] + # runparam_to_time = { + # ('chr', 'iter-alt2-cython') : {'mean': 0.036, 'max': 0.094}, + # ('chr', 'iter-alt2') : {'mean': 0.049, 'max': 0.125}, + # ('chr', 'iter-alt1') : {'mean': 0.050, 'max': 0.129}, + # ('chr', 'iter-prehash2-cython') : {'mean': 0.057, 'max': 0.146}, + # ('number', 'iter-prehash2-cython'): {'mean': 0.057, 'max': 0.146}, + # ('chr', 'iter') : {'mean': 0.064, 'max': 0.167}, + # ('chr', 'iter-prehash2') : {'mean': 0.066, 'max': 0.170}, + # ('number', 'iter-prehash2') : {'mean': 0.067, 'max': 0.176}, + # ('chr', 'iter-prehash') : {'mean': 0.073, 'max': 0.187}, + # ('number', 'iter-prehash') : {'mean': 0.074, 'max': 0.196}, + # ('number', 'iter-alt1') : {'mean': 0.126, 'max': 0.344}, + # ('number', 'iter-alt2-cython') : {'mean': 0.133, 'max': 0.363}, + # ('number', 'iter') : {'mean': 0.140, 'max': 0.386}, + # ('number', 'iter-alt2') : {'mean': 0.149, 'max': 0.408}, + # } + + if BENCH_MODE == 'medium': + data_basis = { + 'size': [30, 40], + 'max_depth': [4, 8], + 'common': [8, 50], + 'prefix_depth1': [0, 4], + 'prefix_depth2': [2], + 'labels': [8, 1] + } + # Results + # runparam_to_time = { + # ('chr', 'iter-alt2-cython') : {'mean': 0.112, 'max': 0.467}, + # ('chr', 'recurse') : {'mean': 0.153, 'max': 0.648}, + # ('chr', 'iter-alt2') : {'mean': 0.155, 'max': 0.661}, + # ('chr', 'iter-alt1') : {'mean': 0.163, 'max': 0.707}, + # ('chr', 'iter-prehash2-cython'): {'mean': 0.197, 'max': 0.849}, + # ('chr', 'iter') : {'mean': 0.216, 'max': 0.933}, + # ('chr', 'iter-prehash2') : {'mean': 0.225, 'max': 0.974}, + # ('chr', 'iter-prehash') : {'mean': 0.253, 'max': 1.097}, + # } + + if BENCH_MODE == 'large': + data_basis = { + 'size': [30, 40], + 'max_depth': [4, 12], # 64000 + 'common': [8, 32], + 'prefix_depth1': [0, 4], + 'prefix_depth2': [2], + 'labels': [8] + } + run_basis['impl'] = balanced_sequence.available_impls_longest_common_balanced_sequence() + # runparam_to_time = { + # ('chr', 'iter-alt2-cython') : {'mean': 0.282, 'max': 0.923}, + # ('chr', 'recurse') : {'mean': 0.397, 'max': 1.297}, + # ('chr', 'iter-alt2') : {'mean': 0.409, 'max': 1.328}, + # ('chr', 'iter-alt1') : {'mean': 0.438, 'max': 1.428}, + # ('chr', 'iter-prehash2-cython'): {'mean': 0.511, 'max': 1.668}, + # ('chr', 'iter') : {'mean': 0.580, 'max': 1.915}, + # ('chr', 'iter-prehash2') : {'mean': 0.605, 'max': 1.962}, + # ('chr', 'iter-prehash') : {'mean': 0.679, 'max': 2.211}, + # } + + elif BENCH_MODE == 'too-big': + data_basis = { + 'size': [100], + 'max_depth': [8], + 'common': [80], + 'prefix_depth1': [4], + 'prefix_depth2': [2], + 'labels': [8] + } + if BENCH_MODE == 'recursion-error': + data_basis = { + 'size': [0], + 'max_depth': [512], + 'common': [4], + 'prefix_depth1': [0], + 'prefix_depth2': [0], + 'labels': [256] + } + run_basis['impl'] = ub.oset(['recurse']) | ub.oset(balanced_sequence.available_impls_longest_common_balanced_sequence()) + # Results + # complexity = 69.48 + # stats1 = {'depth': 395, 'n_edges': 1203, 'n_leafs': 4, 'n_nodes': 1207, 'npaths': 4} + # stats2 = {'depth': 395, 'n_edges': 1203, 'n_leafs': 4, 'n_nodes': 1207, 'npaths': 4} + # runparam_to_time = { + # ('chr', 'recurse') : {'mean': NAN, 'max': NAN}, + # ('chr', 'iter-alt2-cython') : {'mean': 7.979, 'max': 7.979}, + # ('chr', 'iter-alt2') : {'mean': 11.307, 'max': 11.307}, + # ('chr', 'iter-alt1') : {'mean': 11.659, 'max': 11.659}, + # ('chr', 'iter-prehash2-cython'): {'mean': 15.230, 'max': 15.230}, + # ('chr', 'iter-prehash2') : {'mean': 17.058, 'max': 17.058}, + # ('chr', 'iter') : {'mean': 18.377, 'max': 18.377}, + # ('chr', 'iter-prehash') : {'mean': 19.508, 'max': 19.508}, + # } + + data_modes = [ + dict(zip(data_basis.keys(), vals)) + for vals in it.product(*data_basis.values())] + run_modes = [ + dict(zip(run_basis.keys(), vals)) + for vals in it.product(*run_basis.values())] + + print('len(data_modes) = {!r}'.format(len(data_modes))) + print('len(run_modes) = {!r}'.format(len(run_modes))) + print('total = {}'.format(len(data_modes) * len(run_modes))) + + seed = 0 + # if len(data_modes) < 10: + # for datakw in data_modes: + # _datakw = ub.dict_diff(datakw, {'complexity'}) + # paths1, paths2 = random_paths(seed=seed, **datakw) + # print('paths1 = {}'.format(ub.repr2(paths1, nl=1))) + # print('paths2 = {}'.format(ub.repr2(paths2, nl=1))) + # print('---') + for idx, datakw in enumerate(data_modes): + print('datakw = {}'.format(ub.repr2(datakw, nl=1))) + _datakw = ub.dict_diff(datakw, {'complexity'}) + paths1, paths2 = random_paths(seed=seed, **_datakw) + tree1 = path_embedding.paths_to_otree(paths1) + tree2 = path_embedding.paths_to_otree(paths2) + stats1 = { + 'npaths': len(paths1), + 'n_nodes': len(tree1.nodes), + 'n_edges': len(tree1.edges), + 'n_leafs': len([n for n in tree1.nodes if len(tree1.succ[n]) == 0]), + 'depth': max(len(p.split('/')) for p in paths1), + } + stats2 = { + 'npaths': len(paths2), + 'n_nodes': len(tree2.nodes), + 'n_edges': len(tree2.edges), + 'n_leafs': len([n for n in tree2.nodes if len(tree2.succ[n]) == 0]), + 'depth': max(len(p.split('/')) for p in paths2), + } + complexity = ( + stats1['n_nodes'] * min(stats1['n_leafs'], stats1['depth']) * + stats2['n_nodes'] * min(stats2['n_leafs'], stats2['depth'])) ** .25 + + datakw['complexity'] = complexity + print('datakw = {}'.format(ub.repr2(datakw, nl=0, precision=2))) + + if True: + # idx + 4 > len(data_modes): + print('stats1 = {}'.format(ub.repr2(stats1, nl=0))) + print('stats2 = {}'.format(ub.repr2(stats2, nl=0))) + # print('complexity = {:.2f}'.format(complexity)) + + total = len(data_modes) * len(run_modes) + print('len(data_modes) = {!r}'.format(len(data_modes))) + print('len(run_modes) = {!r}'.format(len(run_modes))) + print('total = {!r}'.format(total)) + seed = 0 + + prog = ub.ProgIter(total=total, verbose=3) + prog.begin() + results = [] + ti = timerit.Timerit(1, bestof=1, verbose=1, unit='s') + for datakw in data_modes: + _datakw = ub.dict_diff(datakw, {'complexity'}) + paths1, paths2 = random_paths(seed=seed, **_datakw) + print('---') + prog.step(4) + tree1 = path_embedding.paths_to_otree(paths1) + tree2 = path_embedding.paths_to_otree(paths2) + stats1 = { + 'npaths': len(paths1), + 'n_nodes': len(tree1.nodes), + 'n_edges': len(tree1.edges), + 'n_leafs': len([n for n in tree1.nodes if len(tree1.succ[n]) == 0]), + 'depth': max(len(p.split('/')) for p in paths1), + } + stats2 = { + 'npaths': len(paths2), + 'n_nodes': len(tree2.nodes), + 'n_edges': len(tree2.edges), + 'n_leafs': len([n for n in tree2.nodes if len(tree2.succ[n]) == 0]), + 'depth': max(len(p.split('/')) for p in paths2), + } + complexity = ( + stats1['n_nodes'] * min(stats1['n_leafs'], stats1['depth']) * + stats2['n_nodes'] * min(stats2['n_leafs'], stats2['depth'])) ** .25 + + datakw['complexity'] = complexity + print('datakw = {}'.format(ub.repr2(datakw, nl=0, precision=2))) + + if True: + # idx + 4 > len(data_modes): + print('stats1 = {}'.format(ub.repr2(stats1, nl=0))) + print('stats2 = {}'.format(ub.repr2(stats2, nl=0))) + for runkw in run_modes: + paramkw = {**datakw, **runkw} + run_key = ub.repr2( + paramkw, sep='', itemsep='', kvsep='', + explicit=1, nobr=1, nl=0, precision=1) + try: + for timer in ti.reset(run_key): + with timer: + maximum_common_path_embedding(paths1, paths2, **runkw) + except RecursionError as ex: + print('ex = {!r}'.format(ex)) + row = paramkw.copy() + row['time'] = float('nan') + else: + row = paramkw.copy() + row['time'] = ti.min() + results.append(row) + prog.end() + + print(ub.repr2(ub.sorted_vals(ti.measures['min']), nl=1, align=':', precision=6)) + + import pandas as pd + import kwarray + df = pd.DataFrame.from_dict(results) + + dataparam_to_time = {} + for mode, subdf in df.groupby(['complexity'] + list(data_basis.keys())): + stats = kwarray.stats_dict(subdf['time']) + stats.pop('min', None) + stats.pop('std', None) + stats.pop('shape', None) + dataparam_to_time[mode] = stats + dataparam_to_time = ub.sorted_vals(dataparam_to_time, key=lambda x: x['max']) + print('dataparam_to_time = {}'.format(ub.repr2(dataparam_to_time, nl=1, precision=3, align=':'))) + print(list(data_basis.keys())) + + runparam_to_time = {} + for mode, subdf in df.groupby(['mode', 'impl']): + stats = kwarray.stats_dict(subdf['time']) + stats.pop('min', None) + stats.pop('std', None) + stats.pop('shape', None) + runparam_to_time[mode] = stats + runparam_to_time = ub.sorted_vals(runparam_to_time, key=lambda x: x['max']) + print('runparam_to_time = {}'.format(ub.repr2(runparam_to_time, nl=1, precision=3, align=':'))) + + +def benchmark_balanced_sequence_single(): + from netharn.initializers._nx_ext import balanced_sequence + from netharn.initializers._nx_ext import demodata + import ubelt as ub + mode = 'number' + seq1, open_to_close = demodata.random_balanced_sequence(200, mode=mode) + seq2, open_to_close = demodata.random_balanced_sequence(400, mode=mode, open_to_close=open_to_close) + longest_common_balanced_sequence = balanced_sequence.longest_common_balanced_sequence + impls = balanced_sequence.available_impls_longest_common_balanced_sequence() + results = {} + for impl in impls: + with ub.Timer(impl): + best, val = longest_common_balanced_sequence( + seq1, seq2, open_to_close, node_affinity=None, impl=impl) + results[impl] = val + assert allsame(results.values()) + + +def allsame(iterable, eq=operator.eq): + """ + Determine if all items in a sequence are the same + + Args: + iterable (Iterable[A]): + items to determine if they are all the same + + eq (Callable[[A, A], bool], default=operator.eq): + function used to test for equality + + Returns: + bool: True if all items are equal, otherwise False + + Example: + >>> allsame([1, 1, 1, 1]) + True + >>> allsame([]) + True + >>> allsame([0, 1]) + False + >>> iterable = iter([0, 1, 1, 1]) + >>> next(iterable) + >>> allsame(iterable) + True + >>> allsame(range(10)) + False + >>> allsame(range(10), lambda a, b: True) + True + """ + iter_ = iter(iterable) + try: + first = next(iter_) + except StopIteration: + return True + return all(eq(first, item) for item in iter_) diff --git a/netharn/initializers/_nx_ext/demodata.py b/netharn/initializers/_nx_ext/demodata.py new file mode 100644 index 0000000000000000000000000000000000000000..edcd827de0d9af6010adedb34087b6e5548dfe1d --- /dev/null +++ b/netharn/initializers/_nx_ext/demodata.py @@ -0,0 +1,239 @@ +""" +Helpers for creating random data for tests / benchmarks for the tree embedding +algorithms. +""" + + +def random_paths( + size=10, max_depth=10, common=0, prefix_depth1=0, prefix_depth2=0, + sep='/', labels=26, seed=None): + """ + Returns two randomly created paths (as in directory structures) for use in + testing and benchmarking :func:`maximum_common_path_embedding`. + + Parameters + ---------- + size : int + The number of independant random paths + + max_depth : int + Maximum depth for the independant random paths + + common : int + The number of shared common paths + + prefix_depth1: int + Depth of the random prefix attacheded to first common paths + + prefix_depth2: int + Depth of the random prefix attacheded to second common paths + + labels: int or collection + Number of or collection of tokens that can be used as node labels + + sep: str + path separator + + seed: + Random state or seed + + Examples + -------- + >>> paths1, paths2 = random_paths( + >>> size=5, max_depth=3, common=6, + >>> prefix_depth1=3, prefix_depth2=3, labels=2 ** 64, + >>> seed=0) + >>> from netharn.initializers._nx_ext.path_embedding import paths_to_otree + >>> from netharn.initializers._nx_ext.tree_embedding import tree_to_seq + >>> tree = paths_to_otree(paths1) + >>> seq, open_to_close, node_to_open = tree_to_seq(tree, mode='chr') + >>> seq, open_to_close, node_to_open = tree_to_seq(tree, mode='number') + >>> seq, open_to_close, node_to_open = tree_to_seq(tree, mode='tuple') + >>> # xdoctest: +REQUIRES(module:ubelt) + >>> import ubelt as ub + >>> print('paths1 = {}'.format(ub.repr2(paths1, nl=1))) + >>> print('paths2 = {}'.format(ub.repr2(paths2, nl=1))) + """ + from networkx.utils import create_py_random_state + rng = create_py_random_state(seed) + + if isinstance(labels, int): + alphabet = list(map(chr, range(ord('a'), ord('z')))) + + def random_label(): + digit = rng.randint(0, labels) + label = _convert_digit_base(digit, alphabet) + return label + else: + from functools import partial + random_label = partial(rng.choice, labels) + + def random_path(rng, max_depth): + depth = rng.randint(1, max_depth) + parts = [str(random_label()) for _ in range(depth)] + path = sep.join(parts) + return path + + # These paths might be shared (but usually not) + iid_paths1 = {random_path(rng, max_depth) for _ in range(size)} + iid_paths2 = {random_path(rng, max_depth) for _ in range(size)} + + # These paths will be shared + common_paths = {random_path(rng, max_depth) for _ in range(common)} + + if prefix_depth1 > 0: + prefix1 = random_path(rng, prefix_depth1) + common1 = {sep.join([prefix1, suff]) for suff in common_paths} + else: + common1 = common_paths + + if prefix_depth2 > 0: + prefix2 = random_path(rng, prefix_depth2) + common2 = {sep.join([prefix2, suff]) for suff in common_paths} + else: + common2 = common_paths + + paths1 = sorted(common1 | iid_paths1) + paths2 = sorted(common2 | iid_paths2) + + return paths1, paths2 + + +def random_ordered_tree(n, seed=None): + """ + Creates a random ordered tree + + TODO + ---- + - [ ] Rename to random_ordered_directed_tree ? + - [ ] Merge in with other data generators? + + Parameters + ---------- + n : int + A positive integer representing the number of nodes in the tree. + + seed : integer, random_state, or None (default) + Indicator of random number generation state. + See :ref:`Randomness`. + + Returns + ------- + networkx.OrderedDiGraph + + Example + ------- + >>> assert len(random_ordered_tree(n=1, seed=0).nodes) == 1 + >>> assert len(random_ordered_tree(n=2, seed=0).nodes) == 2 + >>> assert len(random_ordered_tree(n=3, seed=0).nodes) == 3 + >>> from netharn.initializers._nx_ext.tree_embedding import forest_str + >>> print(forest_str(random_ordered_tree(n=5, seed=3))) + └── 1 + ├── 4 + │   ├── 3 + │   └── 2 + └── 0 + """ + import networkx as nx + from networkx.utils import create_py_random_state + rng = create_py_random_state(seed) + # Create a random undirected tree + utree = nx.random_tree(n, seed=rng) + # Use a random root node and dfs to define edge directions + nodes = list(utree.nodes) + source = rng.choice(nodes) + edges = nx.dfs_edges(utree, source=source) + # Populate the ordered graph + otree = nx.OrderedDiGraph() + otree.add_nodes_from(utree.nodes) + otree.add_edges_from(edges) + return otree + + +def random_balanced_sequence(n, seed=None, mode='chr', open_to_close=None): + r""" + Creates a random balanced sequence for testing / benchmarks + + Parameters + ---------- + n : int + A positive integer representing the number of nodes in the tree. + + seed : integer, random_state, or None (default) + Indicator of random number generation state. + See :ref:`Randomness`. + + open_to_close : dict | None + if specified, updates existing open_to_close with tokens from this + sequence. + + mode: str + the type of sequence returned (see :func:`tree_to_seq` for details) + + Returns + ------- + : tuple + The first item is the sequence itself + the second item is the open_to_close mappings. + + Example + ------- + >>> # Demo the various sequence encodings that we might use + >>> seq, open_to_close = random_balanced_sequence(2, seed=1, mode='tuple') + >>> print('seq = {!r}'.format(seq)) + >>> seq, open_to_close = random_balanced_sequence(4, seed=1, mode='chr') + >>> print('seq = {!r}'.format(seq)) + >>> seq, open_to_close = random_balanced_sequence(4, seed=1, mode='number') + >>> print('seq = {!r}'.format(seq)) + >>> seq, open_to_close = random_balanced_sequence(4, seed=1, mode='str') + >>> print('seq = {!r}'.format(seq)) + >>> seq, open_to_close = random_balanced_sequence(10, seed=1, mode='paren') + >>> print('seq = {!r}'.format(seq)) + seq = (('open', 0), ('open', 1), ('close', 1), ('close', 0)) + seq = '\x00\x02\x04\x06\x07\x05\x03\x01' + seq = (1, 2, 3, 4, -4, -3, -2, -1) + seq = ('2(', '1(', '0(', '3(', ')3', ')0', ')1', ')2') + seq = '([[[]{{}}](){{[]}}])' + """ + from networkx.utils import create_py_random_state + from netharn.initializers._nx_ext.tree_embedding import tree_to_seq + # Create a random otree and then convert it to a balanced sequence + rng = create_py_random_state(seed) + tree = random_ordered_tree(n, seed=rng) + if mode == 'paren': + pool = '[{(' + for node in tree.nodes: + tree.nodes[node]['label'] = rng.choice(pool) + seq, open_to_close, _ = tree_to_seq( + tree, mode=mode, open_to_close=open_to_close, strhack=1) + else: + seq, open_to_close, _ = tree_to_seq( + tree, mode=mode, open_to_close=open_to_close) + return seq, open_to_close + + +def _convert_digit_base(digit, alphabet): + """ + Parameters + ---------- + digit : int + number in base 10 to convert + + alphabet : list + symbols of the conversion base + """ + baselen = len(alphabet) + x = digit + if x == 0: + return alphabet[0] + sign = 1 if x > 0 else -1 + x *= sign + digits = [] + while x: + digits.append(alphabet[x % baselen]) + x //= baselen + if sign < 0: + digits.append('-') + digits.reverse() + newbase_str = ''.join(digits) + return newbase_str diff --git a/netharn/initializers/_nx_ext/path_embedding.py b/netharn/initializers/_nx_ext/path_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..0a7a71a532b18a59053c6efbe664875d05b38489 --- /dev/null +++ b/netharn/initializers/_nx_ext/path_embedding.py @@ -0,0 +1,143 @@ +import networkx as nx +from .tree_embedding import maximum_common_ordered_tree_embedding + + +def maximum_common_path_embedding(paths1, paths2, sep='/', impl='iter-alt2', mode='chr'): + """ + Finds the maximum path embedding common between two sets of paths + + Parameters + ---------- + paths1, paths2: List[str] + a list of paths + + sep: str + path separator character + + impl: str + backend runtime to use + + mode: str + backend representation to use + + Returns + ------- + :tuple + corresponding lists subpaths1 and subpaths2 which are subsets of + paths1 and path2 respectively + + Examples + -------- + >>> paths1 = [ + >>> '/usr/bin/python', + >>> '/usr/bin/python3.6.1', + >>> '/usr/lib/python3.6/dist-packages/networkx', + >>> '/usr/lib/python3.6/dist-packages/numpy', + >>> '/usr/include/python3.6/Python.h', + >>> ] + >>> paths2 = [ + >>> '/usr/local/bin/python', + >>> '/usr/bin/python3.6.2', + >>> '/usr/local/lib/python3.6/dist-packages/networkx', + >>> '/usr/local/lib/python3.6/dist-packages/scipy', + >>> '/usr/local/include/python3.6/Python.h', + >>> ] + >>> subpaths1, subpaths2 = maximum_common_path_embedding(paths1, paths2) + >>> import pprint + >>> print('subpaths1 = {}'.format(pprint.pformat(subpaths1))) + >>> print('subpaths2 = {}'.format(pprint.pformat(subpaths2))) + subpaths1 = ['/usr/bin/python', + '/usr/include/python3.6/Python.h', + '/usr/lib/python3.6/dist-packages/networkx'] + subpaths2 = ['/usr/local/bin/python', + '/usr/local/include/python3.6/Python.h', + '/usr/local/lib/python3.6/dist-packages/networkx'] + """ + # the longest common balanced sequence problem + def _affinity(node1, node2): + score = 0 + for t1, t2 in zip(node1[::-1], node2[::-1]): + if t1 == t2: + score += 1 + else: + break + return score + node_affinity = _affinity + + tree1 = paths_to_otree(paths1, sep=sep) + tree2 = paths_to_otree(paths2, sep=sep) + + subtree1, subtree2 = maximum_common_ordered_tree_embedding( + tree1, tree2, node_affinity=node_affinity, impl=impl, mode=mode) + + subpaths1 = [sep.join(node) for node in subtree1.nodes if subtree1.out_degree[node] == 0] + subpaths2 = [sep.join(node) for node in subtree2.nodes if subtree2.out_degree[node] == 0] + return subpaths1, subpaths2 + + +def paths_to_otree(paths, sep='/'): + """ + Generates an ordered tree from a list of path strings + + Parameters + ---------- + paths: List[str] + a list of paths + + sep : str + path separation character. defaults to "/" + + Returns + ------- + nx.OrderedDiGraph + + Example + ------- + >>> from netharn.initializers._nx_ext.tree_embedding import forest_str + >>> paths = [ + >>> '/etc/ld.so.conf', + >>> '/usr/bin/python3.6', + >>> '/usr/include/python3.6/Python.h', + >>> '/usr/lib/python3.6/config-3.6m-x86_64-linux-gnu/libpython3.6.so', + >>> '/usr/local/bin/gnumake.h', + >>> '/usr/local/etc', + >>> '/usr/local/lib/python3.6/dist-packages/', + >>> ] + >>> otree = paths_to_otree(paths) + >>> print(forest_str(otree)) + └── / + ├── usr + │   ├── local + │   │   ├── lib + │   │   │   └── python3.6 + │   │   │   └── dist-packages + │   │   │   └── + │   │   ├── etc + │   │   └── bin + │   │   └── gnumake.h + │   ├── lib + │   │   └── python3.6 + │   │   └── config-3.6m-x86_64-linux-gnu + │   │   └── libpython3.6.so + │   ├── include + │   │   └── python3.6 + │   │   └── Python.h + │   └── bin + │   └── python3.6 + └── etc + └── ld.so.conf + """ + otree = nx.OrderedDiGraph() + for path in sorted(paths): + parts = tuple(path.split(sep)) + node_path = [] + for i in range(1, len(parts) + 1): + node = parts[0:i] + otree.add_node(node) + otree.nodes[node]['label'] = node[-1] + node_path.append(node) + for u, v in zip(node_path[:-1], node_path[1:]): + otree.add_edge(u, v) + if ('',) in otree.nodes: + otree.nodes[('',)]['label'] = sep + return otree diff --git a/netharn/initializers/_nx_ext/tests/test_balanced_sequence.py b/netharn/initializers/_nx_ext/tests/test_balanced_sequence.py new file mode 100644 index 0000000000000000000000000000000000000000..ed88346ee23d2db4aec9628db25884244e57b64d --- /dev/null +++ b/netharn/initializers/_nx_ext/tests/test_balanced_sequence.py @@ -0,0 +1,32 @@ + + +def test_all_implementations_are_same(): + """ + Tests several random sequences + """ + from netharn.initializers._nx_ext import balanced_sequence + from netharn.initializers._nx_ext import demodata + from networkx.utils import create_py_random_state + + seed = 93024896892223032652928827097264 + rng = create_py_random_state(seed) + + maxsize = 20 + num_trials = 5 + + for _ in range(num_trials): + n1 = rng.randint(1, maxsize) + n2 = rng.randint(1, maxsize) + + seq1, open_to_close = demodata.random_balanced_sequence(n1, seed=rng) + seq2, open_to_close = demodata.random_balanced_sequence(n2, open_to_close=open_to_close, seed=rng) + longest_common_balanced_sequence = balanced_sequence.longest_common_balanced_sequence + + # Note: the returned sequences may be different (maximum embeddings may not + # be unique), but the values should all be the same. + results = {} + impls = balanced_sequence.available_impls_longest_common_balanced_sequence() + for impl in impls: + best, val = longest_common_balanced_sequence( + seq1, seq2, open_to_close, node_affinity=None, impl=impl) + results[impl] = val diff --git a/netharn/initializers/_nx_ext/tests/test_path_embedding.py b/netharn/initializers/_nx_ext/tests/test_path_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..82f2cd9b4f28c83f8bd1c19d1ae6897d06cc20c5 --- /dev/null +++ b/netharn/initializers/_nx_ext/tests/test_path_embedding.py @@ -0,0 +1,260 @@ +from netharn.initializers._nx_ext.path_embedding import maximum_common_path_embedding +from netharn.initializers._nx_ext.demodata import random_paths + + +def test_not_compatable(): + paths1 = [ + 'foo/bar' + ] + paths2 = [ + 'baz/biz' + ] + embedding1, embedding2 = maximum_common_path_embedding(paths1, paths2) + assert len(embedding1) == 0 + assert len(embedding2) == 0 + + +def test_compatable(): + paths1 = [ + 'root/suffix1' + ] + paths2 = [ + 'root/suffix2' + ] + embedding1, embedding2 = maximum_common_path_embedding(paths1, paths2) + assert embedding1 == ['root'] + assert embedding2 == ['root'] + + paths1 = [ + 'root/suffix1' + ] + paths2 = [ + 'root' + ] + embedding1, embedding2 = maximum_common_path_embedding(paths1, paths2) + assert embedding1 == ['root'] + assert embedding2 == ['root'] + + +def test_prefixed(): + paths1 = [ + 'prefix1/root/suffix1' + ] + paths2 = [ + 'root/suffix2' + ] + embedding1, embedding2 = maximum_common_path_embedding(paths1, paths2) + assert embedding1 == ['prefix1/root'] + assert embedding2 == ['root'] + + paths1 = [ + 'prefix1/root/suffix1' + ] + paths2 = [ + 'prefix1/root/suffix2' + ] + embedding1, embedding2 = maximum_common_path_embedding(paths1, paths2) + assert embedding1 == ['prefix1/root'] + assert embedding2 == ['prefix1/root'] + + +def test_simple1(): + paths1 = [ + 'root/file1', + 'root/file2', + 'root/file3', + ] + paths2 = [ + 'prefix1/root/file1', + 'prefix1/root/file2', + 'root/file3', + ] + embedding1, embedding2 = maximum_common_path_embedding(paths1, paths2) + assert embedding1 == paths1 + assert embedding2 == paths2 + + paths1 = [ + 'root/file1', + 'root/file2', + 'root/file3', + ] + paths2 = [ + 'prefix1/root/file1', + 'prefix1/root/file2', + 'prefix2/root/file3', + 'prefix2/root/file4', + ] + embedding1, embedding2 = maximum_common_path_embedding(paths1, paths2) + assert embedding1 == paths1 + + +def test_random1(): + paths1, paths2 = random_paths(10, seed=321) + embedding1, embedding2 = maximum_common_path_embedding(paths1, paths2) + + +def _demodata_resnet_module_state(arch): + """ + Construct paths corresponding to resnet convnet state keys to + simulate a real world use case for path-embeddings. + + Ignore + ------ + # Check to make sure the demodata agrees with real data + import torchvision + paths_true = list(torchvision.models.resnet50().state_dict().keys()) + paths_demo = _demodata_resnet_module_state('resnet50') + print(ub.hzcat([ub.repr2(paths_true, nl=2), ub.repr2(paths_demo)])) + assert paths_demo == paths_true + + paths_true = list(torchvision.models.resnet18().state_dict().keys()) + paths_demo = _demodata_resnet_module_state('resnet18') + print(ub.hzcat([ub.repr2(paths_true, nl=2), ub.repr2(paths_demo)])) + assert paths_demo == paths_true + + paths_true = list(torchvision.models.resnet152().state_dict().keys()) + paths_demo = _demodata_resnet_module_state('resnet152') + print(ub.hzcat([ub.repr2(paths_true, nl=2), ub.repr2(paths_demo)])) + assert paths_demo == paths_true + """ + if arch == 'resnet18': + block_type = 'basic' + layer_blocks = [2, 2, 2, 2] + elif arch == 'resnet50': + block_type = 'bottleneck' + layer_blocks = [3, 4, 6, 3] + elif arch == 'resnet152': + block_type = 'bottleneck' + layer_blocks = [3, 8, 36, 3] + else: + raise KeyError(arch) + paths = [] + paths += [ + 'conv1.weight', + 'bn1.weight', + 'bn1.bias', + 'bn1.running_mean', + 'bn1.running_var', + 'bn1.num_batches_tracked', + ] + if block_type == 'bottleneck': + num_convs = 3 + elif block_type == 'basic': + num_convs = 2 + else: + raise KeyError(block_type) + + for layer_idx, nblocks in enumerate(layer_blocks, start=1): + for block_idx in range(0, nblocks): + prefix = 'layer{}.{}.'.format(layer_idx, block_idx) + + for conv_idx in range(1, num_convs + 1): + paths += [ + prefix + 'conv{}.weight'.format(conv_idx), + prefix + 'bn{}.weight'.format(conv_idx), + prefix + 'bn{}.bias'.format(conv_idx), + prefix + 'bn{}.running_mean'.format(conv_idx), + prefix + 'bn{}.running_var'.format(conv_idx), + prefix + 'bn{}.num_batches_tracked'.format(conv_idx), + ] + if block_idx == 0 and layer_idx > 0: + if block_type != 'basic' or layer_idx > 1: + paths += [ + prefix + 'downsample.0.weight', + prefix + 'downsample.1.weight', + prefix + 'downsample.1.bias', + prefix + 'downsample.1.running_mean', + prefix + 'downsample.1.running_var', + prefix + 'downsample.1.num_batches_tracked', + ] + paths += [ + 'fc.weight', + 'fc.bias', + ] + return paths + + +def test_realworld_case1(): + """ + import torchvision + paths1 = list(torchvision.models.resnet50().state_dict().keys()) + + print(ub.hzcat(['paths1 = {}'.format(ub.repr2(paths1, nl=2)), ub.repr2(paths)])) + len(paths1) + """ + # times: resnet18: 0.16 seconds + # times: resnet50: 0.93 seconds + # times: resnet152: 9.83 seconds + paths1 = _demodata_resnet_module_state('resnet50') + paths2 = ['module.' + p for p in paths1] + # import ubelt as ub + # with ub.Timer('test-real-world-case'): + embedding1, embedding2 = maximum_common_path_embedding( + paths1, paths2, sep='.') + assert [p[len('module.'):] for p in embedding2] == embedding1 + + +def test_realworld_case2(): + """ + import torchvision + paths1 = list(torchvision.models.resnet152().state_dict().keys()) + print('paths1 = {}'.format(ub.repr2(paths1, nl=2))) + """ + backbone = _demodata_resnet_module_state('resnet18') + + # Detector strips of prefix and suffix of the backbone net + subpaths = ['detector.' + p for p in backbone[6:-2]] + paths1 = [ + 'detector.conv1.weight', + 'detector.bn1.weight', + 'detector.bn1.bias', + ] + subpaths + [ + 'detector.head1.conv1.weight', + 'detector.head1.conv2.weight', + 'detector.head1.conv3.weight', + 'detector.head1.fc.weight', + 'detector.head1.fc.bias', + 'detector.head2.conv1.weight', + 'detector.head2.conv2.weight', + 'detector.head2.conv3.weight', + 'detector.head2.fc.weight', + 'detector.head2.fc.bias', + ] + + paths2 = ['module.' + p for p in backbone] + + # import ubelt as ub + # with ub.Timer('test-real-world-case'): + embedding1, embedding2 = maximum_common_path_embedding( + paths1, paths2, sep='.') + + mapping = dict(zip(embedding1, embedding2)) + + # Note in the embedding case there may be superfluous assignments + # but they can either be discarded in post-processing or they wont + # be in the solution if we use isomorphisms instead of embeddings + assert len(subpaths) < len(mapping), ( + 'all subpaths should be in the mapping') + + non_common1 = set(paths1) - set(embedding1) + non_common2 = set(paths2) - set(embedding2) + + assert non_common2 == { + 'module.bn1.num_batches_tracked', + 'module.bn1.running_mean', + 'module.bn1.running_var', + } + + assert non_common1 == { + 'detector.conv1.weight', + 'detector.head1.conv1.weight', + 'detector.head1.conv2.weight', + 'detector.head1.conv3.weight', + 'detector.head1.fc.bias', + 'detector.head1.fc.weight', + 'detector.head2.conv2.weight', + 'detector.head2.conv3.weight', + } + # print('non_common1 = {}'.format(ub.repr2(non_common1, nl=1))) + # print('non_common2 = {}'.format(ub.repr2(non_common2, nl=1))) + # assert [p[len('module.'):] for p in embedding2] == embedding1 diff --git a/netharn/initializers/_nx_ext/tests/test_tree_embedding.py b/netharn/initializers/_nx_ext/tests/test_tree_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..b9a2048601a1068066c02c9f3a8b298964e0b917 --- /dev/null +++ b/netharn/initializers/_nx_ext/tests/test_tree_embedding.py @@ -0,0 +1,109 @@ +from netharn.initializers._nx_ext.tree_embedding import ( + maximum_common_ordered_tree_embedding, forest_str) + +from netharn.initializers._nx_ext.demodata import ( + random_ordered_tree +) +import networkx as nx +import pytest +from networkx.utils import create_py_random_state + + +def test_null_common_embedding(): + """ + The empty graph is not a tree and should raise an error + """ + empty = nx.OrderedDiGraph() + non_empty = random_ordered_tree(n=1) + + with pytest.raises(nx.NetworkXPointlessConcept): + maximum_common_ordered_tree_embedding(empty, empty) + + with pytest.raises(nx.NetworkXPointlessConcept): + maximum_common_ordered_tree_embedding(empty, non_empty) + + with pytest.raises(nx.NetworkXPointlessConcept): + maximum_common_ordered_tree_embedding(non_empty, empty) + + +def test_self_common_embedding(): + """ + The common embedding of a tree with itself should always be itself + """ + rng = create_py_random_state(85652972257) + for n in range(1, 10): + tree = random_ordered_tree(n=n, seed=rng) + embedding1, embedding2 = maximum_common_ordered_tree_embedding(tree, tree) + assert tree.edges == embedding1.edges + + +def test_common_tree_embedding_small(): + tree1 = nx.OrderedDiGraph([(0, 1)]) + tree2 = nx.OrderedDiGraph([(0, 1), (1, 2)]) + print(forest_str(tree1)) + print(forest_str(tree2)) + + embedding1, embedding2 = maximum_common_ordered_tree_embedding(tree1, tree2) + print(forest_str(embedding1)) + print(forest_str(embedding2)) + + +def test_common_tree_embedding_small2(): + tree1 = nx.OrderedDiGraph([(0, 1), (2, 3), (4, 5), (5, 6)]) + tree2 = nx.OrderedDiGraph([(0, 1), (1, 2), (0, 3)]) + print(forest_str(tree1)) + print(forest_str(tree2)) + + embedding1, embedding2 = maximum_common_ordered_tree_embedding(tree1, tree2, node_affinity=None) + print(forest_str(embedding1)) + print(forest_str(embedding2)) + + +def test_all_implementations_are_same(): + """ + Tests several random sequences + """ + from netharn.initializers._nx_ext import balanced_sequence + from netharn.initializers._nx_ext import demodata + from networkx.utils import create_py_random_state + + seed = 24658885408229410362279507020239 + rng = create_py_random_state(seed) + + maxsize = 20 + num_trials = 5 + + for _ in range(num_trials): + n1 = rng.randint(1, maxsize) + n2 = rng.randint(1, maxsize) + + tree1 = demodata.random_ordered_tree(n1, seed=rng) + tree2 = demodata.random_ordered_tree(n2, seed=rng) + + # Note: the returned sequences may be different (maximum embeddings may not + # be unique), but the values should all be the same. + results = {} + impls = balanced_sequence.available_impls_longest_common_balanced_sequence() + for impl in impls: + # FIXME: do we need to rework the return value here? + subtree1, subtree2 = maximum_common_ordered_tree_embedding( + tree1, tree2, node_affinity=None, impl=impl) + _check_common_embedding_invariants(tree1, tree2, subtree1, subtree2) + results[impl] = len(subtree1.nodes) + + x = max(results.values()) + assert all(v == x for v in results.values()) + + +def _check_embedding_invariants(tree, subtree): + assert set(subtree.nodes).issubset(set(tree.nodes)), 'must have a node subset' + assert len(subtree.edges) <= len(tree.edges) + + +def _check_common_embedding_invariants(tree1, tree2, subtree1, subtree2): + """ + Validates that this solution satisfies properties of an embedding + """ + _check_embedding_invariants(tree1, subtree1) + _check_embedding_invariants(tree2, subtree2) + assert len(subtree1.nodes) == len(subtree2.nodes) diff --git a/netharn/initializers/_nx_ext/tree_embedding.py b/netharn/initializers/_nx_ext/tree_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..9978eb711ef316073cb17ac69f51cb055e897655 --- /dev/null +++ b/netharn/initializers/_nx_ext/tree_embedding.py @@ -0,0 +1,558 @@ +""" +Algorithm for computing tree embeddings +""" +import networkx as nx +from collections import OrderedDict, defaultdict +from .balanced_sequence import longest_common_balanced_sequence, UnbalancedException + + +def maximum_common_ordered_tree_embedding( + tree1, tree2, node_affinity='auto', impl='auto', mode='chr'): + """ + Finds the maximum common subtree-embedding between two ordered trees. + + A tree S is an embedded subtree of T if it can be obtained from T by a + series of edge contractions. + + Note this produces a subtree embedding, which is not necessarilly a + subgraph isomorphism (although a subgraph isomorphism is also an + embedding.) + + The maximum common embedded subtree problem can be solved in in + `O(n1 * n2 * min(d1, l1) * min(d2, l2))` time on ordered trees with n1 and + n2 nodes, of depth d1 and d2 and with l1 and l2 leaves, respectively. + + Implements algorithm described in [1]_, which introduces the problem as + follows: + + "An important generalization of tree and subtree isomorphism, known as + minor containment, is the problem of determining whether a tree is + isomorphic to an embedded subtree of another tree, where an embedded + subtree of a tree is obtained by contracting some of the edges in the tree. + A further generalization of minor containment on trees, known as maximum + common embedded subtree, is the problem of finding or determining the size + of a largest common embedded subtree of two trees. The latter also + generalizes the maximum common subtree isomorphism problem, in which a + common subtree of largest size is contained as a subtree, not only + embedded, in the two trees." + + Parameters + ---------- + tree1, tree2 : nx.OrderedDiGraph + Trees to find the maximum embedding between + + node_affinity : None | str | callable + Function for to determine if two nodes can be matched. The return is + interpreted as a weight that is used to break ties. If None then any + node can match any other node and only the topology is important. + The default is "eq", which is the same as ``operator.eq``. + + impl : str + Determines the backend implementation + + mode : str + Determines the backend representation + + References + ---------- + .. [1] Lozano, Antoni, and Gabriel Valiente. + "On the maximum common embedded subtree problem for ordered trees." + String Algorithmics (2004): 155-170. + https://pdfs.semanticscholar.org/0b6e/061af02353f7d9b887f9a378be70be64d165.pdf + + Returns + ------- + Tuple[nx.OrderedDiGraph, nx.OrderedDiGraph] : + The maximum value common embedding for each tree with respect to the + chosen ``node_affinity`` function. The topology of both graphs will + always be the same, the only difference is that the node labels in the + first and second embeddings will correspond to ``tree1`` and `tree2`` + respectively. When ``node_affinity='eq'`` then embeddings should be + identical. + + Example + ------- + >>> from netharn.initializers._nx_ext.tree_embedding import * # NOQA + >>> from netharn.initializers._nx_ext.demodata import random_ordered_tree # NOQA + >>> tree1 = random_ordered_tree(7, seed=3257073545741117277206611) + >>> tree2 = random_ordered_tree(7, seed=123568587133124688238689717) + >>> print('tree1') + >>> forest_str(tree1, write=print) + >>> print('tree2') + >>> forest_str(tree2, write=print) + >>> embedding1, embedding2 = maximum_common_ordered_tree_embedding(tree1, tree2 ) + >>> print('embedding1') + >>> forest_str(embedding1, write=print) + >>> print('embedding2') + >>> forest_str(embedding2, write=print) + tree1 + └── 1 + ├── 6 + │   ├── 4 + │   └── 3 + └── 0 + └── 5 + └── 2 + tree2 + └── 4 + └── 1 + ├── 2 + │   ├── 6 + │   └── 0 + └── 3 + └── 5 + embedding1 + └── 1 + ├── 6 + └── 5 + embedding2 + └── 1 + ├── 6 + └── 5 + """ + if not (isinstance(tree1, nx.OrderedDiGraph) and nx.is_forest(tree1)): + raise nx.NetworkXNotImplemented('only implemented for directed ordered trees') + if not (isinstance(tree1, nx.OrderedDiGraph) and nx.is_forest(tree2)): + raise nx.NetworkXNotImplemented('only implemented for directed ordered trees') + + # Convert the trees to balanced sequences + sequence1, open_to_close, node_to_open = tree_to_seq( + tree1, open_to_close=None, node_to_open=None, mode=mode) + sequence2, open_to_close, node_to_open = tree_to_seq( + tree2, open_to_close, node_to_open, mode=mode) + seq1 = sequence1 + seq2 = sequence2 + + # FIXME: I think this may cause bugs in two cases, which may or may not be + # possible, but I need to look into it and provide a fix or justification + # as to why these cases wont be hit: + # (1) when the two trees share nodes that have different open tokens + # (2) when the mapping between nodes to opening tokens is not unique. + # I'm not sure if this second case can happen when we are converting + # from a tree to a sequence, there are certainly sequences where the + # same opening token might share multiple tree nodes. + open_to_node = invert_dict(node_to_open) + + # Solve the longest common balanced sequence problem + best, value = longest_common_balanced_sequence( + seq1, seq2, open_to_close, open_to_node=open_to_node, + node_affinity=node_affinity, impl=impl) + subseq1, subseq2 = best + + # Convert the subsequence back into a tree + embedding1 = seq_to_tree(subseq1, open_to_close, open_to_node) + embedding2 = seq_to_tree(subseq2, open_to_close, open_to_node) + return embedding1, embedding2 + + +def tree_to_seq(tree, open_to_close=None, node_to_open=None, mode='tuple', strhack=None): + r""" + Converts an ordered tree to a balanced sequence for use in algorithm + reductions. + + Parameters + ---------- + open_to_close : Dict | None + Dictionary of opening to closing tokens to be updated for problems + where multiple trees are converted to sequences. + + open_to_node : Dict | None + Dictionary of opening tokens to nodes to be updated for problems where + multiple trees are converted to sequences. + + mode : str + Currently hacky and needs refactor. + Can be 'tuple', 'number', or 'chr'. + Hackier variants are 'str' and 'paren'. + + strhack : bool + Currently hacky and needs refactor. If False, always return a tuple of + items, if True, tries to return a string of items. If None, tries to + choose a value depending on mode. + + Example + ------- + >>> from netharn.initializers._nx_ext.tree_embedding import * # NOQA + >>> tree = nx.path_graph(3, nx.OrderedDiGraph) + >>> print(forest_str(tree)) + >>> sequence, open_to_close, node_to_open = tree_to_seq(tree, mode='number') + >>> print('sequence = {!r}'.format(sequence)) + └── 0 + └── 1 + └── 2 + sequence = (1, 2, 3, -3, -2, -1) + + >>> tree = nx.balanced_tree(2, 2, nx.OrderedDiGraph) + >>> print(forest_str(tree)) + >>> sequence, open_to_close, node_to_open = tree_to_seq(tree, mode='number') + >>> print('sequence = {!r}'.format(sequence)) + └── 0 + ├── 2 + │   ├── 6 + │   └── 5 + └── 1 + ├── 4 + └── 3 + sequence = (1, 2, 3, -3, 4, -4, -2, 5, 6, -6, 7, -7, -5, -1) + + >>> from netharn.initializers._nx_ext.demodata import random_ordered_tree # NOQA + >>> tree = random_ordered_tree(2, seed=1) + >>> sequence, open_to_close, node_to_open = tree_to_seq(tree, mode='tuple') + >>> print('sequence = {!r}'.format(sequence)) + >>> sequence, open_to_close, node_to_open = tree_to_seq(tree, mode='chr') + >>> print('sequence = {!r}'.format(sequence)) + >>> sequence, open_to_close, node_to_open = tree_to_seq(tree, mode='number') + >>> print('sequence = {!r}'.format(sequence)) + sequence = (('open', 0), ('open', 1), ('close', 1), ('close', 0)) + sequence = '\x00\x02\x03\x01' + sequence = (1, 2, -2, -1) + """ + # mapping between opening and closing tokens + sources = [n for n in tree.nodes if tree.in_degree[n] == 0] + sequence = [] + + if strhack is None: + if mode == 'chr': + strhack = True + + if open_to_close is None: + open_to_close = {} + if node_to_open is None: + node_to_open = {} + + if strhack: + if mode == 'paren': + all_labels = {n['label'] for n in list(tree.nodes.values())} + assert all(x == 1 for x in map(len, all_labels)) + + for source in sources: + for u, v, etype in nx.dfs_labeled_edges(tree, source=source): + if etype == 'forward': + # u has been visited by v has not + if v not in node_to_open: + if mode == 'tuple': + open_tok = ('open', v) + close_tok = ('close', v) + elif mode == 'number': + open_tok = len(node_to_open) + 1 + close_tok = -open_tok + elif mode == 'str': + open_tok = '{}('.format(v) + close_tok = '){}'.format(v) + elif mode == 'chr': + if not strhack: + # note ussing the accent mark wont work in string + # mode even though the close tok renders as a + # single character. + open_tok = str(v) + close_tok = str(v) + u'\u0301' + else: + # utf8 can only encode this many chars + assert len(node_to_open) < (1112064 // 2) + open_tok = chr(len(node_to_open) * 2) + close_tok = chr(len(node_to_open) * 2 + 1) + elif mode == 'paren': + open_tok = tree.nodes[v]['label'] + assert strhack + if open_tok == '{': + close_tok = '}' + elif open_tok == '[': + close_tok = ']' + elif open_tok == '(': + close_tok = ')' + else: + raise KeyError(open_tok) + else: + raise KeyError(mode) + node_to_open[v] = open_tok + open_to_close[open_tok] = close_tok + open_tok = node_to_open[v] + sequence.append(open_tok) + elif etype == 'reverse': + # Both u and v are visited and the edge is in the tree + close_tok = open_to_close[node_to_open[v]] + sequence.append(close_tok) + else: + raise KeyError(etype) + sequence = tuple(sequence) + if strhack: + sequence = ''.join(sequence) + return sequence, open_to_close, node_to_open + + +def seq_to_tree(subseq, open_to_close, open_to_node): + """ + Converts a balanced sequence to an ordered tree + + Parameters + ---------- + subseq : Tuple | str + a balanced sequence of hashable items as a string or tuple + + open_to_close : Dict + a dictionary that maps opening tokens to closing tokens in the balanced + sequence problem. + + open_to_node : Dict + a dictionary that maps a sequence token to a node corresponding to an + original problem (e.g. a tree node). Must be unique. If unspecified new + nodes will be generated and the opening sequence token will be used as + a node label. + + Example + -------- + >>> from netharn.initializers._nx_ext.demodata import random_ordered_tree + >>> open_to_close = {'{': '}', '(': ')', '[': ']'} + >>> open_to_node = None + >>> subseq = '({[[]]})[[][]]{{}}' + >>> subtree = seq_to_tree(subseq, open_to_close, open_to_node) + >>> print(forest_str(subtree)) + ├── { + │   └── { + ├── [ + │   ├── [ + │   └── [ + └── ( + └── { + └── [ + └── [ + """ + nextnode = 0 # only used if open_to_node is not specified + subtree = nx.OrderedDiGraph() + stack = [] + for token in subseq: + if token in open_to_close: + if open_to_node is None: + node = nextnode + nextnode += 1 + else: + node = open_to_node[token] + if stack: + parent_tok, parent_node = stack[-1] + subtree.add_edge(parent_node, node) + else: + subtree.add_node(node) + if open_to_node is None: + subtree.nodes[node]['label'] = token + stack.append((token, node)) + else: + if not stack: + raise UnbalancedException + prev_open, prev_node = stack.pop() + want_close = open_to_close[prev_open] + if token != want_close: + raise UnbalancedException + return subtree + + +def invert_dict(dict_, unique_vals=True): + """ + Swaps the keys and values in a dictionary. + + Parameters + ---------- + dict_ (Dict[A, B]): dictionary to invert + + unique_vals (bool, default=True): if False, the values of the new + dictionary are sets of the original keys. + + Returns + ------- + Dict[B, A] | Dict[B, Set[A]]: + the inverted dictionary + + Notes + ----- + The must values be hashable. + + If the original dictionary contains duplicate values, then only one of + the corresponding keys will be returned and the others will be + discarded. This can be prevented by setting ``unique_vals=False``, + causing the inverted keys to be returned in a set. + + Example + ------- + >>> from netharn.initializers._nx_ext.tree_embedding import * # NOQA + >>> dict_ = {'a': 1, 'b': 2} + >>> inverted = invert_dict(dict_) + >>> assert inverted == {1: 'a', 2: 'b'} + + Example + ------- + >>> from netharn.initializers._nx_ext.tree_embedding import * # NOQA + >>> dict_ = OrderedDict([(2, 'a'), (1, 'b'), (0, 'c'), (None, 'd')]) + >>> inverted = invert_dict(dict_) + >>> assert list(inverted.keys())[0] == 'a' + + Example + ------- + >>> from netharn.initializers._nx_ext.tree_embedding import * # NOQA + >>> dict_ = {'a': 1, 'b': 0, 'c': 0, 'd': 0, 'f': 2} + >>> inverted = invert_dict(dict_, unique_vals=False) + >>> assert inverted == {0: {'b', 'c', 'd'}, 1: {'a'}, 2: {'f'}} + """ + if unique_vals: + if isinstance(dict_, OrderedDict): + inverted = OrderedDict((val, key) for key, val in dict_.items()) + else: + inverted = {val: key for key, val in dict_.items()} + else: + # Handle non-unique keys using groups + inverted = defaultdict(set) + for key, value in dict_.items(): + inverted[value].add(key) + inverted = dict(inverted) + return inverted + + +def forest_str(graph, use_labels=True, sources=None, write=None): + """ + Creates a nice utf8 representation of a directed forest + + Parameters + ---------- + graph : nx.DiGraph | nx.Graph + Graph to represent (must be a tree, forest, or the empty graph) + + use_labels : bool + If True will use the "label" attribute of a node to display if it + exists otherwise it will use the node value itself. Defaults to True. + + sources : List + Mainly relevant for undirected forests, specifies which nodes to list + first. If unspecified the root nodes of each tree will be used for + directed forests; for undirected forests this defaults to the nodes + with the smallest degree. + + write : callable + Function to use to write to, if None new lines are appended to + a list and returned. If set to the `print` function, lines will + be written to stdout as they are generated. If specified, + this function will return None. Defaults to None. + + Returns + ------- + str | None : + utf8 representation of the tree / forest + + Example + ------- + >>> import networkx as nx + >>> graph = nx.balanced_tree(r=2, h=3, create_using=nx.DiGraph) + >>> print(forest_str(graph)) + ╙── 0 + ├─╼ 2 + │   ├─╼ 6 + │   │   ├─╼ 14 + │   │   └─╼ 13 + │   └─╼ 5 + │   ├─╼ 12 + │   └─╼ 11 + └─╼ 1 + ├─╼ 4 + │   ├─╼ 10 + │   └─╼ 9 + └─╼ 3 + ├─╼ 8 + └─╼ 7 + + >>> graph = nx.balanced_tree(r=1, h=2, create_using=nx.Graph) + >>> print(nx.forest_str(graph)) + ╟── 1 + ╎   ├── 2 + ╎   └── 0 + """ + import networkx as nx + + printbuf = [] + if write is None: + _write = printbuf.append + else: + _write = write + + if len(graph.nodes) == 0: + _write("╙") + else: + if not nx.is_forest(graph): + raise nx.NetworkXNotImplemented("input must be a forest or the empty graph") + + is_directed = graph.is_directed() + succ = graph.succ if is_directed else graph.adj + + if sources is None: + if is_directed: + # use real source nodes for directed trees + sources = [n for n in graph.nodes if graph.in_degree[n] == 0] + else: + # use arbitrary sources for undirected trees + sources = sorted(graph.nodes, key=lambda n: graph.degree[n]) + + seen = set() + stack = [] + for idx, node in enumerate(sources): + islast_next = idx == 0 + stack.append((node, "", islast_next)) + + while stack: + node, indent, islast = stack.pop() + if node in seen: + continue + seen.add(node) + + # Notes on available box and arrow characters + # https://en.wikipedia.org/wiki/Box-drawing_character + # https://stackoverflow.com/questions/2701192/triangle-arrow + if not indent: + # Top level items (i.e. trees in the forest) get different + # glyphs to indicate they are not actually connected + if islast: + this_prefix = indent + "╙── " + next_prefix = indent + " " + else: + this_prefix = indent + "╟── " + next_prefix = indent + "╎   " + + else: + # For individual forests distinguish between directed and + # undirected cases + if is_directed: + if islast: + this_prefix = indent + "└─╼ " + next_prefix = indent + " " + else: + this_prefix = indent + "├─╼ " + next_prefix = indent + "│   " + else: + if islast: + this_prefix = indent + "└── " + next_prefix = indent + " " + else: + this_prefix = indent + "├── " + next_prefix = indent + "│   " + + if use_labels: + label = graph.nodes[node].get("label", node) + else: + label = node + + _write(this_prefix + str(label)) + + children = [child for child in succ[node] if child not in seen] + for idx, child in enumerate(children, start=1): + islast_next = idx <= 1 + try_frame = (child, next_prefix, islast_next) + stack.append(try_frame) + + if write is None: + # Only return a string if the custom write function was not specified + return "\n".join(printbuf) + + +if __name__ == '__main__': + """ + CommandLine: + python -m netharn.initializers._nx_ext.tree_embedding all + python -m netharn.initializers._nx_ext all + """ + import xdoctest + xdoctest.doctest_module(__file__) diff --git a/netharn/initializers/_nx_extensions.py b/netharn/initializers/_nx_extensions.py deleted file mode 100644 index e364ef8a1ebe92ed05326037f9d796df9344947f..0000000000000000000000000000000000000000 --- a/netharn/initializers/_nx_extensions.py +++ /dev/null @@ -1,1004 +0,0 @@ -""" -EXPERIMENTAL : NEW WORK ON THIS IS HAPPENING IN NETWORKX ITSELF - -ONCE THAT IS DONE I WILL MODIFY THE ALGORITHMS HERE. -""" - -import operator -import ubelt as ub -import networkx as nx - -try: - import xdev - profile = xdev.profile -except Exception: - profile = ub.identity - - -# Cython gives a 40x speed boost in the nx version but not here -TRY_USE_CYTHON = 0 - - -@profile -def maximum_common_ordered_tree_embedding(tree1, tree2, node_affinity='auto'): - """ - Finds the maximum common subtree-embedding between two ordered trees. - - A tree S is an embedded subtree of T if it can be obtained from T by a - series of edge contractions. - - Note this produces a subtree embedding, which is not necessarilly a - subgraph isomorphism (although a subgraph isomorphism is also an - embedding.) - - The maximum common embedded subtree problem can be solved in in - `O(n1 * n2 * min(d1, l1) * min(d2, l2))` time on ordered trees with n1 and - n2 nodes, of depth d1 and d2 and with l1 and l2 leaves, respectively - - Implements algorithm described in [1]_. - - References: - On the Maximum Common Embedded Subtree Problem for Ordered Trees - https://pdfs.semanticscholar.org/0b6e/061af02353f7d9b887f9a378be70be64d165.pdf - - http://algo.inria.fr/flajolet/Publications/FlSiSt90.pdf - - Notes: - Exact algorithms for computing the tree edit distance between unordered trees - https://pdf.sciencedirectassets.com/271538/1-s2.0-S0304397510X00299/1-s2.0-S0304397510005463/main.pdf ? - - Tree Edit Distance and Common Subtrees - https://upcommons.upc.edu/bitstream/handle/2117/97554/R02-20.pdf - - A Survey on Tree Edit Distance and Related Problems - https://grfia.dlsi.ua.es/ml/algorithms/references/editsurvey_bille.pdf - - Args: - - tree1 (nx.OrderedDiGraph): first ordered tree - tree2 (nx.OrderedDiGraph): second ordered tree - node_affinity (callable): function - - Example: - >>> from netharn.initializers._nx_extensions import * # NOQA - >>> from netharn.initializers._nx_extensions import _lcs, _print_forest - >>> def random_ordered_tree(n, seed=None): - >>> tree = nx.dfs_tree(nx.random_tree(n, seed=seed)) - >>> otree = nx.OrderedDiGraph() - >>> otree.add_edges_from(tree.edges) - >>> return otree - >>> tree1 = random_ordered_tree(10, seed=1) - >>> tree2 = random_ordered_tree(10, seed=2) - >>> print('tree1') - >>> _print_forest(tree1) - >>> print('tree2') - >>> _print_forest(tree2) - - >>> embedding1, embedding2 = maximum_common_ordered_tree_embedding(tree1, tree2 ) - >>> print('embedding1') - >>> _print_forest(embedding1) - >>> print('embedding2') - >>> _print_forest(embedding2) - """ - if not (isinstance(tree1, nx.OrderedDiGraph) and nx.is_forest(tree1)): - raise nx.NetworkXNotImplemented('only implemented for directed ordered trees') - if not (isinstance(tree1, nx.OrderedDiGraph) and nx.is_forest(tree2)): - raise nx.NetworkXNotImplemented('only implemented for directed ordered trees') - - # Convert the trees to balanced sequences - sequence1, open_to_close, toks = tree_to_balanced_sequence(tree1, open_to_close=None, toks=None) - sequence2, open_to_close, toks = tree_to_balanced_sequence(tree2, open_to_close, toks) - seq1 = sequence1 - seq2 = sequence2 - - open_to_tok = ub.invert_dict(toks) - - # Solve the longest common balanced sequence problem - best, value = longest_common_balanced_sequence( - seq1, seq2, open_to_close, open_to_tok=open_to_tok, node_affinity=node_affinity) - subseq1, subseq2 = best - - # Convert the subsequence back into a tree - embedding1 = seq_to_tree(subseq1, open_to_close, toks) - embedding2 = seq_to_tree(subseq2, open_to_close, toks) - return embedding1, embedding2 - - -@profile -def maximum_common_ordered_subtree_isomorphism(tree1, tree2, node_affinity='auto'): - """ - Isomorphic version of `maximum_common_ordered_tree_embedding`. - - CommandLine: - xdoctest -m /home/joncrall/code/netharn/netharn/initializers/_nx_extensions.py maximum_common_ordered_subtree_isomorphism:1 --profile && cat profile_output.txt - - Ignore: - >>> from netharn.initializers._nx_extensions import * # NOQA - >>> from netharn.initializers._nx_extensions import _lcs, _print_forest - >>> def random_ordered_tree(n, seed=None): - >>> tree = nx.dfs_tree(nx.random_tree(n, seed=seed)) - >>> otree = nx.OrderedDiGraph() - >>> otree.add_edges_from(tree.edges) - >>> return otree - >>> tree1 = random_ordered_tree(10, seed=3) - >>> tree2 = random_ordered_tree(10, seed=2) - >>> tree1.add_edges_from(tree2.edges, weight=1) - >>> tree1 = nx.minimum_spanning_arborescence(tree1) - >>> tree2.add_edges_from(tree1.edges, weight=1) - >>> tree2 = nx.minimum_spanning_arborescence(tree2) - >>> tree1.remove_edge(4, 7) - >>> tree1.remove_edge(4, 9) - >>> tree1.add_edge(4, 10) - >>> tree1.add_edge(10, 7) - >>> tree1.add_edge(10, 9) - >>> #tree1.add_edges_from([(9, 11), (11, 12), (12, 13), (13, 14)]) - >>> #tree2.add_edges_from([(9, 11), (11, 12), (12, 13), (13, 14)]) - >>> tree1.add_edges_from([(9, 11), (11, 12)]) - >>> tree2.add_edges_from([(9, 11), (11, 12)]) - >>> tree2.add_edge(100, 0) - >>> tree1.add_edge(102, 100) - >>> tree1.add_edge(100, 101) - >>> tree1.add_edge(101, 0) - >>> tree1.add_edge(5, 201) - >>> tree1.add_edge(5, 202) - >>> tree1.add_edge(5, 203) - >>> tree1.add_edge(201, 2000) - >>> tree1.add_edge(2000, 2001) - >>> tree1.add_edge(2001, 2002) - >>> tree1.add_edge(2002, 2003) - >>> tree2.add_edge(5, 202) - >>> tree2.add_edge(5, 203) - >>> tree2.add_edge(5, 201) - >>> tree2.add_edge(201, 2000) - >>> tree2.add_edge(2000, 2001) - >>> tree2.add_edge(2001, 2002) - >>> tree2.add_edge(2002, 2003) - >>> print('-----') - >>> print('tree1') - >>> _print_forest(tree1) - >>> print('tree2') - >>> _print_forest(tree2) - >>> subtree1, subtree2 = maximum_common_ordered_subtree_isomorphism(tree1, tree2 ) - >>> print('-----') - >>> print('subtree1') - >>> _print_forest(subtree1) - >>> print('subtree2') - >>> _print_forest(subtree2) - >>> embedding1, embedding2 = maximum_common_ordered_tree_embedding(tree1, tree2) - >>> print('-----') - >>> print('embedding1') - >>> _print_forest(embedding1) - >>> print('embedding2') - >>> _print_forest(embedding2) - >>> if 0: - >>> ti = timerit.Timerit(6, bestof=2, verbose=2) - >>> for timer in ti.reset('isomorphism'): - >>> with timer: - >>> maximum_common_ordered_subtree_isomorphism(tree1, tree2 ) - >>> for timer in ti.reset('embedding'): - >>> with timer: - >>> maximum_common_ordered_tree_embedding(tree1, tree2 ) - >>> from networkx import isomorphism - >>> assert isomorphism.DiGraphMatcher(tree1, subtree1).subgraph_is_isomorphic() - >>> assert isomorphism.DiGraphMatcher(tree2, subtree2).subgraph_is_isomorphic() - >>> list(isomorphism.DiGraphMatcher(tree1, tree2).subgraph_isomorphisms_iter()) - >>> list(isomorphism.DiGraphMatcher(tree1, tree2).subgraph_monomorphisms_iter()) - >>> list(isomorphism.DiGraphMatcher(subtree1, subtree2).subgraph_isomorphisms_iter()) - >>> list(isomorphism.DiGraphMatcher(tree1, subtree1).subgraph_isomorphisms_iter()) - >>> list(isomorphism.DiGraphMatcher(tree2, subtree2).subgraph_isomorphisms_iter()) - - Ignore: - >>> from netharn.initializers._nx_extensions import * # NOQA - >>> from netharn.initializers._nx_extensions import _lcs, _print_forest - >>> def random_ordered_tree(n, seed=None): - >>> if n > 0: - >>> tree = nx.dfs_tree(nx.random_tree(n, seed=seed)) - >>> otree = nx.OrderedDiGraph() - >>> if n > 0: - >>> otree.add_edges_from(tree.edges) - >>> return otree - >>> import random - >>> rng = random.Random(90269698983701724775426457020022) - >>> num = 1000 - >>> def _gen_seeds(num): - >>> for _ in range(num): - >>> yield (rng.randint(0, 50), rng.randint(0, 50), rng.randint(0, 2 ** 64), rng.randint(0, 2 ** 64)) - >>> for n1, n2, s1, s2 in ub.ProgIter(_gen_seeds(num=num), total=num, verbose=3): - >>> tree1 = random_ordered_tree(n1, seed=s1) - >>> tree2 = random_ordered_tree(n2, seed=s2) - >>> #print('-----') - >>> #print('tree1') - >>> #_print_forest(tree1) - >>> #print('tree2') - >>> #_print_forest(tree2) - >>> subtree1, subtree2 = maximum_common_ordered_subtree_isomorphism(tree1, tree2, node_affinity='auto') - >>> #print('-----') - >>> #print('subtree1') - >>> #_print_forest(subtree1) - >>> #print('subtree2') - >>> #_print_forest(subtree2) - >>> from networkx import isomorphism - >>> assert isomorphism.DiGraphMatcher(tree1, subtree1).subgraph_is_isomorphic() - >>> assert isomorphism.DiGraphMatcher(tree2, subtree2).subgraph_is_isomorphic() - - """ - try: - if not (isinstance(tree1, nx.OrderedDiGraph) and nx.is_forest(tree1)): - raise nx.NetworkXNotImplemented('only implemented for directed ordered trees') - if not (isinstance(tree1, nx.OrderedDiGraph) and nx.is_forest(tree2)): - raise nx.NetworkXNotImplemented('only implemented for directed ordered trees') - except nx.NetworkXPointlessConcept: - subtree1 = nx.OrderedDiGraph() - subtree2 = nx.OrderedDiGraph() - return subtree1, subtree2 - - # Convert the trees to balanced sequences - sequence1, open_to_close, toks = tree_to_balanced_sequence(tree1, open_to_close=None, toks=None, mode='chr') - sequence2, open_to_close, toks = tree_to_balanced_sequence(tree2, open_to_close, toks, mode='chr') - seq1 = sequence1 - seq2 = sequence2 - - open_to_tok = ub.invert_dict(toks) - - # Solve the longest common balanced sequence problem - best, value = longest_common_isomorphic_sequence( - seq1, seq2, open_to_close, open_to_tok=open_to_tok, node_affinity=node_affinity) - subseq1, subseq2 = best - - # Convert the subsequence back into a tree - subtree1 = seq_to_tree(subseq1, open_to_close, toks) - subtree2 = seq_to_tree(subseq2, open_to_close, toks) - return subtree1, subtree2 - - -class UnbalancedException(Exception): - pass - - -def tree_to_balanced_sequence(tree, open_to_close=None, toks=None, mode='tuple'): - from collections import namedtuple - Token = namedtuple('Token', ['action', 'value']) - # mapping between opening and closing tokens - sources = [n for n in tree.nodes if tree.in_degree[n] == 0] - sequence = [] - - if open_to_close is None: - open_to_close = {} - if toks is None: - toks = {} - - for source in sources: - for u, v, etype in nx.dfs_labeled_edges(tree, source=source): - if etype == 'forward': - # u has been visited by v has not - if v not in toks: - if mode == 'tuple': - # TODO: token encoding scheme where subdirectories - # are matchable via a custom operation. - # open_tok = '<{}>'.format(v) - # close_tok = ''.format(v) - open_tok = Token('open', v) - close_tok = Token('close', v) - elif mode == 'number': - open_tok = len(toks) + 1 - close_tok = -open_tok - elif mode == 'paren': - open_tok = '{}('.format(v) - close_tok = '){}'.format(v) - elif mode == 'chr': - open_tok = str(v) - close_tok = str(v) + u'\u0301' - # chr(ord(v) + 128) - toks[v] = open_tok - open_to_close[open_tok] = close_tok - open_tok = toks[v] - sequence.append(open_tok) - elif etype == 'reverse': - # Both u and v are visited and the edge is in the tree - close_tok = open_to_close[toks[v]] - sequence.append(close_tok) - else: - raise KeyError(etype) - sequence = tuple(sequence) - return sequence, open_to_close, toks - - -def seq_to_tree(subseq, open_to_close, toks): - open_to_tok = ub.invert_dict(toks) - subtree = nx.OrderedDiGraph() - stack = [] - for token in subseq: - if token in open_to_close: - node = open_to_tok[token] - if stack: - parent = open_to_tok[stack[-1]] - subtree.add_edge(parent, node) - else: - subtree.add_node(node) - stack.append(token) - else: - if not stack: - raise Exception - prev_open = stack.pop() - want_close = open_to_close[prev_open] - if token != want_close: - raise Exception - return subtree - - -def random_ordered_tree(n, seed=None): - tree = nx.dfs_tree(nx.random_tree(n, seed=seed)) - otree = nx.OrderedDiGraph() - otree.add_edges_from(tree.edges) - return otree - - -@profile -def generate_balance_unsafe_python(sequence, open_to_close): - """ - Benchmark: - >>> tree = random_ordered_tree(1000) - >>> sequence, open_to_close, toks = tree_to_balanced_sequence(tree, mode='tuple') - >>> sequence, open_to_close, toks = tree_to_balanced_sequence(tree, mode='number') - >>> import timerit - >>> ti = timerit.Timerit(100, bestof=10, verbose=2) - >>> for timer in ti.reset('time'): - >>> with timer: - >>> list(generate_balance_unsafe(sequence, open_to_close)) - >>> import timerit - >>> ti = timerit.Timerit(100, bestof=10, verbose=2) - >>> for timer in ti.reset('time'): - >>> with timer: - >>> list(generate_balance_unsafe_cython(sequence, open_to_close)) - """ - stacklen = 0 - for token in sequence: - if token in open_to_close: - stacklen += 1 - else: - stacklen -= 1 - yield stacklen == 0, token - - -@profile -def balanced_decomp(sequence, open_to_close): - """ - Note this is not exactly the same as the decomposition in the paper. - That is because we also return the "wrapping" element, and we let the - user do the head + tail concatenation. - - Example: - >>> open_to_close = {0: 1} - >>> sequence = [0, 0, 0, 1, 1, 1, 0, 1] - >>> open_to_close = {'{': '}', '(': ')', '[': ']'} - >>> sequence = '({[[]]})[[][]]' - >>> a1, b1, head, tail = balanced_decomp(sequence, open_to_close) - >>> a2, b2, tail1, tail2 = balanced_decomp(tail, open_to_close) - """ - gen = generate_balance(sequence, open_to_close) - - bal_curr, tok_curr = next(gen) - pop_open = sequence[0:1] - want_close = open_to_close[tok_curr] - - head_stop = 1 - for head_stop, (bal_curr, tok_curr) in enumerate(gen, start=1): - if tok_curr is None: - break - elif bal_curr and tok_curr == want_close: - pop_close = sequence[head_stop:head_stop + 1] - break - head = sequence[1:head_stop] - # if __debug__: - # list(gen) # exhaust the generator to check we are balanced - tail = sequence[head_stop + 1:] - return pop_open, pop_close, head, tail - - -@profile -def balanced_decomp_unsafe(sequence, open_to_close): - """ - open_to_close = {0: 1} - sequence = [0, 0, 0, 1, 1, 1, 0, 1] - open_to_close = {'{': '}', '(': ')', '[': ']'} - sequence = '({[[]]})[[][]]' - a1, b1, head, tail = balanced_decomp(sequence, open_to_close) - a2, b2, tail1, tail2 = balanced_decomp(tail, open_to_close) - - Benchmark: - >>> from netharn.initializers._nx_extensions import * # NOQA - >>> tree = random_ordered_tree(100) - >>> sequence, open_to_close, toks = tree_to_balanced_sequence(tree) - >>> import timerit - >>> ti = timerit.Timerit(100, bestof=10, verbose=2, unit='us') - >>> for timer in ti.reset('safe-python'): - >>> with timer: - >>> list(balanced_decomp(sequence, open_to_close)) - >>> for timer in ti.reset('unsafe-python'): - >>> with timer: - >>> list(balanced_decomp_unsafe(sequence, open_to_close)) - >>> for timer in ti.reset('unsafe-python-v2'): - >>> with timer: - >>> list(balanced_decomp_unsafe2_python(sequence, open_to_close)) - >>> for timer in ti.reset('unsafe-c/python-v2'): - >>> with timer: - >>> list(balanced_decomp_unsafe2(sequence, open_to_close)) - """ - gen = generate_balance_unsafe(sequence, open_to_close) - - bal_curr, tok_curr = next(gen) - pop_open = sequence[0:1] - want_close = open_to_close[tok_curr] - - head_stop = 1 - for head_stop, (bal_curr, tok_curr) in enumerate(gen, start=1): - if bal_curr and tok_curr == want_close: - pop_close = sequence[head_stop:head_stop + 1] - break - head = sequence[1:head_stop] - tail = sequence[head_stop + 1:] - return pop_open, pop_close, head, tail - - -@profile -def balanced_decomp_unsafe2_python(sequence, open_to_close): - stacklen = 0 - seq_iter = iter(sequence) - tok_curr = next(seq_iter) - stacklen += 1 if tok_curr in open_to_close else -1 - want_close = open_to_close[tok_curr] - - head_stop = 1 - for head_stop, tok_curr in enumerate(seq_iter, start=1): - stacklen += 1 if tok_curr in open_to_close else -1 - if stacklen == 0 and tok_curr == want_close: - break - - pop_close = sequence[head_stop:head_stop + 1] - pop_open = sequence[0:1] - head = sequence[1:head_stop] - tail = sequence[head_stop + 1:] - return pop_open, pop_close, head, tail - - -generate_balance_unsafe = generate_balance_unsafe_python -balanced_decomp_unsafe2 = balanced_decomp_unsafe2_python - - -if TRY_USE_CYTHON: - try: - from netharn.initializers import _nx_extensions_cython_backend as cyb - - generate_balance_unsafe_cython = cyb.generate_balance_unsafe_cython - generate_balance_unsafe = cyb.generate_balance_unsafe_cython - - balanced_decomp_unsafe2_cython = cyb.balanced_decomp_unsafe2_cython - balanced_decomp_unsafe2 = cyb.balanced_decomp_unsafe2_cython - except Exception: - pass - - -def generate_balance(sequence, open_to_close, safe=True): - """ - Args: - safe (bool): if True we will error if the sequence is not balanced - if you are SURE the sequence is balanced set safe=False to slightly - improve runtime. - - - CommandLine: - xdoctest -m /home/joncrall/code/netharn/netharn/initializers/_nx_extensions.py generate_balance:1 --profile - - Example: - >>> open_to_close = {0: 1} - >>> sequence = [0, 0, 0, 1, 1, 1] - >>> gen = list(generate_balance(sequence, open_to_close)) - >>> for flag, token in gen: - >>> print('flag={:d}, token={}'.format(flag, token)) - - Example: - >>> tree = random_ordered_tree(1000) - >>> sequence, open_to_close, toks = tree_to_balanced_sequence(tree) - >>> gen = list(generate_balance(sequence, open_to_close)) - >>> for flag, token in gen: - >>> print('flag={:d}, token={}'.format(flag, token)) - - Benchmark: - >>> from netharn.initializers._nx_extensions import * # NOQA - >>> tree = random_ordered_tree(100) - >>> sequence, open_to_close, toks = tree_to_balanced_sequence(tree) - >>> import timerit - >>> ti = timerit.Timerit(100, bestof=10, verbose=2, unit='us') - >>> for timer in ti.reset('safe-python'): - >>> with timer: - >>> list(generate_balance(sequence, open_to_close)) - >>> for timer in ti.reset('unsafe-python'): - >>> with timer: - >>> list(generate_balance_unsafe(sequence, open_to_close)) - - Ignore: - from netharn.initializers._nx_extensions import * # NOQA - from numba import jit - jit_generate_balance = jit(forceobj=True)(generate_balance) - - open_to_close = {0: 1} - sequence = [0, 0, 0, 1, 1, 1] - list(jit_generate_balance(sequence, open_to_close)) - - tree = random_ordered_tree(1000) - sequence, open_to_close, toks = tree_to_balanced_sequence(tree) - - import timerit - ti = timerit.Timerit(100, bestof=10, verbose=2, unit='us') - - for timer in ti.reset('safe-python'): - with timer: - list(generate_balance(sequence, open_to_close)) - - for timer in ti.reset('unsafe-python'): - with timer: - list(generate_balance_unsafe(sequence, open_to_close)) - - for timer in ti.reset('numba'): - with timer: - list(jit_generate_balance(sequence, open_to_close)) - """ - if safe: - stack = [] - # Traversing the Expression - for token in sequence: - - if token in open_to_close: - # Push opening elements onto the stack - stack.append(token) - else: - # Check that closing elements - if not stack: - raise UnbalancedException - prev_open = stack.pop() - want_close = open_to_close[prev_open] - - if token != want_close: - raise UnbalancedException - - # If the stack is empty the sequence is currently balanced - currently_balanced = not bool(stack) - yield currently_balanced, token - - if stack: - raise UnbalancedException - else: - yield from generate_balance_unsafe(sequence, open_to_close) - - -@profile -def longest_common_balanced_sequence(seq1, seq2, open_to_close, node_affinity='auto', open_to_tok=None): - """ - CommandLine: - xdoctest -m /home/joncrall/code/netharn/netharn/initializers/_nx_extensions.py longest_common_balanced_sequence:0 --profile && cat profile_output.txt - - Example: - >>> tree1 = random_ordered_tree(100, seed=1) - >>> tree2 = random_ordered_tree(100, seed=2) - >>> seq1, open_to_close, toks = tree_to_balanced_sequence(tree1) - >>> seq2, open_to_close, toks = tree_to_balanced_sequence(tree2, open_to_close, toks) - >>> longest_common_balanced_sequence(seq1, seq2, open_to_close) - - Benchmark: - >>> tree1 = random_ordered_tree(20, seed=1) - >>> tree2 = random_ordered_tree(20, seed=2) - >>> seq1, open_to_close, toks = tree_to_balanced_sequence(tree1) - >>> seq2, open_to_close, toks = tree_to_balanced_sequence(tree2, open_to_close, toks) - >>> longest_common_balanced_sequence(seq1, seq2, open_to_close) - - import sys, ubelt - sys.path.append(ubelt.expandpath('~/code/netharn')) - from netharn.initializers._nx_extensions import * # NOQA - from netharn.initializers._nx_extensions import _best_prefix_transform, _lcs, _print_forest - - open_to_close = {'0': '1'} - seq1 = '0010010010111100001011011011' - seq2 = '001000101101110001000100101110111011' - - open_to_close = {'(': ')'} - seq1 = '(()(()(()())))(((()())())())' - seq2 = '(()((()())()))((()((()(()()))()))())' - longest_common_balanced_sequence(seq1, seq2, open_to_close) - - open_to_close = {'0': '1'} - seq1 = '0010010010111100001011011011' - seq2 = '001000101101110001000100101110111011' - longest_common_balanced_sequence(seq1, seq2, open_to_close) - - open_to_close = {'0': '1'} - seq1 = '001101' - seq2 = '00110011' - seq1 = '001101' - seq2 = '00110011' - longest_common_balanced_sequence(seq1, seq2, open_to_close) - - open_to_close = {'{': '}', '(': ')', '[': ']'} - seq1 = '(({}{([])}[{}]))' - seq2 = '((({}[{{}}])))' - - seq1 = '({[[[]]]}){}' - seq2 = '{}{[[[]]]}' - best, value = longest_common_balanced_sequence(seq1, seq2, open_to_close) - subseq1, subseq2 = best - print('subseq1 = {!r}'.format(subseq1)) - """ - if node_affinity == 'auto': - node_affinity = operator.eq - if node_affinity is None: - def _matchany(a, b): - return True - node_affinity = _matchany - _memo = {} - _seq_memo = {} - if open_to_tok is None: - class Dummy: - def __getitem__(self, key): - return key - open_to_tok = Dummy() - best, value = _lcs(seq1, seq2, open_to_close, node_affinity, open_to_tok, _memo, _seq_memo) - return best, value - - -@profile -def _lcs(seq1, seq2, open_to_close, node_affinity, open_to_tok, _memo, _seq_memo): - if not seq1: - return (seq1, seq1), 0 - elif not seq2: - return (seq2, seq2), 0 - else: - # if len(seq2) < len(seq1): - # seq1, seq2 = seq2, seq1 - # key = (seq1, seq2) - key1 = hash(seq1) # using hash(seq) is faster than seq itself - key2 = hash(seq2) - key = hash((key1, key2)) - if key in _memo: - return _memo[key] - - # TODO: we can probably just do a single linear run through the - # sequences to index the sub-sequence locations and then apply an - # offset when we run the decomposed sequence. - if key1 in _seq_memo: - a1, b1, head1, tail1, head1_tail1 = _seq_memo[key1] - else: - a1, b1, head1, tail1 = balanced_decomp_unsafe2(seq1, open_to_close) - head1_tail1 = head1 + tail1 - _seq_memo[key1] = a1, b1, head1, tail1, head1_tail1 - - if key2 in _seq_memo: - a2, b2, head2, tail2, head2_tail2 = _seq_memo[key2] - else: - a2, b2, head2, tail2 = balanced_decomp_unsafe2(seq2, open_to_close) - head2_tail2 = head2 + tail2 - _seq_memo[key2] = a2, b2, head2, tail2, head2_tail2 - - # Case 2: The current edge in sequence1 is deleted - best, val = _lcs(head1_tail1, seq2, open_to_close, node_affinity, open_to_tok, _memo, _seq_memo) - - # Case 3: The current edge in sequence2 is deleted - cand, val_alt = _lcs(seq1, head2_tail2, open_to_close, node_affinity, open_to_tok, _memo, _seq_memo) - if val_alt > val: - best = cand - val = val_alt - - # Case 1: The LCS involves this edge - t1 = open_to_tok[a1[0]] - t2 = open_to_tok[a2[0]] - # if node_affinity(a1[0], a2[0]): - affinity = node_affinity(t1, t2) - if affinity: - new_heads, pval_h = _lcs(head1, head2, open_to_close, node_affinity, open_to_tok, _memo, _seq_memo) - new_tails, pval_t = _lcs(tail1, tail2, open_to_close, node_affinity, open_to_tok, _memo, _seq_memo) - - new_head1, new_head2 = new_heads - new_tail1, new_tail2 = new_tails - - subseq1 = a1 + new_head1 + b1 + new_tail1 - subseq2 = a2 + new_head2 + b2 + new_tail2 - - cand = (subseq1, subseq2) - val_alt = pval_h + pval_t + affinity - if val_alt > val: - best = cand - val = val_alt - - found = (best, val) - _memo[key] = found - return found - - -@profile -def longest_common_isomorphic_sequence(seq1, seq2, open_to_close, node_affinity='auto', open_to_tok=None): - if node_affinity == 'auto': - node_affinity = operator.eq - if node_affinity is None: - def _matchany(a, b): - return True - node_affinity = _matchany - _memo = {} - _seq_memo = {} - if open_to_tok is None: - class Dummy: - def __getitem__(self, key): - return key - open_to_tok = Dummy() - best_lvl, value_lvl, best_low, value_low = _lcsi(seq1, seq2, open_to_close, node_affinity, open_to_tok, _memo, _seq_memo) - - if value_lvl > value_low: - best = best_lvl - value = value_lvl - else: - best = best_low - value = value_low - - return best, value - - -@profile -def _lcsi(seq1, seq2, open_to_close, node_affinity, open_to_tok, _memo, _seq_memo): - """ - Prototype isomorphic only version - """ - if not seq1: - return (seq1, seq1), 0, (seq1, seq1), 0 - elif not seq2: - return (seq2, seq2), 0, (seq2, seq2), 0 - else: - key1 = hash(seq1) - key2 = hash(seq2) - key = hash((key1, key2)) - if key in _memo: - return _memo[key] - - if key1 in _seq_memo: - a1, b1, head1, tail1, head1_tail1 = _seq_memo[key1] - else: - a1, b1, head1, tail1 = balanced_decomp_unsafe2(seq1, open_to_close) - head1_tail1 = head1 + tail1 - _seq_memo[key1] = a1, b1, head1, tail1, head1_tail1 - - if key2 in _seq_memo: - a2, b2, head2, tail2, head2_tail2 = _seq_memo[key2] - else: - a2, b2, head2, tail2 = balanced_decomp_unsafe2(seq2, open_to_close) - head2_tail2 = head2 + tail2 - _seq_memo[key2] = a2, b2, head2, tail2, head2_tail2 - - # TODO: IS THIS THE CORRECT MODIFICATION TO THE RECURRANCE TO - # ACHIEVE A SUBTREE ISOMORPHISM INSTEAD OF AN EMBEDDING? - r""" - - tree1 = nx.OrderedDiGraph() - tree1.add_nodes_from(['a', 'b', 'c', 'd', 'e', 'f', 'g']) - tree1.add_edges_from([('a', 'b'), ('a', 'c'), ('a', 'd'), ('b', 'e'), ('b', 'f'), ('c', 'g')]) - - _print_forest(tree1) - - └── a - ├── b - │   ├── e - │   └── f - ├── c - │   └── g - └── d - - seq1, open_to_close, toks = tree_to_balanced_sequence(tree1, mode='chr') - a, b, head1, tail1 = balanced_decomp(seq1, open_to_close) - _print_forest(seq_to_tree(head1, open_to_close, toks)) - _print_forest(seq_to_tree(tail1, open_to_close, toks)) - - CONTRACTED NODE: - a - - HEAD (children of the contracted node) - - ├── b - │   ├── e - │   └── f - ├── c - │   └── g - └── d - - TAIL (right siblings of the contracted node) - -- - - a, b, head11, tail11 = balanced_decomp(head1, open_to_close) - _print_forest(seq_to_tree(head11, open_to_close, toks)) - _print_forest(seq_to_tree(tail11, open_to_close, toks)) - - CONTRACTED NODE: - b - - HEAD OF HEAD - ├── e - └── f - - TAIL OF HEAD - ├── c - │   └── g - └── d - - - The problem here is that if you are at a level where two levels down - there are two matches, you will return those two matches as the best - solution at that layer, and therefore you won't flag if there is a - feasible solution at this layer. This is a problem because that - feasible low-value solution might be part of the highest value - solution. - - Perhaps we return two solutions at each step: the solution value at - this level if one exists, and the solution value at any other depth. - We are allowed to add to the first, but can take the second if we want - to. - - This should work because we know a solution that skipped a layer will - never be added to, and we are always keeping track of the solution that - might change. By the time we get to the root level, we have enough info - to know which is better. - """ - - # If any of these cases are selected we are not choosing the leftmost - # node as our match - best_lvl, val_lvl, best_low, val_low = None, -1, None, -1 - - # TODO: it may be the case that some of these tests are redundant, in - # which case we could simplify and speed up the algorithm. We would - # need to prove that the value in one of these tests was always lower - # than the value in another one of these tests, in that case we could - # remove the former. - - # When using the head part of the decomp, we can only update the "low" candidate - cand_lvl, score_lvl, cand_low, score_low = _lcsi(head1, seq2, open_to_close, node_affinity, open_to_tok, _memo, _seq_memo) - if score_low > val_low: - val_low = score_low - best_low = cand_low - if score_lvl > val_low: - val_low = score_lvl - best_low = cand_lvl - - cand_lvl, score_lvl, cand_low, score_low = _lcsi(seq1, head2, open_to_close, node_affinity, open_to_tok, _memo, _seq_memo) - if score_low > val_low: - val_low = score_low - best_low = cand_low - if score_lvl > val_low: - val_low = score_lvl - best_low = cand_lvl - - # As long as we are only using the tail part of the decomp we can update - # both the lvl and low scores - cand_lvl, score_lvl, cand_low, score_low = _lcsi(tail1, seq2, open_to_close, node_affinity, open_to_tok, _memo, _seq_memo) - if score_lvl > val_lvl: - val_lvl = score_lvl - best_lvl = cand_lvl - if score_low > val_low: - val_low = score_low - best_low = cand_low - - cand_lvl, score_lvl, cand_low, score_low = _lcsi(seq1, tail2, open_to_close, node_affinity, open_to_tok, _memo, _seq_memo) - if score_lvl > val_lvl: - val_lvl = score_lvl - best_lvl = cand_lvl - if score_low > val_low: - val_low = score_low - best_low = cand_low - - # This is the case where we found a matching node - t1 = open_to_tok[a1[0]] - t2 = open_to_tok[a2[0]] - affinity = node_affinity(t1, t2) - if affinity: - - new_heads_lvl, pval_h_lvl, new_heads_low, pval_h_low = _lcsi(head1, head2, open_to_close, node_affinity, open_to_tok, _memo, _seq_memo) - new_tails_lvl, pval_t_lvl, new_tails_low, pval_t_low = _lcsi(tail1, tail2, open_to_close, node_affinity, open_to_tok, _memo, _seq_memo) - - # Add to the best solution at the former level - score_lvl = pval_h_lvl + pval_t_lvl + affinity - if score_lvl > val_lvl: - new_head1, new_head2 = new_heads_lvl - new_tail1, new_tail2 = new_tails_lvl - subseq1 = a1 + new_head1 + b1 + new_tail1 - subseq2 = a2 + new_head2 + b2 + new_tail2 - cand_lvl = (subseq1, subseq2) - val_lvl = score_lvl - best_lvl = cand_lvl - - # In my big tests these were never hit once, is it true that this - # test was covered by a previous case? - cand_low = new_heads_low - score_low = pval_h_low - if score_low > val_low: - val_low = score_low - best_low = cand_low - - cand_low = new_tails_low - score_low = pval_t_low - if score_low > val_low: - val_low = score_low - best_low = cand_low - - # We return two solutions: - # the best AT this level (lvl), and the best AT any lowers (low). - found = (best_lvl, val_lvl, best_low, val_low) - _memo[key] = found - return found - - -def _print_forest(graph): - """ - Nice ascii representation of a forest - - Ignore: - graph = nx.balanced_tree(r=2, h=3, create_using=nx.DiGraph) - _print_forest(graph) - - graph = CategoryTree.demo('coco').graph - _print_forest(graph) - """ - if len(graph.nodes) == 0: - print('--') - return - assert nx.is_forest(graph) - - def _recurse(node, indent='', islast=False): - if islast: - this_prefix = indent + '└── ' - next_prefix = indent + ' ' - else: - this_prefix = indent + '├── ' - next_prefix = indent + '│   ' - label = graph.nodes[node].get('label', node) - print(this_prefix + str(label)) - graph.succ[node] - children = graph.succ[node] - for idx, child in enumerate(children, start=1): - islast_next = (idx == len(children)) - _recurse(child, indent=next_prefix, islast=islast_next) - - sources = [n for n in graph.nodes if graph.in_degree[n] == 0] - for idx, node in enumerate(sources, start=1): - islast_next = (idx == len(sources)) - _recurse(node, indent='', islast=islast_next) - - -def maximum_common_ordered_paths(paths1, paths2, sep='/'): - import networkx as nx - - # the longest common balanced sequence problem - def _affinity(tok1, tok2): - score = 0 - for t1, t2 in zip(tok1[::-1], tok2[::-1]): - if t1 == t2: - score += 1 - else: - break - return score - # return tok1[-1] == tok2[-1] - node_affinity = _affinity - # import operator - # eq = operator.eq - - def paths_to_tree(paths): - tree = nx.OrderedDiGraph() - for path in sorted(paths): - parts = tuple(path.split(sep)) - node_path = [] - for i in range(1, len(parts) + 1): - node = parts[0:i] - tree.add_node(node) - tree.nodes[node]['label'] = node[-1] - node_path.append(node) - for u, v in ub.iter_window(node_path, 2): - tree.add_edge(u, v) - return tree - - tree1 = paths_to_tree(paths1) - tree2 = paths_to_tree(paths2) - - subtree1, subtree2 = maximum_common_ordered_tree_embedding(tree1, tree2, node_affinity=node_affinity) - # subtree1, subtree2 = maximum_common_ordered_subtree_isomorphism(tree1, tree2, node_affinity=node_affinity) - - subpaths1 = [sep.join(node) for node in subtree1.nodes if subtree1.out_degree[node] == 0] - subpaths2 = [sep.join(node) for node in subtree2.nodes if subtree2.out_degree[node] == 0] - return subpaths1, subpaths2 diff --git a/netharn/initializers/_nx_extensions_cython_backend.pyx b/netharn/initializers/_nx_extensions_cython_backend.pyx deleted file mode 100644 index c3d312e6146f61a51417bfd838ae8235a066ba61..0000000000000000000000000000000000000000 --- a/netharn/initializers/_nx_extensions_cython_backend.pyx +++ /dev/null @@ -1,46 +0,0 @@ -""" -cythonize -a -i ~/code/netharn/netharn/initializers/_nx_extensions_cython_backend.pyx - - >>> from netharn.initializers import _nx_extensions_cython_backend - >>> import timerit - >>> ti = timerit.Timerit(100, bestof=10, verbose=2) - >>> for timer in ti.reset('time'): - >>> with timer: - >>> list(_nx_extensions_cython_backend.generate_balance_unsafe_cython(sequence, open_to_close)) - -""" - -def generate_balance_unsafe_cython(sequence, open_to_close): - cdef tuple item - cdef bint flag - cdef int stacklen = 0 - for token in sequence: - if token in open_to_close: - stacklen += 1 - else: - stacklen -= 1 - flag = stacklen == 0 - item = (flag, token) - yield item - - -def balanced_decomp_unsafe2_cython(tuple sequence, dict open_to_close): - cdef int stacklen = 1 # always +1 in the first iteration - cdef int head_stop = 1 - - tok_curr = sequence[0] - want_close = open_to_close[tok_curr] - - # for tok_curr in sequence[1:]: - for head_stop in range(1, len(sequence)): - tok_curr = sequence[head_stop] - stacklen += 1 if tok_curr in open_to_close else -1 - if stacklen == 0 and tok_curr == want_close: - pop_close = sequence[head_stop:head_stop + 1] - break - - pop_open = sequence[0:1] - head = sequence[1:head_stop] - tail = sequence[head_stop + 1:] - return pop_open, pop_close, head, tail - diff --git a/netharn/initializers/balanced_sequence.py b/netharn/initializers/balanced_sequence.py deleted file mode 100644 index 13d5db1751c06bea874db00cac374abbb9849b6a..0000000000000000000000000000000000000000 --- a/netharn/initializers/balanced_sequence.py +++ /dev/null @@ -1,969 +0,0 @@ -import operator -import ubelt as ub -import networkx as nx - -try: - import xdev - profile = xdev.profile -except Exception: - profile = ub.identity - - -# @profile -def longest_common_balanced_sequence(seq1, seq2, open_to_close, open_to_tok=None, node_affinity='auto', impl='iter'): - """ - CommandLine: - xdoctest -m /home/joncrall/code/netharn/netharn/initializers/balanced_sequence.py longest_common_balanced_sequence:0 --profile && cat profile_output.txt - - Example: - >>> from netharn.initializers.balanced_sequence import * # NOQA - >>> from netharn.initializers.balanced_sequence import _lcs_iter_prehash, _lcs_iter_simple, _lcs_recurse, _print_forest - >>> tree1 = random_ordered_tree(5, seed=10, pool='[{(') - >>> tree2 = random_ordered_tree(5, seed=3, pool='[{(') - - >>> import kwarray - >>> rng = kwarray.ensure_rng(3432432, 'python') - >>> tree1 = random_ordered_tree(100, seed=rng, pool='[{(') - >>> tree2 = random_ordered_tree(100, seed=rng, pool='[{(') - >>> if len(tree1.nodes) < 20: - >>> _print_forest(tree1) - >>> _print_forest(tree2) - >>> seq1, open_to_close, toks = tree_to_balanced_sequence(tree1, mode='label', strhack=1) - >>> seq2, open_to_close, toks = tree_to_balanced_sequence(tree2, open_to_close, toks, mode='label', strhack=1) - >>> full_seq1 = seq1 - >>> full_seq2 = seq2 - >>> print('seq1 = {!r}'.format(seq1)) - >>> print('seq2 = {!r}'.format(seq2)) - >>> open_to_tok = ub.invert_dict(toks) - >>> node_affinity = operator.eq - >>> with ub.Timer('iterative-alt2'): - >>> best1, val1 = longest_common_balanced_sequence(seq1, seq2, open_to_close, open_to_tok, impl='iter-alt2') - >>> print('val1, best1 = {}, {!r}'.format(val1, best1)) - >>> with ub.Timer('iterative-alt1'): - >>> best1, val1 = longest_common_balanced_sequence(seq1, seq2, open_to_close, open_to_tok, impl='iter-alt1') - >>> print('val1, best1 = {}, {!r}'.format(val1, best1)) - >>> with ub.Timer('iterative'): - >>> best1, val1 = longest_common_balanced_sequence(seq1, seq2, open_to_close, open_to_tok, impl='iter') - >>> print('val1, best1 = {}, {!r}'.format(val1, best1)) - >>> with ub.Timer('recursive'): - >>> best2, val2 = longest_common_balanced_sequence(seq1, seq2, open_to_close, open_to_tok, impl='recurse') - >>> print('val2, best2 = {}, {!r}'.format(val2, best2)) - >>> #with ub.Timer('iterative-prehash'): - >>> # best1, val1 = longest_common_balanced_sequence(seq1, seq2, open_to_close, open_to_tok, impl='iter-prehash') - >>> # print('val1, best1 = {}, {!r}'.format(val1, best1)) - """ - if node_affinity == 'auto' or node_affinity == 'eq': - node_affinity = operator.eq - if node_affinity is None: - def _matchany(a, b): - return True - node_affinity = _matchany - _memo = {} - _seq_memo = {} - if open_to_tok is None: - class Dummy: - def __getitem__(self, key): - return key - open_to_tok = Dummy() - full_seq1 = seq1 - full_seq2 = seq2 - if impl == 'recurse': - best, value = _lcs_recurse(full_seq1, full_seq2, open_to_close, node_affinity, open_to_tok, _memo, _seq_memo) - elif impl == 'iter': - best, value = _lcs_iter_simple(full_seq1, full_seq2, open_to_close, node_affinity, open_to_tok) - elif impl == 'iter-prehash': - best, value = _lcs_iter_prehash(full_seq1, full_seq2, open_to_close, node_affinity, open_to_tok) - elif impl == 'iter-alt1': - best, value = _lcs_iter_simple_alt1(full_seq1, full_seq2, open_to_close, node_affinity, open_to_tok) - elif impl == 'iter-alt2': - best, value = _lcs_iter_simple_alt2(full_seq1, full_seq2, open_to_close, node_affinity, open_to_tok) - else: - raise KeyError(impl) - return best, value - - -@profile -def _lcs_iter_simple(full_seq1, full_seq2, open_to_close, node_affinity, open_to_tok): - """ - Converts _lcs_recursive to an iterative algorithm using a fairly - straightforward method that effectivly simulates callstacks - """ - all_decomp1 = generate_all_decompositions(full_seq1, open_to_close, open_to_tok) - all_decomp2 = generate_all_decompositions(full_seq2, open_to_close, open_to_tok) - - args0 = (full_seq1, full_seq2) - frame0 = args0 - stack = [frame0] - - _results = {} - # Populate base cases - empty1 = type(ub.peek(all_decomp1.keys()))() - empty2 = type(ub.peek(all_decomp2.keys()))() - best = (empty1, empty2) - base_result = (0, best) - for seq1 in all_decomp1.keys(): - key1 = seq1 - t1, a1, b1, head1, tail1, head_tail1 = all_decomp1[key1] - _results[(seq1, empty2)] = base_result - _results[(head1, empty2)] = base_result - _results[(tail1, empty2)] = base_result - _results[(head_tail1, empty2)] = base_result - - for seq2 in all_decomp2.keys(): - key2 = seq2 - t2, a2, b2, head2, tail2, head_tail2 = all_decomp2[key2] - _results[(empty1, seq2)] = base_result - _results[(empty1, head2)] = base_result - _results[(empty1, tail2)] = base_result - _results[(empty1, head_tail2)] = base_result - - del args0 - del frame0 - del empty1 - del empty2 - del best - del base_result - - missing_frames = [] - while stack: - key = stack.pop() - if key not in _results: - seq1, seq2 = key - missing_frames.clear() - - # try: - t1, a1, b1, head1, tail1, head_tail1 = all_decomp1[seq1] - # except KeyError: - # a1, b1, head1, tail1 = balanced_decomp_unsafe(seq1, open_to_close) - # head_tail1 = head1 + tail1 - # all_decomp1[seq1] = a1, b1, head1, tail1, head_tail1 - - # try: - t2, a2, b2, head2, tail2, head_tail2 = all_decomp2[seq2] - # except KeyError: - # a2, b2, head2, tail2 = balanced_decomp_unsafe(seq2, open_to_close) - # head_tail2 = head2 + tail2 - # all_decomp2[seq2] = a2, b2, head2, tail2, head_tail2 - - # Case 2: The current edge in sequence1 is deleted - try: - try_key = (head_tail1, seq2) - cand1 = _results[try_key] - except KeyError: - missing_frames.append(try_key) - - # Case 3: The current edge in sequence2 is deleted - try: - try_key = (seq1, head_tail2) - cand2 = _results[try_key] - except KeyError: - missing_frames.append(try_key) - - # Case 1: The LCS involves this edge - affinity = node_affinity(t1, t2) - if affinity: - try: - try_key = (head1, head2) - pval_h, new_heads = _results[try_key] - except KeyError: - missing_frames.append(try_key) - - try: - try_key = (tail1, tail2) - pval_t, new_tails = _results[try_key] - except KeyError: - missing_frames.append(try_key) - - if not missing_frames: - new_head1, new_head2 = new_heads - new_tail1, new_tail2 = new_tails - - subseq1 = a1 + new_head1 + b1 + new_tail1 - subseq2 = a2 + new_head2 + b2 + new_tail2 - - res3 = (subseq1, subseq2) - val3 = pval_h + pval_t + affinity - cand3 = (val3, res3) - else: - cand3 = (-1, None) - - if missing_frames: - # We did not solve this frame yet - stack.append(key) - stack.extend(missing_frames) - # stack.extend(missing_frames[::-1]) - else: - # We solved the frame - _results[key] = max(cand1, cand2, cand3) - - val, best = _results[key] - found = (best, val) - return found - - -@profile -def _lcs_iter_simple_alt1(full_seq1, full_seq2, open_to_close, node_affinity, open_to_tok): - """ - Depth first stack trajectory - """ - all_decomp1 = generate_all_decompositions(full_seq1, open_to_close, open_to_tok) - all_decomp2 = generate_all_decompositions(full_seq2, open_to_close, open_to_tok) - - args0 = (full_seq1, full_seq2) - frame0 = args0 - stack = [frame0] - - _results = {} - # Populate base cases - empty1 = type(ub.peek(all_decomp1.keys()))() - empty2 = type(ub.peek(all_decomp2.keys()))() - best = (empty1, empty2) - base_result = (0, best) - for seq1 in all_decomp1.keys(): - key1 = seq1 - t1, a1, b1, head1, tail1, head_tail1 = all_decomp1[key1] - _results[(seq1, empty2)] = base_result - _results[(head1, empty2)] = base_result - _results[(tail1, empty2)] = base_result - _results[(head_tail1, empty2)] = base_result - - for seq2 in all_decomp2.keys(): - key2 = seq2 - t2, a2, b2, head2, tail2, head_tail2 = all_decomp2[key2] - _results[(empty1, seq2)] = base_result - _results[(empty1, head2)] = base_result - _results[(empty1, tail2)] = base_result - _results[(empty1, head_tail2)] = base_result - - del args0 - del frame0 - del empty1 - del empty2 - del best - del base_result - - while stack: - key = stack.pop() - if key not in _results: - seq1, seq2 = key - - t1, a1, b1, head1, tail1, head_tail1 = all_decomp1[seq1] - - t2, a2, b2, head2, tail2, head_tail2 = all_decomp2[seq2] - - # Case 2: The current edge in sequence1 is deleted - try: - try_key = (head_tail1, seq2) - cand1 = _results[try_key] - except KeyError: - stack.append(key) - stack.append(try_key) - continue - - # Case 3: The current edge in sequence2 is deleted - try: - try_key = (seq1, head_tail2) - cand2 = _results[try_key] - except KeyError: - stack.append(key) - stack.append(try_key) - continue - - # Case 1: The LCS involves this edge - affinity = node_affinity(t1, t2) - if affinity: - try: - try_key = (head1, head2) - pval_h, new_heads = _results[try_key] - except KeyError: - stack.append(key) - stack.append(try_key) - continue - - try: - try_key = (tail1, tail2) - pval_t, new_tails = _results[try_key] - except KeyError: - stack.append(key) - stack.append(try_key) - continue - - new_head1, new_head2 = new_heads - new_tail1, new_tail2 = new_tails - - subseq1 = a1 + new_head1 + b1 + new_tail1 - subseq2 = a2 + new_head2 + b2 + new_tail2 - - res3 = (subseq1, subseq2) - val3 = pval_h + pval_t + affinity - cand3 = (val3, res3) - else: - cand3 = (-1, None) - - # We solved the frame - _results[key] = max(cand1, cand2, cand3) - - val, best = _results[key] - found = (best, val) - return found - - -@profile -def _lcs_iter_simple_alt2(full_seq1, full_seq2, open_to_close, node_affinity, open_to_tok): - """ - Depth first stack trajectory and replace try except statements with ifs - """ - all_decomp1 = generate_all_decompositions(full_seq1, open_to_close, open_to_tok) - all_decomp2 = generate_all_decompositions(full_seq2, open_to_close, open_to_tok) - - key0 = (full_seq1, full_seq2) - frame0 = key0 - stack = [frame0] - - _results = {} - # Populate base cases - empty1 = type(ub.peek(all_decomp1.keys()))() - empty2 = type(ub.peek(all_decomp2.keys()))() - best = (empty1, empty2) - base_result = (0, best) - for seq1 in all_decomp1.keys(): - key1 = seq1 - t1, a1, b1, head1, tail1, head_tail1 = all_decomp1[key1] - _results[(seq1, empty2)] = base_result - _results[(head1, empty2)] = base_result - _results[(tail1, empty2)] = base_result - _results[(head_tail1, empty2)] = base_result - - for seq2 in all_decomp2.keys(): - key2 = seq2 - t2, a2, b2, head2, tail2, head_tail2 = all_decomp2[key2] - _results[(empty1, seq2)] = base_result - _results[(empty1, head2)] = base_result - _results[(empty1, tail2)] = base_result - _results[(empty1, head_tail2)] = base_result - - del frame0 - del empty1 - del empty2 - del best - del base_result - - while stack: - key = stack[-1] - if key not in _results: - seq1, seq2 = key - - t1, a1, b1, head1, tail1, head_tail1 = all_decomp1[seq1] - t2, a2, b2, head2, tail2, head_tail2 = all_decomp2[seq2] - - # Case 2: The current edge in sequence1 is deleted - try_key = (head_tail1, seq2) - if try_key in _results: - cand1 = _results[try_key] - else: - # stack.append(key) - stack.append(try_key) - continue - - # Case 3: The current edge in sequence2 is deleted - try_key = (seq1, head_tail2) - if try_key in _results: - cand2 = _results[try_key] - else: - # stack.append(key) - stack.append(try_key) - continue - - # Case 1: The LCS involves this edge - affinity = node_affinity(t1, t2) - if affinity: - try_key = (head1, head2) - if try_key in _results: - pval_h, new_heads = _results[try_key] - else: - # stack.append(key) - stack.append(try_key) - continue - - try_key = (tail1, tail2) - if try_key in _results: - pval_t, new_tails = _results[try_key] - else: - # stack.append(key) - stack.append(try_key) - continue - - new_head1, new_head2 = new_heads - new_tail1, new_tail2 = new_tails - - subseq1 = a1 + new_head1 + b1 + new_tail1 - subseq2 = a2 + new_head2 + b2 + new_tail2 - - res3 = (subseq1, subseq2) - val3 = pval_h + pval_t + affinity - cand3 = (val3, res3) - else: - cand3 = (-1, None) - - # We solved the frame - _results[key] = max(cand1, cand2, cand3) - stack.pop() - - val, best = _results[key0] - found = (best, val) - return found - - -@profile -def _lcs_iter_prehash(full_seq1, full_seq2, open_to_close, node_affinity, open_to_tok): - """ - Version of the lcs iterative algorithm where we precompute hash values - - This is actually slower than the simple version - """ - def decomp_info(seq, open_to_close): - pop_open, pop_close, head, tail = balanced_decomp_unsafe(seq, open_to_close) - head_tail = head + tail - head_key = hash(head) - tail_key = hash(tail) - head_tail_key = hash(head_tail) - tok = open_to_tok[pop_open[0]] - a = pop_open - b = pop_close - info = (tok, seq, head, tail, head_tail, head_key, tail_key, head_tail_key, a, b) - return info - - def gen_decomp_v2(seq, open_to_close): - _genmemo = {} - def _gen(seq): - if seq: - key = hash(seq) - if key not in _genmemo: - info = decomp_info(seq, open_to_close) - head, tail, head_tail = info[2:5] - _genmemo[key] = info - yield (seq, _genmemo[key]) - yield from _gen(head_tail) - yield from _gen(head) - yield from _gen(tail) - all_decomp = dict(_gen(seq)) - return all_decomp - - all_decomp1 = gen_decomp_v2(full_seq1, open_to_close) - all_decomp2 = gen_decomp_v2(full_seq2, open_to_close) - - key_decomp1 = {} - key_decomp2 = {} - _results = {} - # Populate base cases - empty1 = type(ub.peek(all_decomp1.keys()))() - empty2 = type(ub.peek(all_decomp2.keys()))() - empty1_key = hash(empty1) - empty2_key = hash(empty2) - best = (empty1, empty2) - base_result = (0, best) - for seq1, info1 in all_decomp1.items(): - seq1_key = hash(seq1) - head1_key, tail1_key, head_tail1_key = all_decomp1[seq1][5:8] - _results[(seq1_key, empty2_key)] = base_result - _results[(head1_key, empty2_key)] = base_result - _results[(tail1_key, empty2_key)] = base_result - _results[(head_tail1_key, empty2_key)] = base_result - key_decomp1[seq1_key] = info1 - - for seq2, info2 in all_decomp2.items(): - seq2_key = hash(seq2) - head2_key, tail2_key, head_tail2_key = all_decomp2[seq2][5:8] - _results[(empty1_key, seq2_key)] = base_result - _results[(empty1_key, head2_key)] = base_result - _results[(empty1_key, tail2_key)] = base_result - _results[(empty1_key, head_tail2_key)] = base_result - key_decomp2[seq2_key] = info2 - - full_seq1_key = hash(full_seq1) - full_seq2_key = hash(full_seq2) - key0 = (full_seq1_key, full_seq2_key) - frame0 = key0, full_seq1, full_seq2 - stack = [frame0] - missing_frames = [] - while stack: - frame = stack.pop() - key, seq1, seq2 = frame - seq1_key, seq2_key = key - if key not in _results: - missing_frames.clear() - - try: - info1 = key_decomp1[seq1_key] - except KeyError: - info1 = decomp_info(seq1, open_to_close) - key_decomp1[seq1_key] = info1 - tok1, seq1, head1, tail1, head_tail1, head1_key, tail1_key, head_tail1_key, a1, b1 = info1 - - try: - info2 = key_decomp2[seq2_key] - except KeyError: - info2 = decomp_info(seq2, open_to_close) - key_decomp2[seq2_key] = info2 - tok2, seq2, head2, tail2, head_tail2, head2_key, tail2_key, head_tail2_key, a2, b2 = info2 - - affinity = node_affinity(tok1, tok2) - - # Case 2: The current edge in sequence1 is deleted - try: - try_key = (head_tail1_key, seq2_key) - cand1 = _results[try_key] - except KeyError: - miss_frame = try_key, head_tail1, seq2 - missing_frames.append(miss_frame) - - # Case 3: The current edge in sequence2 is deleted - try: - try_key = (seq1_key, head_tail2_key) - cand2 = _results[try_key] - except KeyError: - miss_frame = try_key, seq1, head_tail2 - missing_frames.append(miss_frame) - - # Case 1: The LCS involves this edge - if affinity: - try: - try_key = (head1_key, head2_key) - pval_h, new_heads = _results[try_key] - except KeyError: - miss_frame = try_key, head1, head2 - missing_frames.append(miss_frame) - - try: - try_key = (tail1_key, tail2_key) - pval_t, new_tails = _results[try_key] - except KeyError: - miss_frame = try_key, tail1, tail2 - missing_frames.append(miss_frame) - - if not missing_frames: - new_head1, new_head2 = new_heads - new_tail1, new_tail2 = new_tails - - subseq1 = a1 + new_head1 + b1 + new_tail1 - subseq2 = a2 + new_head2 + b2 + new_tail2 - - res3 = (subseq1, subseq2) - val3 = pval_h + pval_t + affinity - cand3 = (val3, res3) - else: - cand3 = (-1, None) - - if missing_frames: - # We did not solve this frame yet - stack.append(frame) - stack.extend(missing_frames[::-1]) - else: - # We solved the frame - _results[key] = max(cand1, cand2, cand3) - - # The stack pop is our solution - (val, best) = _results[key] - found = (best, val) - return found - - -def generate_all_decompositions(seq, open_to_close, open_to_tok=None): - """ - Can doing this a-priori speed up the algorithm? - - open_to_close = {0: 1} - sequence = [0, 0, 0, 1, 1, 1, 0, 1] - open_to_close = {'{': '}', '(': ')', '[': ']'} - seq = '({[[]]})[[][]]{{}}' - pop_open, pop_close, head, tail = balanced_decomp(seq, open_to_close) - - >>> tree = random_ordered_tree(10) - >>> seq, open_to_close, toks = tree_to_balanced_sequence(tree) - >>> all_decomp = generate_all_decompositions(seq, open_to_close) - """ - if open_to_tok is None: - class Dummy: - def __getitem__(self, key): - return key - open_to_tok = Dummy() - _memo = {} - def _gen(seq): - if not seq: - pass - # yield None - elif seq in _memo: - pass - # yield (seq, _memo[seq]) - else: - pop_open, pop_close, head, tail = balanced_decomp(seq, open_to_close) - head_tail = head + tail - tok = open_to_tok[pop_open[0]] - _memo[seq] = (tok, pop_open, pop_close, head, tail, head_tail) - yield (seq, _memo[seq]) - yield from _gen(head_tail) - yield from _gen(head) - yield from _gen(tail) - all_decomp = dict(_gen(seq)) - return all_decomp - - -@profile -def _lcs_recurse(seq1, seq2, open_to_close, node_affinity, open_to_tok, _memo, _seq_memo): - if not seq1: - return (seq1, seq1), 0 - elif not seq2: - return (seq2, seq2), 0 - else: - # if len(seq2) < len(seq1): - # seq1, seq2 = seq2, seq1 - # key = (seq1, seq2) - key1 = hash(seq1) # using hash(seq) is faster than seq itself - key2 = hash(seq2) - key = hash((key1, key2)) - if key in _memo: - return _memo[key] - - # TODO: we can probably just do a single linear run through the - # sequences to index the sub-sequence locations and then apply an - # offset when we run the decomposed sequence. - if key1 in _seq_memo: - a1, b1, head1, tail1, head1_tail1 = _seq_memo[key1] - else: - a1, b1, head1, tail1 = balanced_decomp_unsafe(seq1, open_to_close) - head1_tail1 = head1 + tail1 - _seq_memo[key1] = a1, b1, head1, tail1, head1_tail1 - - if key2 in _seq_memo: - a2, b2, head2, tail2, head2_tail2 = _seq_memo[key2] - else: - a2, b2, head2, tail2 = balanced_decomp_unsafe(seq2, open_to_close) - head2_tail2 = head2 + tail2 - _seq_memo[key2] = a2, b2, head2, tail2, head2_tail2 - - # Case 2: The current edge in sequence1 is deleted - best, val = _lcs_recurse(head1_tail1, seq2, open_to_close, node_affinity, open_to_tok, _memo, _seq_memo) - - # Case 3: The current edge in sequence2 is deleted - cand, val_alt = _lcs_recurse(seq1, head2_tail2, open_to_close, node_affinity, open_to_tok, _memo, _seq_memo) - if val_alt > val: - best = cand - val = val_alt - - # Case 1: The LCS involves this edge - t1 = open_to_tok[a1[0]] - t2 = open_to_tok[a2[0]] - # if node_affinity(a1[0], a2[0]): - affinity = node_affinity(t1, t2) - if affinity: - new_heads, pval_h = _lcs_recurse(head1, head2, open_to_close, node_affinity, open_to_tok, _memo, _seq_memo) - new_tails, pval_t = _lcs_recurse(tail1, tail2, open_to_close, node_affinity, open_to_tok, _memo, _seq_memo) - - new_head1, new_head2 = new_heads - new_tail1, new_tail2 = new_tails - - subseq1 = a1 + new_head1 + b1 + new_tail1 - subseq2 = a2 + new_head2 + b2 + new_tail2 - - cand = (subseq1, subseq2) - val_alt = pval_h + pval_t + affinity - if val_alt > val: - best = cand - val = val_alt - - found = (best, val) - _memo[key] = found - return found - - -class UnbalancedException(Exception): - pass - - -def balanced_decomp(sequence, open_to_close): - """ - Note this is not exactly the same as the decomposition in the paper. - That is because we also return the "wrapping" element, and we let the - user do the head + tail concatenation. - - Example: - >>> open_to_close = {0: 1} - >>> sequence = [0, 0, 0, 1, 1, 1, 0, 1] - >>> open_to_close = {'{': '}', '(': ')', '[': ']'} - >>> sequence = '({[[]]})[[][]]' - >>> a1, b1, head, tail = balanced_decomp(sequence, open_to_close) - >>> a2, b2, tail1, tail2 = balanced_decomp(tail, open_to_close) - """ - gen = generate_balance(sequence, open_to_close) - - bal_curr, tok_curr = next(gen) - pop_open = sequence[0:1] - want_close = open_to_close[tok_curr] - - head_stop = 1 - for head_stop, (bal_curr, tok_curr) in enumerate(gen, start=1): - if tok_curr is None: - break - elif bal_curr and tok_curr == want_close: - pop_close = sequence[head_stop:head_stop + 1] - break - head = sequence[1:head_stop] - # if __debug__: - # list(gen) # exhaust the generator to check we are balanced - tail = sequence[head_stop + 1:] - return pop_open, pop_close, head, tail - - -def tree_to_balanced_sequence(tree, open_to_close=None, toks=None, mode='tuple', strhack=False): - from collections import namedtuple - Token = namedtuple('Token', ['action', 'value']) - # mapping between opening and closing tokens - sources = [n for n in tree.nodes if tree.in_degree[n] == 0] - sequence = [] - - if open_to_close is None: - open_to_close = {} - if toks is None: - toks = {} - - if strhack: - if mode == 'label': - all_labels = {n['label'] for n in list(tree.nodes.values())} - assert all(x == 1 for x in map(len, all_labels)) - - for source in sources: - for u, v, etype in nx.dfs_labeled_edges(tree, source=source): - if etype == 'forward': - # u has been visited by v has not - if v not in toks: - if mode == 'tuple': - # TODO: token encoding scheme where subdirectories - # are matchable via a custom operation. - # open_tok = '<{}>'.format(v) - # close_tok = ''.format(v) - open_tok = Token('open', v) - close_tok = Token('close', v) - elif mode == 'number': - open_tok = len(toks) + 1 - close_tok = -open_tok - elif mode == 'paren': - open_tok = '{}('.format(v) - close_tok = '){}'.format(v) - elif mode == 'chr': - open_tok = str(v) - close_tok = str(v) + u'\u0301' - elif mode == 'label': - open_tok = tree.nodes[v]['label'] - assert strhack - if open_tok == '{': - close_tok = '}' - if open_tok == '[': - close_tok = ']' - if open_tok == '(': - close_tok = ')' - toks[v] = open_tok - open_to_close[open_tok] = close_tok - open_tok = toks[v] - sequence.append(open_tok) - elif etype == 'reverse': - # Both u and v are visited and the edge is in the tree - close_tok = open_to_close[toks[v]] - sequence.append(close_tok) - else: - raise KeyError(etype) - sequence = tuple(sequence) - if strhack: - sequence = ''.join(sequence) - return sequence, open_to_close, toks - - -def seq_to_tree(subseq, open_to_close, toks): - open_to_tok = ub.invert_dict(toks) - subtree = nx.OrderedDiGraph() - stack = [] - for token in subseq: - if token in open_to_close: - node = open_to_tok[token] - if stack: - parent = open_to_tok[stack[-1]] - subtree.add_edge(parent, node) - else: - subtree.add_node(node) - stack.append(token) - else: - if not stack: - raise Exception - prev_open = stack.pop() - want_close = open_to_close[prev_open] - if token != want_close: - raise Exception - return subtree - - -def random_ordered_tree(n, seed=None, pool=None): - import kwarray - rng = kwarray.ensure_rng(seed, 'python') - tree = nx.dfs_tree(nx.random_tree(n, seed=seed)) - otree = nx.OrderedDiGraph() - otree.add_edges_from(tree.edges) - if pool is not None: - for node in otree.nodes: - otree.nodes[node]['label'] = rng.choice(pool) - return otree - - -def generate_balance_unsafe(sequence, open_to_close): - """ - Benchmark: - >>> tree = random_ordered_tree(1000) - >>> sequence, open_to_close, toks = tree_to_balanced_sequence(tree, mode='tuple') - >>> sequence, open_to_close, toks = tree_to_balanced_sequence(tree, mode='number') - >>> import timerit - >>> ti = timerit.Timerit(100, bestof=10, verbose=2) - >>> for timer in ti.reset('time'): - >>> with timer: - >>> list(generate_balance_unsafe(sequence, open_to_close)) - >>> import timerit - >>> ti = timerit.Timerit(100, bestof=10, verbose=2) - >>> for timer in ti.reset('time'): - >>> with timer: - >>> list(generate_balance_unsafe_cython(sequence, open_to_close)) - """ - stacklen = 0 - for token in sequence: - if token in open_to_close: - stacklen += 1 - else: - stacklen -= 1 - yield stacklen == 0, token - - -def balanced_decomp_unsafe(sequence, open_to_close): - """ - Example: - >>> open_to_close = {'{': '}', '(': ')', '[': ']'} - >>> sequence = '({[[]]})[[][]]' - >>> print('sequence = {!r}'.format(sequence)) - >>> a1, b1, head, tail = balanced_decomp(sequence, open_to_close) - >>> print('a1 = {!r}'.format(a1)) - >>> print('tail = {!r}'.format(tail)) - >>> print('head = {!r}'.format(head)) - >>> a2, b2, tail1, tail2 = balanced_decomp(tail, open_to_close) - >>> print('a2 = {!r}'.format(a2)) - >>> print('tail1 = {!r}'.format(tail1)) - >>> print('tail2 = {!r}'.format(tail2)) - """ - gen = generate_balance_unsafe(sequence, open_to_close) - - bal_curr, tok_curr = next(gen) - pop_open = sequence[0:1] - want_close = open_to_close[tok_curr] - - head_stop = 1 - for head_stop, (bal_curr, tok_curr) in enumerate(gen, start=1): - if bal_curr and tok_curr == want_close: - pop_close = sequence[head_stop:head_stop + 1] - break - head = sequence[1:head_stop] - tail = sequence[head_stop + 1:] - return pop_open, pop_close, head, tail - - -def generate_balance(sequence, open_to_close): - """ - Safe version - - Example: - >>> open_to_close = {0: 1} - >>> sequence = [0, 0, 0, 1, 1, 1] - >>> gen = list(generate_balance(sequence, open_to_close)) - >>> for flag, token in gen: - >>> print('flag={:d}, token={}'.format(flag, token)) - - Example: - >>> tree = random_ordered_tree(1000) - >>> sequence, open_to_close, toks = tree_to_balanced_sequence(tree) - >>> gen = list(generate_balance(sequence, open_to_close)) - >>> for flag, token in gen: - >>> print('flag={:d}, token={}'.format(flag, token)) - """ - stack = [] - # Traversing the Expression - for token in sequence: - - if token in open_to_close: - # Push opening elements onto the stack - stack.append(token) - else: - # Check that closing elements - if not stack: - raise UnbalancedException - prev_open = stack.pop() - want_close = open_to_close[prev_open] - - if token != want_close: - raise UnbalancedException - - # If the stack is empty the sequence is currently balanced - currently_balanced = not bool(stack) - yield currently_balanced, token - - if stack: - raise UnbalancedException - - -def _print_forest(graph): - """ - Nice ascii representation of a forest - - Ignore: - graph = nx.balanced_tree(r=2, h=3, create_using=nx.DiGraph) - _print_forest(graph) - - graph = CategoryTree.demo('coco').graph - _print_forest(graph) - """ - if len(graph.nodes) == 0: - print('--') - return - assert nx.is_forest(graph) - - def _recurse(node, indent='', islast=False): - if islast: - this_prefix = indent + '└── ' - next_prefix = indent + ' ' - else: - this_prefix = indent + '├── ' - next_prefix = indent + '│   ' - label = graph.nodes[node].get('label', node) - print(this_prefix + str(label)) - graph.succ[node] - children = graph.succ[node] - for idx, child in enumerate(children, start=1): - islast_next = (idx == len(children)) - _recurse(child, indent=next_prefix, islast=islast_next) - - sources = [n for n in graph.nodes if graph.in_degree[n] == 0] - for idx, node in enumerate(sources, start=1): - islast_next = (idx == len(sources)) - _recurse(node, indent='', islast=islast_next) - - -__notes_ = """ - - # if 0: - # tuples = [(i + 1, i + 2, i + 3,) for i in range(4)] - # import timerit - - # ti = timerit.Timerit(100, bestof=10, verbose=2) - # import itertools as it - # for timer in ti.reset('time'): - # with timer: - # tuple(it.chain.from_iterable(tuples)) - # for timer in ti.reset('time'): - # with timer: - # res = tuples[0] - # for a in tuples[1:]: - # res = res + a - -""" diff --git a/netharn/initializers/bseq2.py b/netharn/initializers/bseq2.py deleted file mode 100644 index 1b26842d0c1f2cb1f6cabb7e290a861613bea51a..0000000000000000000000000000000000000000 --- a/netharn/initializers/bseq2.py +++ /dev/null @@ -1,612 +0,0 @@ -import operator -import ubelt as ub -import networkx as nx - -try: - import xdev - profile = xdev.profile -except Exception: - profile = ub.identity - - -def longest_common_balanced_sequence(seq1, seq2, open_to_close, open_to_tok=None, node_affinity='auto', impl='iter'): - """ - CommandLine: - xdoctest -m /home/joncrall/code/netharn/netharn/initializers/balanced_sequence.py longest_common_balanced_sequence:0 --profile && cat profile_output.txt - - Example: - >>> from netharn.initializers.balanced_sequence import * # NOQA - >>> tree1 = random_ordered_tree(5, seed=10, pool='[{(') - >>> tree2 = random_ordered_tree(5, seed=3, pool='[{(') - - >>> import kwarray - >>> rng = kwarray.ensure_rng(None, 'python') - >>> tree1 = random_ordered_tree(100, seed=rng, pool='[{(') - >>> tree2 = random_ordered_tree(200, seed=rng, pool='[{(') - >>> if len(tree1.nodes) < 20: - >>> _print_forest(tree1) - >>> _print_forest(tree2) - >>> seq1, open_to_close, toks = tree_to_balanced_sequence(tree1, mode='label', strhack=1) - >>> seq2, open_to_close, toks = tree_to_balanced_sequence(tree2, open_to_close, toks, mode='label', strhack=1) - >>> full_seq1 = seq1 - >>> full_seq2 = seq2 - >>> print('seq1 = {!r}'.format(seq1)) - >>> print('seq2 = {!r}'.format(seq2)) - >>> open_to_tok = ub.invert_dict(toks) - >>> with ub.Timer('recursive'): - >>> best2, val2 = longest_common_balanced_sequence(seq1, seq2, open_to_close, open_to_tok, impl='recurse') - >>> print('val2, best2 = {}, {!r}'.format(val2, best2)) - >>> with ub.Timer('iterative'): - >>> best1, val1 = longest_common_balanced_sequence(seq1, seq2, open_to_close, open_to_tok, impl='iter') - >>> print('val1, best1 = {}, {!r}'.format(val1, best1)) - """ - if node_affinity == 'auto' or node_affinity == 'eq': - node_affinity = operator.eq - if node_affinity is None: - def _matchany(a, b): - return True - node_affinity = _matchany - _memo = {} - _seq_memo = {} - if open_to_tok is None: - class Dummy: - def __getitem__(self, key): - return key - open_to_tok = Dummy() - full_seq1 = seq1 - full_seq2 = seq2 - if impl == 'recurse': - best, value = _lcs_recurse(full_seq1, full_seq2, open_to_close, node_affinity, open_to_tok, _memo, _seq_memo) - elif impl == 'iter': - best, value = _lcs_iter(full_seq1, full_seq2, open_to_close, node_affinity, open_to_tok) - else: - raise KeyError(impl) - return best, value - - -@profile -def _lcs_iter(full_seq1, full_seq2, open_to_close, node_affinity, open_to_tok): - def decomp_info(seq, open_to_close): - pop_open, pop_close, head, tail = balanced_decomp_unsafe(seq, open_to_close) - head_tail = head + tail - head_key = hash(head) - tail_key = hash(tail) - head_tail_key = hash(head_tail) - tok = open_to_tok[pop_open[0]] - a = pop_open - b = pop_close - info = (tok, seq, head, tail, head_tail, head_key, tail_key, head_tail_key, a, b) - return info - - def gen_decomp_v2(seq, open_to_close): - _genmemo = {} - def _gen(seq): - if seq: - key = hash(seq) - if key not in _genmemo: - info = decomp_info(seq, open_to_close) - head, tail, head_tail = info[2:5] - _genmemo[key] = info - yield (seq, _genmemo[key]) - yield from _gen(head_tail) - yield from _gen(head) - yield from _gen(tail) - all_decomp = dict(_gen(seq)) - return all_decomp - - all_decomp1 = gen_decomp_v2(full_seq1, open_to_close) - all_decomp2 = gen_decomp_v2(full_seq2, open_to_close) - - key_decomp1 = {} - key_decomp2 = {} - _results = {} - # Populate base cases - empty1 = type(ub.peek(all_decomp1.keys()))() - empty2 = type(ub.peek(all_decomp2.keys()))() - empty1_key = hash(empty1) - empty2_key = hash(empty2) - best = (empty1, empty2) - base_result = (0, best) - for seq1, info1 in all_decomp1.items(): - seq1_key = hash(seq1) - head1_key, tail1_key, head_tail1_key = all_decomp1[seq1][5:8] - _results[(seq1_key, empty2_key)] = base_result - _results[(head1_key, empty2_key)] = base_result - _results[(tail1_key, empty2_key)] = base_result - _results[(head_tail1_key, empty2_key)] = base_result - key_decomp1[seq1_key] = info1 - - for seq2, info2 in all_decomp2.items(): - seq2_key = hash(seq2) - head2_key, tail2_key, head_tail2_key = all_decomp2[seq2][5:8] - _results[(empty1_key, seq2_key)] = base_result - _results[(empty1_key, head2_key)] = base_result - _results[(empty1_key, tail2_key)] = base_result - _results[(empty1_key, head_tail2_key)] = base_result - key_decomp2[seq2_key] = info2 - - full_seq1_key = hash(full_seq1) - full_seq2_key = hash(full_seq2) - key0 = (full_seq1_key, full_seq2_key) - frame0 = key0, full_seq1, full_seq2 - stack = [frame0] - missing_frames = [] - num_misses = 0 - while stack: - frame = stack.pop() - key, seq1, seq2 = frame - seq1_key, seq2_key = key - if key not in _results: - missing_frames.clear() - - try: - info1 = key_decomp1[seq1_key] - except KeyError: - info1 = decomp_info(seq1, open_to_close) - key_decomp1[seq1_key] = info1 - tok1, seq1, head1, tail1, head_tail1, head1_key, tail1_key, head_tail1_key, a1, b1 = info1 - - try: - info2 = key_decomp2[seq2_key] - except KeyError: - info2 = decomp_info(seq2, open_to_close) - key_decomp2[seq2_key] = info2 - tok2, seq2, head2, tail2, head_tail2, head2_key, tail2_key, head_tail2_key, a2, b2 = info2 - - affinity = node_affinity(tok1, tok2) - - # Case 2: The current edge in sequence1 is deleted - try: - try_key = (head_tail1_key, seq2_key) - cand1 = _results[try_key] - except KeyError: - miss_frame = try_key, head_tail1, seq2 - missing_frames.append(miss_frame) - - # Case 3: The current edge in sequence2 is deleted - try: - try_key = (seq1_key, head_tail2_key) - cand2 = _results[try_key] - except KeyError: - miss_frame = try_key, seq1, head_tail2 - missing_frames.append(miss_frame) - - # Case 1: The LCS involves this edge - if affinity: - try: - try_key = (head1_key, head2_key) - pval_h, new_heads = _results[try_key] - except KeyError: - miss_frame = try_key, head1, head2 - missing_frames.append(miss_frame) - - try: - try_key = (tail1_key, tail2_key) - pval_t, new_tails = _results[try_key] - except KeyError: - miss_frame = try_key, tail1, tail2 - missing_frames.append(miss_frame) - - if not missing_frames: - new_head1, new_head2 = new_heads - new_tail1, new_tail2 = new_tails - - subseq1 = a1 + new_head1 + b1 + new_tail1 - subseq2 = a2 + new_head2 + b2 + new_tail2 - - res3 = (subseq1, subseq2) - val3 = pval_h + pval_t + affinity - cand3 = (val3, res3) - else: - cand3 = (-1, None) - - if missing_frames: - num_misses += 1 - # We did not solve this frame yet - stack.append(frame) - stack.extend(missing_frames[::-1]) - else: - # We solved the frame - _results[key] = max(cand1, cand2, cand3) - - print('num_misses = {!r}'.format(num_misses)) - - # The stack pop is our solution - (val, best) = _results[key] - found = (best, val) - return found - - -def generate_all_decompositions(seq, open_to_close): - """ - Can doing this a-priori speed up the algorithm? - - open_to_close = {0: 1} - sequence = [0, 0, 0, 1, 1, 1, 0, 1] - open_to_close = {'{': '}', '(': ')', '[': ']'} - seq = '({[[]]})[[][]]{{}}' - pop_open, pop_close, head, tail = balanced_decomp(seq, open_to_close) - - >>> tree = random_ordered_tree(1000) - >>> seq, open_to_close, toks = tree_to_balanced_sequence(tree) - >>> all_decomp = _generate_all_decompositions(seq, open_to_close) - """ - _memo = {} - def _gen(seq): - if not seq: - pass - # yield None - elif seq in _memo: - pass - # yield (seq, _memo[seq]) - else: - pop_open, pop_close, head, tail = balanced_decomp(seq, open_to_close) - head_tail = head + tail - _memo[seq] = (pop_open, pop_close, head, tail, head_tail) - yield (seq, _memo[seq]) - yield from _gen(head_tail) - yield from _gen(head) - yield from _gen(tail) - all_decomp = dict(_gen(seq)) - return all_decomp - - -@profile -def _lcs_recurse(seq1, seq2, open_to_close, node_affinity, open_to_tok, _memo, _seq_memo): - if not seq1: - return (seq1, seq1), 0 - elif not seq2: - return (seq2, seq2), 0 - else: - # if len(seq2) < len(seq1): - # seq1, seq2 = seq2, seq1 - # key = (seq1, seq2) - key1 = hash(seq1) # using hash(seq) is faster than seq itself - key2 = hash(seq2) - key = hash((key1, key2)) - if key in _memo: - return _memo[key] - - # TODO: we can probably just do a single linear run through the - # sequences to index the sub-sequence locations and then apply an - # offset when we run the decomposed sequence. - if key1 in _seq_memo: - a1, b1, head1, tail1, head1_tail1 = _seq_memo[key1] - else: - a1, b1, head1, tail1 = balanced_decomp_unsafe(seq1, open_to_close) - head1_tail1 = head1 + tail1 - _seq_memo[key1] = a1, b1, head1, tail1, head1_tail1 - - if key2 in _seq_memo: - a2, b2, head2, tail2, head2_tail2 = _seq_memo[key2] - else: - a2, b2, head2, tail2 = balanced_decomp_unsafe(seq2, open_to_close) - head2_tail2 = head2 + tail2 - _seq_memo[key2] = a2, b2, head2, tail2, head2_tail2 - - # Case 2: The current edge in sequence1 is deleted - best, val = _lcs_recurse(head1_tail1, seq2, open_to_close, node_affinity, open_to_tok, _memo, _seq_memo) - - # Case 3: The current edge in sequence2 is deleted - cand, val_alt = _lcs_recurse(seq1, head2_tail2, open_to_close, node_affinity, open_to_tok, _memo, _seq_memo) - if val_alt > val: - best = cand - val = val_alt - - # Case 1: The LCS involves this edge - t1 = open_to_tok[a1[0]] - t2 = open_to_tok[a2[0]] - # if node_affinity(a1[0], a2[0]): - affinity = node_affinity(t1, t2) - if affinity: - new_heads, pval_h = _lcs_recurse(head1, head2, open_to_close, node_affinity, open_to_tok, _memo, _seq_memo) - new_tails, pval_t = _lcs_recurse(tail1, tail2, open_to_close, node_affinity, open_to_tok, _memo, _seq_memo) - - new_head1, new_head2 = new_heads - new_tail1, new_tail2 = new_tails - - subseq1 = a1 + new_head1 + b1 + new_tail1 - subseq2 = a2 + new_head2 + b2 + new_tail2 - - cand = (subseq1, subseq2) - val_alt = pval_h + pval_t + affinity - if val_alt > val: - best = cand - val = val_alt - - found = (best, val) - _memo[key] = found - return found - - -class UnbalancedException(Exception): - pass - - -def balanced_decomp(sequence, open_to_close): - """ - Note this is not exactly the same as the decomposition in the paper. - That is because we also return the "wrapping" element, and we let the - user do the head + tail concatenation. - - Example: - >>> open_to_close = {0: 1} - >>> sequence = [0, 0, 0, 1, 1, 1, 0, 1] - >>> open_to_close = {'{': '}', '(': ')', '[': ']'} - >>> sequence = '({[[]]})[[][]]' - >>> a1, b1, head, tail = balanced_decomp(sequence, open_to_close) - >>> a2, b2, tail1, tail2 = balanced_decomp(tail, open_to_close) - """ - gen = generate_balance(sequence, open_to_close) - - bal_curr, tok_curr = next(gen) - pop_open = sequence[0:1] - want_close = open_to_close[tok_curr] - - head_stop = 1 - for head_stop, (bal_curr, tok_curr) in enumerate(gen, start=1): - if tok_curr is None: - break - elif bal_curr and tok_curr == want_close: - pop_close = sequence[head_stop:head_stop + 1] - break - head = sequence[1:head_stop] - # if __debug__: - # list(gen) # exhaust the generator to check we are balanced - tail = sequence[head_stop + 1:] - return pop_open, pop_close, head, tail - - -def tree_to_balanced_sequence(tree, open_to_close=None, toks=None, mode='tuple', strhack=False): - from collections import namedtuple - Token = namedtuple('Token', ['action', 'value']) - # mapping between opening and closing tokens - sources = [n for n in tree.nodes if tree.in_degree[n] == 0] - sequence = [] - - if open_to_close is None: - open_to_close = {} - if toks is None: - toks = {} - - if strhack: - if mode == 'label': - all_labels = {n['label'] for n in list(tree.nodes.values())} - assert all(x == 1 for x in map(len, all_labels)) - - for source in sources: - for u, v, etype in nx.dfs_labeled_edges(tree, source=source): - if etype == 'forward': - # u has been visited by v has not - if v not in toks: - if mode == 'tuple': - # TODO: token encoding scheme where subdirectories - # are matchable via a custom operation. - # open_tok = '<{}>'.format(v) - # close_tok = ''.format(v) - open_tok = Token('open', v) - close_tok = Token('close', v) - elif mode == 'number': - open_tok = len(toks) + 1 - close_tok = -open_tok - elif mode == 'paren': - open_tok = '{}('.format(v) - close_tok = '){}'.format(v) - elif mode == 'chr': - open_tok = str(v) - close_tok = str(v) + u'\u0301' - elif mode == 'label': - open_tok = tree.nodes[v]['label'] - assert strhack - if open_tok == '{': - close_tok = '}' - if open_tok == '[': - close_tok = ']' - if open_tok == '(': - close_tok = ')' - toks[v] = open_tok - open_to_close[open_tok] = close_tok - open_tok = toks[v] - sequence.append(open_tok) - elif etype == 'reverse': - # Both u and v are visited and the edge is in the tree - close_tok = open_to_close[toks[v]] - sequence.append(close_tok) - else: - raise KeyError(etype) - sequence = tuple(sequence) - if strhack: - sequence = ''.join(sequence) - return sequence, open_to_close, toks - - -def seq_to_tree(subseq, open_to_close, toks): - open_to_tok = ub.invert_dict(toks) - subtree = nx.OrderedDiGraph() - stack = [] - for token in subseq: - if token in open_to_close: - node = open_to_tok[token] - if stack: - parent = open_to_tok[stack[-1]] - subtree.add_edge(parent, node) - else: - subtree.add_node(node) - stack.append(token) - else: - if not stack: - raise Exception - prev_open = stack.pop() - want_close = open_to_close[prev_open] - if token != want_close: - raise Exception - return subtree - - -def random_ordered_tree(n, seed=None, pool=None): - import kwarray - rng = kwarray.ensure_rng(seed, 'python') - tree = nx.dfs_tree(nx.random_tree(n, seed=seed)) - otree = nx.OrderedDiGraph() - otree.add_edges_from(tree.edges) - if pool is not None: - for node in otree.nodes: - otree.nodes[node]['label'] = rng.choice(pool) - return otree - - -def generate_balance_unsafe(sequence, open_to_close): - """ - Benchmark: - >>> tree = random_ordered_tree(1000) - >>> sequence, open_to_close, toks = tree_to_balanced_sequence(tree, mode='tuple') - >>> sequence, open_to_close, toks = tree_to_balanced_sequence(tree, mode='number') - >>> import timerit - >>> ti = timerit.Timerit(100, bestof=10, verbose=2) - >>> for timer in ti.reset('time'): - >>> with timer: - >>> list(generate_balance_unsafe(sequence, open_to_close)) - >>> import timerit - >>> ti = timerit.Timerit(100, bestof=10, verbose=2) - >>> for timer in ti.reset('time'): - >>> with timer: - >>> list(generate_balance_unsafe_cython(sequence, open_to_close)) - """ - stacklen = 0 - for token in sequence: - if token in open_to_close: - stacklen += 1 - else: - stacklen -= 1 - yield stacklen == 0, token - - -def balanced_decomp_unsafe(sequence, open_to_close): - """ - Example: - >>> open_to_close = {'{': '}', '(': ')', '[': ']'} - >>> sequence = '({[[]]})[[][]]' - >>> print('sequence = {!r}'.format(sequence)) - >>> a1, b1, head, tail = balanced_decomp(sequence, open_to_close) - >>> print('a1 = {!r}'.format(a1)) - >>> print('tail = {!r}'.format(tail)) - >>> print('head = {!r}'.format(head)) - >>> a2, b2, tail1, tail2 = balanced_decomp(tail, open_to_close) - >>> print('a2 = {!r}'.format(a2)) - >>> print('tail1 = {!r}'.format(tail1)) - >>> print('tail2 = {!r}'.format(tail2)) - """ - gen = generate_balance_unsafe(sequence, open_to_close) - - bal_curr, tok_curr = next(gen) - pop_open = sequence[0:1] - want_close = open_to_close[tok_curr] - - head_stop = 1 - for head_stop, (bal_curr, tok_curr) in enumerate(gen, start=1): - if bal_curr and tok_curr == want_close: - pop_close = sequence[head_stop:head_stop + 1] - break - head = sequence[1:head_stop] - tail = sequence[head_stop + 1:] - return pop_open, pop_close, head, tail - - -def generate_balance(sequence, open_to_close): - """ - Safe version - - Example: - >>> open_to_close = {0: 1} - >>> sequence = [0, 0, 0, 1, 1, 1] - >>> gen = list(generate_balance(sequence, open_to_close)) - >>> for flag, token in gen: - >>> print('flag={:d}, token={}'.format(flag, token)) - - Example: - >>> tree = random_ordered_tree(1000) - >>> sequence, open_to_close, toks = tree_to_balanced_sequence(tree) - >>> gen = list(generate_balance(sequence, open_to_close)) - >>> for flag, token in gen: - >>> print('flag={:d}, token={}'.format(flag, token)) - """ - stack = [] - # Traversing the Expression - for token in sequence: - - if token in open_to_close: - # Push opening elements onto the stack - stack.append(token) - else: - # Check that closing elements - if not stack: - raise UnbalancedException - prev_open = stack.pop() - want_close = open_to_close[prev_open] - - if token != want_close: - raise UnbalancedException - - # If the stack is empty the sequence is currently balanced - currently_balanced = not bool(stack) - yield currently_balanced, token - - if stack: - raise UnbalancedException - - -def _print_forest(graph): - """ - Nice ascii representation of a forest - - Ignore: - graph = nx.balanced_tree(r=2, h=3, create_using=nx.DiGraph) - _print_forest(graph) - - graph = CategoryTree.demo('coco').graph - _print_forest(graph) - """ - if len(graph.nodes) == 0: - print('--') - return - assert nx.is_forest(graph) - - def _recurse(node, indent='', islast=False): - if islast: - this_prefix = indent + '└── ' - next_prefix = indent + ' ' - else: - this_prefix = indent + '├── ' - next_prefix = indent + '│   ' - label = graph.nodes[node].get('label', node) - print(this_prefix + str(label)) - graph.succ[node] - children = graph.succ[node] - for idx, child in enumerate(children, start=1): - islast_next = (idx == len(children)) - _recurse(child, indent=next_prefix, islast=islast_next) - - sources = [n for n in graph.nodes if graph.in_degree[n] == 0] - for idx, node in enumerate(sources, start=1): - islast_next = (idx == len(sources)) - _recurse(node, indent='', islast=islast_next) - - -__notes_ = """ - - # if 0: - # tuples = [(i + 1, i + 2, i + 3,) for i in range(4)] - # import timerit - - # ti = timerit.Timerit(100, bestof=10, verbose=2) - # import itertools as it - # for timer in ti.reset('time'): - # with timer: - # tuple(it.chain.from_iterable(tuples)) - # for timer in ti.reset('time'): - # with timer: - # res = tuples[0] - # for a in tuples[1:]: - # res = res + a - -""" diff --git a/netharn/initializers/functional.py b/netharn/initializers/functional.py index 1c488c509a425b91566922bd03cb5f216b1c1b6b..c5599dd67dd4ca9f7f7dffbf4d84b906258a00d7 100644 --- a/netharn/initializers/functional.py +++ b/netharn/initializers/functional.py @@ -300,6 +300,37 @@ def load_partial_state(model, model_state_dict, leftover=None, """ other_keys = set(model_state_dict) self_keys = set(self_state) + + if 0: + # Automatic way to reduce nodes in the trees? + # If node b always follows node a, can we contract it? + nodes1 = [n for p in other_keys for n in p.split('.')] + nodes2 = [n for p in self_keys for n in p.split('.')] + tups1 = list(tup for key in other_keys for tup in ub.iter_window(key.split('.'), 2)) + tups2 = list(tup for key in self_keys for tup in ub.iter_window(key.split('.'), 2)) + x = ub.ddict(list) + for a, b in tups1: + x[a].append(b) + for a, b in tups2: + x[a].append(b) + + nodehist = ub.dict_hist(nodes1 + nodes2) + + for k, v in x.items(): + print('----') + print(k) + print(nodehist[k]) + follow_hist = ub.dict_hist(v) + print(follow_hist) + total = sum(follow_hist.values()) + if ub.allsame(follow_hist.values()) and total == nodehist[k]: + print('CONTRACT') + + # pair_freq = ub.dict_hist(ub.flatten([tups1, tups2])) + from netharn.initializers._nx_ext.tree_embedding import forest_str + from netharn.initializers._nx_ext.path_embedding import paths_to_otree + print(forest_str(paths_to_otree(other_keys, '.'))) + common_keys = other_keys.intersection(self_keys) if not common_keys: if association == 'strict': @@ -343,7 +374,43 @@ def load_partial_state(model, model_state_dict, leftover=None, # I believe this is the correct way to solve the problem paths1 = sorted(other_keys) paths2 = sorted(self_state) - subpaths1, subpaths2 = maximum_common_ordered_subpaths(paths1, paths2) + + if 1: + # hack to filter to reduce tree size in embedding problem + def shrink_paths(paths): + new_paths = [] + for p in paths: + p = p.replace('.0', ':0') + p = p.replace('.1', ':1') + p = p.replace('.2', ':2') + p = p.replace('.3', ':3') + p = p.replace('.4', ':4') + p = p.replace('.5', ':5') + p = p.replace('.6', ':6') + p = p.replace('.7', ':7') + p = p.replace('.8', ':8') + p = p.replace('.9', ':9') + p = p.replace('.weight', ':weight') + p = p.replace('.bias', ':bias') + p = p.replace('.num_batches_tracked', ':num_batches_tracked') + p = p.replace('.running_mean', ':running_mean') + p = p.replace('.running_var', ':running_var') + p = p.replace('.conv1', ':conv1') + p = p.replace('.conv2', ':conv2') + p = p.replace('.conv3', ':conv3') + p = p.replace('.bn1', ':bn1') + p = p.replace('.bn2', ':bn2') + p = p.replace('.bn3', ':bn3') + new_paths.append(p) + return new_paths + + paths1_ = shrink_paths(paths1) + paths2_ = shrink_paths(paths2) + + # Reducing the depth saves a lot of time + subpaths1, subpaths2 = maximum_common_ordered_subpaths(paths1_, paths2_, sep='.') + subpaths1 = [p.replace(':', '.') for p in subpaths1] + subpaths2 = [p.replace(':', '.') for p in subpaths2] mapping = ub.dzip(subpaths1, subpaths2) if verbose > 1: print('mapping = {}'.format(ub.repr2(mapping, nl=1))) @@ -603,7 +670,7 @@ def maximum_common_ordered_subpaths(paths1, paths2, sep='.'): Example: >>> import torchvision >>> resnet50 = torchvision.models.resnet50() - >>> paths1 = sorted(resnet50.state_dict().keys())[0:100] + >>> paths1 = sorted(resnet50.state_dict().keys()) >>> paths2 = ['prefix.' + k for k in paths1] >>> paths2.append('extra_key') >>> subpaths1, subpaths2 = maximum_common_ordered_subpaths(paths1, paths2) @@ -710,17 +777,20 @@ def maximum_common_ordered_subpaths(paths1, paths2, sep='.'): tree1 = paths_to_tree(paths1) tree2 = paths_to_tree(paths2) - # _print_forest(tree1) - # _print_forest(tree2) + # from netharn.initializers._nx_ext.tree_embedding import forest_str + print(len(tree1.nodes)) + print(len(tree2.nodes)) + # print(forest_str(tree1)) + # print(forest_str(tree2)) # if 0: # DiGM = isomorphism.DiGraphMatcher(tree1, tree2) # DiGM.is_isomorphic() # list(DiGM.subgraph_isomorphisms_iter()) - from netharn.initializers import _nx_extensions - subtree1, subtree2 = _nx_extensions.maximum_common_ordered_tree_embedding(tree1, tree2, node_affinity=node_affinity) - # subtree1, subtree2 = _nx_extensions.maximum_common_ordered_subtree_isomorphism(tree1, tree2, node_affinity=node_affinity) + from netharn.initializers import _nx_ext + subtree1, subtree2 = _nx_ext.maximum_common_ordered_tree_embedding(tree1, tree2, node_affinity=node_affinity) + # subtree1, subtree2 = _nx_ext.maximum_common_ordered_subtree_isomorphism(tree1, tree2, node_affinity=node_affinity) subpaths1 = [sep.join(node) for node in subtree1.nodes if subtree1.out_degree[node] == 0] subpaths2 = [sep.join(node) for node in subtree2.nodes if subtree2.out_degree[node] == 0] diff --git a/netharn/initializers/pretrained.py b/netharn/initializers/pretrained.py index 751b8de277ce8d91f35cebcc4e491f16159020ec..5634e98fc1a4259d4739ecd78f9b3d70f6e25027 100644 --- a/netharn/initializers/pretrained.py +++ b/netharn/initializers/pretrained.py @@ -41,7 +41,8 @@ class Pretrained(api.Initializer, ub.NiceRepr): classification layer if class indexes are not aligned. association (str): controls how we search for the association between - the two model states. Can be strict, module-hack, prefix-hack, or embedding. + the two model states. Can be strict, module-hack, prefix-hack, or + embedding. info (dict, optional): specify explicit history info diff --git a/netharn/layers/norm.py b/netharn/layers/norm.py index 8aaa267bc1b2c09a9c78ef02392751b40f821408..e175bdade327ce2412bf755f4face8e323aa1084 100644 --- a/netharn/layers/norm.py +++ b/netharn/layers/norm.py @@ -87,6 +87,26 @@ class InputNorm(common.Module): >>> # Specifying either the mean or the std is ok. >>> partial1 = InputNorm(mean=50)(inputs) >>> partial2 = InputNorm(std=29)(inputs) + + import torch + + model = torch.nn.Sequential(*[ + InputNorm(mean=10, std=0.2), + torch.nn.Conv2d(3, 3, 3), + ]) + inputs = torch.rand(2, 3, 5, 7) * 100 + optim = torch.optim.SGD(model.parameters(), lr=1e-3) + + for i in range(100): + optim.zero_grad() + x = model(inputs).sum() + x.backward() + optim.step() + + std = model[0].mean + mean = model[0].std + print('std = {!r}'.format(std)) + print('mean = {!r}'.format(mean)) """ def __init__(self, mean=None, std=None): diff --git a/netharn/layers/perceptron.py b/netharn/layers/perceptron.py index 6c6f32702f663b171908e6756a901c73c4bd626e..0d6ba03cf017f1a2f946aece4eb5018d6320c47a 100644 --- a/netharn/layers/perceptron.py +++ b/netharn/layers/perceptron.py @@ -20,7 +20,8 @@ class MultiLayerPerceptronNd(common.Module): input and output channels) out_channels (int): dropout (float, default=0): amount of dropout to use - norm (str, default='batch'): type of normalization layer (e.g. batch or group) + norm (str, default='batch'): type of normalization layer + (e.g. batch or group), set to None for no normalization. noli (str, default='relu'): type of nonlinearity residual (bool, default=False): if true includes a resitual skip connection between inputs and diff --git a/netharn/mixins.py b/netharn/mixins.py index f083a0243cc7d303f859081c2affa2aeb6f2301b..58d0052ce939e2bce7e843d99059fd5de2d4e6c9 100644 --- a/netharn/mixins.py +++ b/netharn/mixins.py @@ -25,9 +25,9 @@ def _dump_monitor_tensorboard(harn, mode='epoch', special_groupers=['loss'], xdoctest -m netharn.mixins _dump_monitor_tensorboard --profile Example: - >>> from netharn.export.deployer import _demodata_toy_harn + >>> import netharn as nh >>> from netharn.mixins import _dump_monitor_tensorboard - >>> harn = _demodata_toy_harn() + >>> harn = nh.FitHarn.demo() >>> harn.run() >>> try: >>> _dump_monitor_tensorboard(harn) @@ -41,7 +41,7 @@ def _dump_monitor_tensorboard(harn, mode='epoch', special_groupers=['loss'], import six from six.moves import cPickle as pickle - harn.debug('Plotting tensorboard data. serial={}, mode={}'.format(serial, mode)) + # harn.debug('Plotting tensorboard data. serial={}, mode={}'.format(serial, mode)) train_dpath = harn.train_dpath diff --git a/netharn/models/toynet.py b/netharn/models/toynet.py index 4b596214d7375f1cbb3114158c267df886561fbc..7a1ac7d905c238332652104d86e7a9b1a1223f0c 100644 --- a/netharn/models/toynet.py +++ b/netharn/models/toynet.py @@ -17,6 +17,8 @@ class ToyNet1d(layers.Module): """ def __init__(self, input_channels=2, num_classes=2): super(ToyNet1d, self).__init__() + self.input_channels = input_channels + self.num_classes = num_classes self.layers = torch.nn.Sequential(*[ torch.nn.Linear(input_channels, 8), @@ -55,6 +57,8 @@ class ToyNet2d(layers.Module): """ def __init__(self, input_channels=1, num_classes=2): super(ToyNet2d, self).__init__() + self.input_channels = input_channels + self.num_classes = num_classes self.layers = torch.nn.Sequential(*[ torch.nn.Conv2d(input_channels, 8, kernel_size=3, padding=1, bias=False), diff --git a/netharn/prefit/lr_tests.py b/netharn/prefit/lr_tests.py index da91694011a39f547053ef50541e0470a2c9ab16..64ff5561f6f01acd8a20af494702d4c987fdf5a7 100644 --- a/netharn/prefit/lr_tests.py +++ b/netharn/prefit/lr_tests.py @@ -37,8 +37,8 @@ def lr_range_test(harn, init_value=1e-8, final_value=10., beta=0.98, Example: >>> from netharn.prefit.lr_tests import * - >>> from netharn.export.deployer import _demodata_toy_harn - >>> harn = _demodata_toy_harn().initialize() + >>> import netharn as nh + >>> harn = nh.FitHarn.demo().initialize() >>> result = lr_range_test(harn) >>> print('result = {!r}'.format(result)) >>> # xdoctest: +REQUIRES(--show) @@ -247,8 +247,8 @@ def lr_range_scan(harn, low=1e-6, high=10.0, num=8, niter_train=1, Example: >>> from netharn.prefit.lr_tests import * - >>> from netharn.export.deployer import _demodata_toy_harn - >>> harn = _demodata_toy_harn().initialize() + >>> import netharn as nh + >>> harn = nh.FitHarn.demo().initialize() >>> result = lr_range_scan(harn) >>> print('result = {!r}'.format(result)) >>> # xdoctest: +REQUIRES(--show) diff --git a/netharn/schedulers/core.py b/netharn/schedulers/core.py index 2bb2e02a2036205a436faf0c9e0a6be3b2a267b4..2967e852a3d0ac3ef412fa05762c6ceb81e8bcaa 100644 --- a/netharn/schedulers/core.py +++ b/netharn/schedulers/core.py @@ -1,6 +1,56 @@ import torch.optim.lr_scheduler from collections import defaultdict +""" + +# Notes on torch schedulers + +import torch +from torch.optim import lr_scheduler +from torch import optim + + +parameters = list(torch.nn.Conv1d(1, 1, 1).parameters()) + +base_lr = 1e-3 +optimizer = optim.SGD(parameters, lr=base_lr) + + +schedulers = {} +scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=20) +schedulers[scheduler.__class__.__name__] = scheduler +scheduler = lr_scheduler.OneCycleLR(optimizer, max_lr=base_lr, total_steps=100) +schedulers[scheduler.__class__.__name__] = scheduler +scheduler = lr_scheduler.StepLR(optimizer, step_size=30) +schedulers[scheduler.__class__.__name__] = scheduler +scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=0.9) +schedulers[scheduler.__class__.__name__] = scheduler + +key = scheduler.__class__.__name__ + + +xdata = list(range(100)) +ydata = ub.ddict(list) + +for key, scheduler in schedulers.items(): + + # Reset optimizer LR + for g in optimizer.param_groups: + g['lr'] = base_lr + + for x in xdata: + lr = scheduler.get_last_lr()[0] + scheduler.step() + ydata[key].append(lr) + +import kwplot +kwplot.autompl() + +kwplot.multi_plot(xdata=xdata, ydata=ydata) + + +""" + class CommonMixin(object): diff --git a/netharn/util/collect_env.py b/netharn/util/collect_env.py new file mode 100644 index 0000000000000000000000000000000000000000..f0bdee1097ba12310c1634b422e864e6d3e29a46 --- /dev/null +++ b/netharn/util/collect_env.py @@ -0,0 +1,440 @@ +""" +Adapted from +https://raw.githubusercontent.com/pytorch/pytorch/master/torch/utils/collect_env.py +""" +# This script outputs relevant system environment info +# Run it with `python collect_env.py`. +import locale +import re +import subprocess +import sys +import os +from collections import namedtuple + +import ubelt as ub +try: + from xdev import profile +except Exception: + profile = ub.profile + +try: + import torch + TORCH_AVAILABLE = True +except (ImportError, NameError, AttributeError): + TORCH_AVAILABLE = False + +# System Environment Information +SystemEnv = namedtuple('SystemEnv', [ + 'torch_version', + 'is_debug_build', + 'cuda_compiled_version', + 'gcc_version', + 'clang_version', + 'cmake_version', + 'os', + 'python_version', + 'is_cuda_available', + 'cuda_runtime_version', + 'nvidia_driver_version', + 'nvidia_gpu_models', + 'cudnn_version', + 'pip_version', # 'pip' or 'pip3' + 'pip_packages', + 'conda_packages', + 'hip_compiled_version', + 'hip_runtime_version', + 'miopen_runtime_version', +]) + + +def run(command): + """Returns (return-code, stdout, stderr)""" + p = subprocess.Popen(command, stdout=subprocess.PIPE, + stderr=subprocess.PIPE, shell=True) + raw_output, raw_err = p.communicate() + rc = p.returncode + enc = locale.getpreferredencoding() + output = raw_output.decode(enc) + err = raw_err.decode(enc) + return rc, output.strip(), err.strip() + + +def run_and_read_all(run_lambda, command): + """Runs command using run_lambda; reads and returns entire output if rc is 0""" + rc, out, _ = run_lambda(command) + if rc != 0: + return None + return out + + +def run_and_parse_first_match(run_lambda, command, regex): + """Runs command using run_lambda, returns the first regex match if it exists""" + rc, out, _ = run_lambda(command) + if rc != 0: + return None + match = re.search(regex, out) + if match is None: + return None + return match.group(1) + + +@profile +def get_conda_packages(run_lambda): + # If we are not in cond do nothing + if os.path.exists(os.path.join(sys.prefix, 'conda-meta')): + if get_platform() == 'win32': + system_root = os.environ.get('SystemRoot', 'C:\\Windows') + findstr_cmd = os.path.join(system_root, 'System32', 'findstr') + grep_cmd = r'{} /R "torch numpy cudatoolkit soumith mkl magma"'.format(findstr_cmd) + else: + grep_cmd = r'grep "torch\|numpy\|cudatoolkit\|soumith\|mkl\|magma"' + conda = os.environ.get('CONDA_EXE', 'conda') + out = run_and_read_all(run_lambda, conda + ' list | ' + grep_cmd) + if out is not None: + # Comment starting at beginning of line + comment_regex = re.compile(r'^#.*\n') + return re.sub(comment_regex, '', out) + + +def get_gcc_version(run_lambda): + return run_and_parse_first_match(run_lambda, 'gcc --version', r'gcc (.*)') + + +def get_clang_version(run_lambda): + return run_and_parse_first_match(run_lambda, 'clang --version', r'clang version (.*)') + + +def get_cmake_version(run_lambda): + return run_and_parse_first_match(run_lambda, 'cmake --version', r'cmake (.*)') + + +def get_nvidia_driver_version(run_lambda): + if get_platform() == 'darwin': + cmd = 'kextstat | grep -i cuda' + return run_and_parse_first_match(run_lambda, cmd, + r'com[.]nvidia[.]CUDA [(](.*?)[)]') + smi = get_nvidia_smi() + return run_and_parse_first_match(run_lambda, smi, r'Driver Version: (.*?) ') + + +def get_gpu_info(run_lambda): + if get_platform() == 'darwin' or torch.version.hip is not None: + if TORCH_AVAILABLE and torch.cuda.is_available(): + return torch.cuda.get_device_name(None) + return None + smi = get_nvidia_smi() + uuid_regex = re.compile(r' \(UUID: .+?\)') + rc, out, _ = run_lambda(smi + ' -L') + if rc != 0: + return None + # Anonymize GPUs by removing their UUID + return re.sub(uuid_regex, '', out) + + +def get_running_cuda_version(run_lambda): + return run_and_parse_first_match(run_lambda, 'nvcc --version', r'V(.*)$') + + +def get_cudnn_version(run_lambda): + """This will return a list of libcudnn.so; it's hard to tell which one is being used""" + if get_platform() == 'win32': + system_root = os.environ.get('SystemRoot', 'C:\\Windows') + cuda_path = os.environ.get('CUDA_PATH', "%CUDA_PATH%") + where_cmd = os.path.join(system_root, 'System32', 'where') + cudnn_cmd = '{} /R "{}\\bin" cudnn*.dll'.format(where_cmd, cuda_path) + elif get_platform() == 'darwin': + # CUDA libraries and drivers can be found in /usr/local/cuda/. See + # https://docs.nvidia.com/cuda/cuda-installation-guide-mac-os-x/index.html#install + # https://docs.nvidia.com/deeplearning/sdk/cudnn-install/index.html#installmac + # Use CUDNN_LIBRARY when cudnn library is installed elsewhere. + cudnn_cmd = 'ls /usr/local/cuda/lib/libcudnn*' + else: + cudnn_cmd = 'ldconfig -p | grep libcudnn | rev | cut -d" " -f1 | rev' + rc, out, _ = run_lambda(cudnn_cmd) + # find will return 1 if there are permission errors or if not found + if len(out) == 0 or (rc != 1 and rc != 0): + lib = os.environ.get('CUDNN_LIBRARY') + if lib is not None and os.path.isfile(lib): + return os.path.realpath(lib) + return None + files_set = set() + for fn in out.split('\n'): + fn = os.path.realpath(fn) # eliminate symbolic links + if os.path.isfile(fn): + files_set.add(fn) + if not files_set: + return None + # Alphabetize the result because the order is non-deterministic otherwise + files = list(sorted(files_set)) + if len(files) == 1: + return files[0] + result = '\n'.join(files) + return 'Probably one of the following:\n{}'.format(result) + + +def get_nvidia_smi(): + # Note: nvidia-smi is currently available only on Windows and Linux + smi = 'nvidia-smi' + if get_platform() == 'win32': + smi = '"C:\\Program Files\\NVIDIA Corporation\\NVSMI\\%s"' % smi + return smi + + +def get_platform(): + if sys.platform.startswith('linux'): + return 'linux' + elif sys.platform.startswith('win32'): + return 'win32' + elif sys.platform.startswith('cygwin'): + return 'cygwin' + elif sys.platform.startswith('darwin'): + return 'darwin' + else: + return sys.platform + + +def get_mac_version(run_lambda): + return run_and_parse_first_match(run_lambda, 'sw_vers -productVersion', r'(.*)') + + +def get_windows_version(run_lambda): + system_root = os.environ.get('SystemRoot', 'C:\\Windows') + wmic_cmd = os.path.join(system_root, 'System32', 'Wbem', 'wmic') + findstr_cmd = os.path.join(system_root, 'System32', 'findstr') + return run_and_read_all(run_lambda, '{} os get Caption | {} /v Caption'.format(wmic_cmd, findstr_cmd)) + + +def get_lsb_version(run_lambda): + return run_and_parse_first_match(run_lambda, 'lsb_release -a', r'Description:\t(.*)') + + +def check_release_file(run_lambda): + return run_and_parse_first_match(run_lambda, 'cat /etc/*-release', + r'PRETTY_NAME="(.*)"') + + +def get_os(run_lambda): + from platform import machine + platform = get_platform() + + if platform == 'win32' or platform == 'cygwin': + return get_windows_version(run_lambda) + + if platform == 'darwin': + version = get_mac_version(run_lambda) + if version is None: + return None + return 'Mac OSX {} ({})'.format(version, machine()) + + if platform == 'linux': + # Ubuntu/Debian based + desc = get_lsb_version(run_lambda) + if desc is not None: + return '{} ({})'.format(desc, machine()) + + # Try reading /etc/*-release + desc = check_release_file(run_lambda) + if desc is not None: + return '{} ({})'.format(desc, machine()) + + return '{} ({})'.format(platform, machine()) + + # Unknown platform + return platform + + +@profile +def get_pip_packages(run_lambda): + """Returns `pip list` output. Note: will also find conda-installed pytorch + and numpy packages.""" + + # MODIFICATION: just run with whatever python we are using + pip = sys.executable + ' -m pip' + + if get_platform() == 'win32': + system_root = os.environ.get('SystemRoot', 'C:\\Windows') + findstr_cmd = os.path.join(system_root, 'System32', 'findstr') + grep_cmd = r'{} /R "numpy torch"'.format(findstr_cmd) + else: + grep_cmd = r'grep "torch\|numpy"' + out = run_and_read_all(run_lambda, pip + ' list --format=freeze | ' + grep_cmd) + return 'pip', out + + +@profile +def get_env_info(): + run_lambda = run + + if TORCH_AVAILABLE: + version_str = torch.__version__ + debug_mode_str = str(torch.version.debug) + cuda_available_str = str(torch.cuda.is_available()) + cuda_version_str = torch.version.cuda + else: + version_str = debug_mode_str = cuda_available_str = cuda_version_str = 'N/A' + + if torch.version.hip is None: # cuda version + gpu_info = dict( + is_cuda_available=cuda_available_str, + cuda_compiled_version=cuda_version_str, + cuda_runtime_version=get_running_cuda_version(run_lambda), + nvidia_gpu_models=get_gpu_info(run_lambda), + nvidia_driver_version=get_nvidia_driver_version(run_lambda), + cudnn_version=get_cudnn_version(run_lambda), + hip_compiled_version='N/A', + hip_runtime_version='N/A', + miopen_runtime_version='N/A', + ) + else: # HIP version + cfg = torch._C._show_config().split('\n') + hip_runtime_version = [s.rsplit(None, 1)[-1] for s in cfg if 'HIP Runtime' in s][0] + miopen_runtime_version = [s.rsplit(None, 1)[-1] for s in cfg if 'MIOpen' in s][0] + gpu_info = dict( + is_cuda_available=cuda_available_str, + cuda_compiled_version='N/A', + hip_compiled_version=torch.version.hip, + hip_runtime_version=hip_runtime_version, + miopen_runtime_version=miopen_runtime_version, + cuda_runtime_version='N/A', + nvidia_gpu_models=get_gpu_info(run_lambda), + nvidia_driver_version=get_nvidia_driver_version(run_lambda), + cudnn_version='N/A', + ) + + pip_version, pip_list_output = get_pip_packages(run_lambda) + conda_packages = get_conda_packages(run_lambda) + gcc_version = get_gcc_version(run_lambda) + clang_version = get_clang_version(run_lambda) + cmake_version = get_cmake_version(run_lambda) + os_version = get_os(run_lambda) + + return SystemEnv( + torch_version=version_str, + is_debug_build=debug_mode_str, + python_version='{}.{} ({}-bit runtime)'.format(sys.version_info[0], sys.version_info[1], sys.maxsize.bit_length() + 1), + pip_version=pip_version, + pip_packages=pip_list_output, + conda_packages=conda_packages, + os=os_version, + gcc_version=gcc_version, + clang_version=clang_version, + cmake_version=cmake_version, + **gpu_info + ) + +env_info_fmt = """ +PyTorch version: {torch_version} +Is debug build: {is_debug_build} +CUDA used to build PyTorch: {cuda_compiled_version} +ROCM used to build PyTorch: {hip_compiled_version} + +OS: {os} +GCC version: {gcc_version} +Clang version: {clang_version} +CMake version: {cmake_version} + +Python version: {python_version} +Is CUDA available: {is_cuda_available} +CUDA runtime version: {cuda_runtime_version} +GPU models and configuration: {nvidia_gpu_models} +Nvidia driver version: {nvidia_driver_version} +cuDNN version: {cudnn_version} +HIP runtime version: {hip_runtime_version} +MIOpen runtime version: {miopen_runtime_version} + +Versions of relevant libraries: +{pip_packages} +{conda_packages} +""".strip() + + +def pretty_str(envinfo): + def replace_nones(dct, replacement='Could not collect'): + for key in dct.keys(): + if dct[key] is not None: + continue + dct[key] = replacement + return dct + + def replace_bools(dct, true='Yes', false='No'): + for key in dct.keys(): + if dct[key] is True: + dct[key] = true + elif dct[key] is False: + dct[key] = false + return dct + + def prepend(text, tag='[prepend]'): + lines = text.split('\n') + updated_lines = [tag + line for line in lines] + return '\n'.join(updated_lines) + + def replace_if_empty(text, replacement='No relevant packages'): + if text is not None and len(text) == 0: + return replacement + return text + + def maybe_start_on_next_line(string): + # If `string` is multiline, prepend a \n to it. + if string is not None and len(string.split('\n')) > 1: + return '\n{}\n'.format(string) + return string + + mutable_dict = envinfo._asdict() + + # If nvidia_gpu_models is multiline, start on the next line + mutable_dict['nvidia_gpu_models'] = \ + maybe_start_on_next_line(envinfo.nvidia_gpu_models) + + # If the machine doesn't have CUDA, report some fields as 'No CUDA' + dynamic_cuda_fields = [ + 'cuda_runtime_version', + 'nvidia_gpu_models', + 'nvidia_driver_version', + ] + all_cuda_fields = dynamic_cuda_fields + ['cudnn_version'] + all_dynamic_cuda_fields_missing = all( + mutable_dict[field] is None for field in dynamic_cuda_fields) + if TORCH_AVAILABLE and not torch.cuda.is_available() and all_dynamic_cuda_fields_missing: + for field in all_cuda_fields: + mutable_dict[field] = 'No CUDA' + if envinfo.cuda_compiled_version is None: + mutable_dict['cuda_compiled_version'] = 'None' + + # Replace True with Yes, False with No + mutable_dict = replace_bools(mutable_dict) + + # Replace all None objects with 'Could not collect' + mutable_dict = replace_nones(mutable_dict) + + # If either of these are '', replace with 'No relevant packages' + mutable_dict['pip_packages'] = replace_if_empty(mutable_dict['pip_packages']) + mutable_dict['conda_packages'] = replace_if_empty(mutable_dict['conda_packages']) + + # Tag conda and pip packages with a prefix + # If they were previously None, they'll show up as ie '[conda] Could not collect' + if mutable_dict['pip_packages']: + mutable_dict['pip_packages'] = prepend(mutable_dict['pip_packages'], + '[{}] '.format(envinfo.pip_version)) + if mutable_dict['conda_packages']: + mutable_dict['conda_packages'] = prepend(mutable_dict['conda_packages'], + '[conda] ') + return env_info_fmt.format(**mutable_dict) + + +def get_pretty_env_info(): + return pretty_str(get_env_info()) + + +def main(): + print("Collecting environment information...") + output = get_pretty_env_info() + print(output) + + +if __name__ == '__main__': + """ + CommandLine: + python ~/code/netharn/netharn/util/collect_env.py --profile + """ + main() diff --git a/netharn/util/util_json.py b/netharn/util/util_json.py index e4fc85c83e30f7e9980122dc4867bb65d9e29905..2347c2d5906ece0bc9a51c8e9871e2eadb679ce4 100644 --- a/netharn/util/util_json.py +++ b/netharn/util/util_json.py @@ -6,6 +6,7 @@ import six import torch import numpy as np import ubelt as ub +from collections.abc import Generator from collections import OrderedDict @@ -139,6 +140,18 @@ def ensure_json_serializable(dict_, normalize_containers=False, verbose=0): if True, normalizes dict containers to be standard python structures. + Example: + >>> data = ub.ddict(lambda: int) + >>> data['foo'] = ub.ddict(lambda: int) + >>> data['bar'] = np.array([1, 2, 3]) + >>> data['foo']['a'] = 1 + >>> data['foo']['b'] = (1, np.array([1, 2, 3]), {3: np.int(3), 4: np.float16(1.0)}) + >>> dict_ = data + >>> print(ub.repr2(data, nl=-1)) + >>> result = ensure_json_serializable(data, normalize_containers=True) + >>> print(ub.repr2(result, nl=-1)) + >>> assert type(result) is dict + Example: >>> data = ub.ddict(lambda: int) >>> data['foo'] = ub.ddict(lambda: int) @@ -161,100 +174,180 @@ def ensure_json_serializable(dict_, normalize_containers=False, verbose=0): c = dict(c) return c - # inplace convert any ndarrays to lists - def _walk_json(data, prefix=[]): - items = None - if isinstance(data, list): - items = enumerate(data) - elif isinstance(data, tuple): - items = enumerate(data) - elif isinstance(data, dict): - items = data.items() - else: - raise TypeError(type(data)) + walker = IndexableWalker(dict_) + for prefix, value in walker: + if isinstance(value, tuple): + new_value = list(value) + walker[prefix] = new_value + elif isinstance(value, np.ndarray): + new_value = value.tolist() + walker[prefix] = new_value + elif isinstance(value, torch.Tensor): + new_value = value.data.cpu().numpy().tolist() + walker[prefix] = new_value + elif isinstance(value, (np.integer)): + new_value = int(value) + walker[prefix] = new_value + elif isinstance(value, (np.floating)): + new_value = float(value) + walker[prefix] = new_value + elif isinstance(value, (np.complex)): + new_value = complex(value) + walker[prefix] = new_value + elif hasattr(value, '__json__'): + new_value = value.__json__() + walker[prefix] = new_value + elif normalize_containers: + if isinstance(value, dict): + new_value = _norm_container(value) + walker[prefix] = new_value + + if normalize_containers: + # normalize the outer layer + dict_ = _norm_container(dict_) + return dict_ + + +class IndexableWalker(Generator): + """ + Traverses through a nested tree-liked indexable structure. + + Generates a path and value to each node in the structure. The path is a + list of indexes which if applied in order will reach the value. + + The ``__setitem__`` method can be used to modify a nested value based on the + path returned by the generator. + + When generating values, you can use "send" to prevent traversal of a + particular branch. + + Example: + >>> # Create nested data + >>> import numpy as np + >>> data = ub.ddict(lambda: int) + >>> data['foo'] = ub.ddict(lambda: int) + >>> data['bar'] = np.array([1, 2, 3]) + >>> data['foo']['a'] = 1 + >>> data['foo']['b'] = np.array([1, 2, 3]) + >>> data['foo']['c'] = [1, 2, 3] + >>> data['baz'] = 3 + >>> print('data = {}'.format(ub.repr2(data, nl=True))) + >>> # We can walk through every node in the nested tree + >>> walker = IndexableWalker(data) + >>> for path, value in walker: + >>> print('walk path = {}'.format(ub.repr2(path, nl=0))) + >>> if path[-1] == 'c': + >>> # Use send to prevent traversing this branch + >>> got = walker.send(False) + >>> # We can modify the value based on the returned path + >>> walker[path] = 'changed the value of c' + >>> print('data = {}'.format(ub.repr2(data, nl=True))) + >>> assert data['foo']['c'] == 'changed the value of c' + """ + + def __init__(self, data, dict_cls=(dict,), list_cls=(list, tuple)): + self.data = data + self.dict_cls = dict_cls + self.list_cls = list_cls + self.indexable_cls = self.dict_cls + self.list_cls - root = prefix - level = {} - for key, value in items: - level[key] = value + self._walk_gen = None - # yield a dict so the user can choose to not walk down a path - yield root, level + def __iter__(self): + """ + Iterates through the indexable ``self.data`` + + Can send a False flag to prevent a branch from being traversed + + Yields: + Tuple[List, Any] : + path (List): list of index operations to arrive at the value + value (object): the value at the path + """ + return self + + def __next__(self): + """ returns next item from this generator """ + if self._walk_gen is None: + self._walk_gen = self._walk(self.data, prefix=[]) + return next(self._walk_gen) + + def send(self, arg): + """ + send(arg) -> send 'arg' into generator, + return next yielded value or raise StopIteration. + """ + # Note: this will error if called before __next__ + self._walk_gen.send(arg) - for key, value in level.items(): - if isinstance(value, (dict, list, tuple)): - path = prefix + [key] - for _ in _walk_json(value, prefix=path): - yield _ + def throw(self, type=None, value=None, traceback=None): + """ + throw(typ[,val[,tb]]) -> raise exception in generator, + return next yielded value or raise StopIteration. + """ + raise StopIteration - def _convert(dict_, root, key, new_value): - d = dict_ - for k in root: + def __setitem__(self, path, value): + """ + Set nested value by path + + Args: + path (List): list of indexes into the nested structure + value (object): new value + """ + d = self.data + *prefix, key = path + for k in prefix: d = d[k] - d[key] = new_value + d[key] = value - def _flatmap(func, data): - if isinstance(data, list): - return [_flatmap(func, item) for item in data] - else: - return func(data) - - to_convert = [] - for root, level in ub.ProgIter(_walk_json(dict_), desc='walk json', - verbose=verbose): - for key, value in level.items(): - if isinstance(value, tuple): - # Convert tuples on the fly so they become mutable - new_value = list(value) - _convert(dict_, root, key, new_value) - elif isinstance(value, np.ndarray): - new_value = value.tolist() - if 0: - if len(value.shape) == 1: - if value.dtype.kind in {'i', 'u'}: - new_value = list(map(int, new_value)) - elif value.dtype.kind in {'f'}: - new_value = list(map(float, new_value)) - elif value.dtype.kind in {'c'}: - new_value = list(map(complex, new_value)) - else: - pass - else: - if value.dtype.kind in {'i', 'u'}: - new_value = _flatmap(int, new_value) - elif value.dtype.kind in {'f'}: - new_value = _flatmap(float, new_value) - elif value.dtype.kind in {'c'}: - new_value = _flatmap(complex, new_value) - else: - pass - # raise TypeError(value.dtype) - to_convert.append((root, key, new_value)) - elif isinstance(value, torch.Tensor): - new_value = value.data.cpu().numpy().tolist() - to_convert.append((root, key, new_value)) - elif isinstance(value, (np.int16, np.int32, np.int64, - np.uint16, np.uint32, np.uint64)): - new_value = int(value) - to_convert.append((root, key, new_value)) - elif isinstance(value, (np.float32, np.float64)): - new_value = float(value) - to_convert.append((root, key, new_value)) - elif isinstance(value, (np.complex64, np.complex128)): - new_value = complex(value) - to_convert.append((root, key, new_value)) - elif hasattr(value, '__json__'): - new_value = value.__json__() - to_convert.append((root, key, new_value)) - elif normalize_containers: - if isinstance(value, dict): - new_value = _norm_container(value) - to_convert.append((root, key, new_value)) - - for root, key, new_value in to_convert: - _convert(dict_, root, key, new_value) + def __delitem__(self, path): + """ + Remove nested value by path - if normalize_containers: - # normalize the outer layer - dict_ = _norm_container(dict_) - return dict_ + Note: + It can be dangerous to use this while iterating (because we may try + to descend into a deleted location) or on leaf items that are + list-like (because the indexes of all subsequent items will be + modified). + + Args: + path (List): list of indexes into the nested structure. + The item at the last index will be removed. + """ + d = self.data + *prefix, key = path + for k in prefix: + d = d[k] + del d[key] + + def _walk(self, data, prefix=[]): + """ + Defines the underlying generator used by IndexableWalker + """ + stack = [(data, prefix)] + while stack: + _data, _prefix = stack.pop() + # Create an items iterable of depending on the indexable data type + if isinstance(_data, self.list_cls): + items = enumerate(_data) + elif isinstance(_data, self.dict_cls): + items = _data.items() + else: + raise TypeError(type(_data)) + + for key, value in items: + # Yield the full path to this position and its value + path = _prefix + [key] + message = yield path, value + # If the value at this path is also indexable, then continue + # the traversal, unless the False message was explicitly sent + # by the caller. + if message is False: + # Because the `send` method will return the next value, + # we yield a dummy value so we don't clobber the next + # item in the traversal. + yield None + else: + if isinstance(value, self.indexable_cls): + stack.append((value, path)) diff --git a/requirements/runtime.txt b/requirements/runtime.txt index b947a61320911240c5c7914a1c7bf0ad171a8e37..33256f33bd1fcda63c157809cb707968852212ce 100644 --- a/requirements/runtime.txt +++ b/requirements/runtime.txt @@ -29,3 +29,4 @@ kwimage >= 0.4.0 kwplot >= 0.4.0 qualname>=0.1.0;python_version < '3.0' +torch_liberator >= 0.0.4 diff --git a/super_setup.py b/super_setup.py index aa6b163dc61b75fc9a41bf2deb8624ea1cc0bcb6..4d8aa470e39d5f7df649086ef684f40e8e7ae474 100755 --- a/super_setup.py +++ b/super_setup.py @@ -403,10 +403,20 @@ class Repo(ub.NiceRepr): if exists(repo.dpath): raise ValueError('cannot clone into non-empty directory') args = '--recursive' + # NOTE: if the remote branch does not exist this will fail if repo.branch is not None: args += ' -b {}'.format(repo.branch) - command = 'git clone {args} {url} {dpath}'.format(args=args, url=repo.url, dpath=repo.dpath) - repo._cmd(command, cwd=repo.code_dpath) + try: + command = 'git clone {args} {url} {dpath}'.format( + args=args, url=repo.url, dpath=repo.dpath) + repo._cmd(command, cwd=repo.code_dpath) + except Exception as ex: + text = repr(ex) + if 'Remote branch' in text and 'not found' in text: + print('ERROR: It looks like the remote branch you asked for doesnt exist') + print('ERROR: Caused by: ex = {}'.format(text)) + raise Exception('Cannot find branch {} for repo {}'.format(repo.branch, repo)) + raise def _assert_clean(repo): if repo.pygit.is_dirty(): @@ -708,65 +718,68 @@ def determine_code_dpath(): return code_dpath -def make_netharn_registry(): +DEVEL_REPOS = [ + # The util libs + { + 'name': 'kwarray', 'branch': 'dev/0.5.10', 'remote': 'public', + 'remotes': {'public': 'git@gitlab.kitware.com:computer-vision/kwarray.git'}, + }, + { + 'name': 'kwimage', 'branch': 'dev/0.6.7', 'remote': 'public', + 'remotes': {'public': 'git@gitlab.kitware.com:computer-vision/kwimage.git'}, + }, + { + 'name': 'kwannot', 'branch': 'dev/0.1.0', 'remote': 'public', + 'remotes': {'public': 'git@gitlab.kitware.com:computer-vision/kwannot.git'}, + }, + { + 'name': 'kwcoco', 'branch': 'dev/0.1.7', 'remote': 'public', + 'remotes': {'public': 'git@gitlab.kitware.com:computer-vision/kwcoco.git'}, + }, + { + 'name': 'kwplot', 'branch': 'dev/0.4.8', 'remote': 'public', + 'remotes': {'public': 'git@gitlab.kitware.com:computer-vision/kwplot.git'}, + }, + + # Pytorch deployer / exporter + { + 'name': 'liberator', 'branch': 'dev/0.0.2', 'remote': 'public', + 'remotes': {'public': 'git@gitlab.kitware.com:python/liberator.git'}, + }, + { + 'name': 'torch_liberator', 'branch': 'dev/0.0.5', 'remote': 'public', + 'remotes': {'public': 'git@gitlab.kitware.com:computer-vision/torch_liberator.git'}, + }, + + # For example data and CLI + { + 'name': 'scriptconfig', 'branch': 'dev/0.5.8', 'remote': 'public', + 'remotes': {'public': 'git@gitlab.kitware.com:utils/scriptconfig.git'}, + }, + { + 'name': 'ndsampler', 'branch': 'dev/0.5.12', 'remote': 'public', + 'remotes': {'public': 'git@gitlab.kitware.com:computer-vision/ndsampler.git'}, + }, + + # netharn - training harness + { + 'name': 'netharn', 'branch': 'dev/0.5.10', 'remote': 'public', + 'remotes': {'public': 'git@gitlab.kitware.com:computer-vision/netharn.git'}, + }, +] + + +def make_registry(devel_repos): code_dpath = determine_code_dpath() CommonRepo = functools.partial(Repo, code_dpath=code_dpath) - - devel_repos = [ - # The util libs - { - 'name': 'kwarray', 'branch': 'dev/0.5.10', 'remote': 'public', - 'remotes': {'public': 'git@gitlab.kitware.com:computer-vision/kwarray.git'}, - }, - { - 'name': 'kwimage', 'branch': 'dev/0.6.6', 'remote': 'public', - 'remotes': {'public': 'git@gitlab.kitware.com:computer-vision/kwimage.git'}, - }, - { - 'name': 'kwcoco', 'branch': 'dev/0.1.6', 'remote': 'public', - 'remotes': {'public': 'git@gitlab.kitware.com:computer-vision/kwcoco.git'}, - }, - { - 'name': 'kwplot', 'branch': 'dev/0.4.8', 'remote': 'public', - 'remotes': {'public': 'git@gitlab.kitware.com:computer-vision/kwplot.git'}, - }, - - # Pytorch deployer / exporter - { - 'name': 'liberator', 'branch': 'dev/0.0.2', 'remote': 'public', - 'remotes': {'public': 'git@gitlab.kitware.com:python/liberator.git'}, - }, - { - 'name': 'torch_liberator', 'branch': 'dev/0.0.5', 'remote': 'public', - 'remotes': {'public': 'git@gitlab.kitware.com:computer-vision/torch_liberator.git'}, - }, - - # For example data and CLI - { - 'name': 'scriptconfig', 'branch': 'dev/0.5.8', 'remote': 'public', - 'remotes': {'public': 'git@gitlab.kitware.com:utils/scriptconfig.git'}, - }, - { - 'name': 'ndsampler', 'branch': 'dev/0.5.12', 'remote': 'public', - 'remotes': {'public': 'git@gitlab.kitware.com:computer-vision/ndsampler.git'}, - }, - - # netharn - training harness - { - 'name': 'netharn', 'branch': 'dev/0.5.9', 'remote': 'public', - 'remotes': {'public': 'git@gitlab.kitware.com:computer-vision/netharn.git'}, - }, - ] - repos = [CommonRepo(**kw) for kw in devel_repos] - registery = RepoRegistry(repos) return registery def main(): - - registery = make_netharn_registry() + devel_repos = DEVEL_REPOS + registery = make_registry(devel_repos) only = ub.argval('--only', default=None) if only is not None: diff --git a/tests/test_run_sequence.py b/tests/test_run_sequence.py new file mode 100644 index 0000000000000000000000000000000000000000..5ea8172fb10366922502f5bb90c6b9a0a2e3445f --- /dev/null +++ b/tests/test_run_sequence.py @@ -0,0 +1,108 @@ +""" +Tests the order in which things happen in "run" +""" +# import torch.nn.functional as F +import numpy as np +import ubelt as ub +import netharn as nh +import torch + + +class Failpoint(Exception): + pass + + +class MyHarn(nh.FitHarn): + + def run_batch(harn, raw_batch): + if harn.epoch == harn.failpoint and harn.batch_index >= 4: + raise Failpoint + + x = torch.Tensor([[1, 2]]) + f = torch.nn.Linear(2, 1) + y = f(x) + loss = y.sum() + output = y + + # harn._all_iters[harn.current_tag].append(harn.iter_index) + # batch = harn.xpu.move(raw_batch) + # output = harn.model(batch['im']) + # log_probs = F.log_softmax(output, dim=1) + # loss_parts = { + # 'nll_loss': F.nll_loss(log_probs, batch['label']), + # } + return output, loss + + +def test_run_sequence(): + """ + main test function + """ + datasets = { + 'train': nh.data.ToyData2d(size=3, border=1, n=7, rng=0), + 'vali': nh.data.ToyData2d(size=3, border=1, n=3, rng=0), + } + model = nh.models.ToyNet2d() + + hyper = { + # --- data first + 'datasets' : datasets, + 'nice' : 'test_run_sequence', + 'workdir' : ub.ensure_app_cache_dir('netharn/test/test_run_sequence'), + 'loaders' : {'batch_size': 1}, + 'xpu' : nh.XPU.coerce('cpu'), + # --- algorithm second + 'model' : model, + 'optimizer' : nh.api.Optimizer.coerce({'optim': 'sgd'}), + 'initializer' : nh.api.Initializer.coerce({'init': 'noop'}), + 'scheduler' : nh.api.Scheduler.coerce({'scheduler': 'step-3-7'}), + 'dynamics' : nh.api.Dynamics.coerce({'batch_step': 1, 'warmup_iters': 6}), + 'monitor' : (nh.Monitor, {'max_epoch': 4}), + } + harn1 = MyHarn(hyper=hyper) + harn1.preferences['verbose'] = 1 + harn1.preferences['use_tensorboard'] = False + harn1.preferences['eager_dump_tensorboard'] = False + + harn1.intervals['log_iter_train'] = 1 + harn1.intervals['log_iter_vali'] = 1 + harn1.intervals['cleanup'] = 5 + # Delete previous data + harn1.initialize(reset='delete') + + # Cause the harness to fail + try: + harn1.failpoint = 0 + harn1.run() + except Failpoint: + pass + print('\nFAILPOINT REACHED\n') + + # Restarting the harness should begin at the same point + harn2 = MyHarn(hyper=hyper) + harn2.preferences.update(harn1.preferences) + harn2.intervals.update(harn1.intervals) + harn2.failpoint = None + harn2.run() + + if 0: + idxs1 = harn1._all_iters['train'] + idxs2 = harn2._all_iters['train'] + diff1 = np.diff(idxs1) + diff2 = np.diff(idxs2) + print('idxs1 = {!r}'.format(idxs1)) + print('idxs2 = {!r}'.format(idxs2)) + print('diff1 = {!r}'.format(diff1)) + print('diff2 = {!r}'.format(diff2)) + assert np.all(diff1 == 1) + assert np.all(diff2 == 1) + assert idxs1[0] == 0 + assert idxs1[-1] == (idxs2[0] - 1) + + +if __name__ == '__main__': + """ + CommandLine: + python ~/code/netharn/tests/test_run_sequence.py + """ + test_run_sequence()