From bc0e8d59058942d09e8031d33cf38b8faebf799f Mon Sep 17 00:00:00 2001 From: joncrall Date: Thu, 14 Nov 2019 17:30:09 -0500 Subject: [PATCH 01/24] support more activations --- CHANGELOG.md | 8 ++++++ netharn/layers/rectify.py | 48 ++++++++++++++++++++-------------- netharn/receptive_field_for.py | 2 +- 3 files changed, 37 insertions(+), 21 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7da3c50..c2efd13 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,14 @@ This changelog follows the specifications detailed in: [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). This project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html), although we have not yet reached a `1.0.0` release. + +## Version 0.5.2 + +### Added + +* Rectify nonlinearity now supports more torch activations + + ## Version 0.5.1 ### Changed diff --git a/netharn/layers/rectify.py b/netharn/layers/rectify.py index 560a9a0..1170015 100644 --- a/netharn/layers/rectify.py +++ b/netharn/layers/rectify.py @@ -26,14 +26,7 @@ def rectify_nonlinearity(key=ub.NoParam, dim=2): key = 'relu' if isinstance(key, six.string_types): - if key == 'relu': - key = {'type': 'relu'} - elif key == 'relu6': - key = {'type': 'relu6'} - elif key == 'leaky_relu': - key = {'type': 'leaky_relu', 'negative_slope': 1e-2} - else: - raise KeyError(key) + key = {'type': key} elif isinstance(key, dict): key = key.copy() else: @@ -47,6 +40,12 @@ def rectify_nonlinearity(key=ub.NoParam, dim=2): cls = torch.nn.LeakyReLU elif noli_type == 'relu': cls = torch.nn.ReLU + elif noli_type == 'elu': + cls = torch.nn.ELU + elif noli_type == 'celu': + cls = torch.nn.CELU + elif noli_type == 'selu': + cls = torch.nn.SELU elif noli_type == 'relu6': cls = torch.nn.ReLU6 else: @@ -54,10 +53,15 @@ def rectify_nonlinearity(key=ub.NoParam, dim=2): return cls(**kw) -def rectify_normalizer(in_channels, key=ub.NoParam, dim=2): +def rectify_normalizer(in_channels, key=ub.NoParam, dim=2, **kwargs): """ Allows dictionary based specification of a normalizing layer + Args: + in_channels (int): number of input channels + dim (int): dimensionality + **kwargs: extra args + Example: >>> rectify_normalizer(8) BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) @@ -82,16 +86,7 @@ def rectify_normalizer(in_channels, key=ub.NoParam, dim=2): key = 'batch' if isinstance(key, six.string_types): - if key == 'batch': - key = {'type': 'batch'} - elif key == 'syncbatch': - key = {'type': 'syncbatch'} - elif key == 'group': - key = {'type': 'group', 'num_groups': ('gcd', min(in_channels, 32))} - elif key == 'batch+group': - key = {'type': 'batch+group'} - else: - raise KeyError(key) + key = {'type': key} elif isinstance(key, dict): key = key.copy() else: @@ -116,6 +111,9 @@ def rectify_normalizer(in_channels, key=ub.NoParam, dim=2): cls = torch.nn.SyncBatchNorm elif norm_type == 'group': in_channels_key = 'num_channels' + if key.get('num_groups') is None: + key['num_groups'] = ('gcd', min(in_channels, 32)) + if isinstance(key['num_groups'], tuple): if key['num_groups'][0] == 'gcd': key['num_groups'] = gcd( @@ -134,7 +132,17 @@ def rectify_normalizer(in_channels, key=ub.NoParam, dim=2): raise KeyError('unknown type: {}'.format(key)) assert in_channels_key not in key key[in_channels_key] = in_channels - return cls(**key) + + try: + import copy + kw = copy.copy(key) + kw.update(kwargs) + return cls(**kw) + except Exception: + # Ignore kwargs + import warnings + warnings.warn('kwargs ignored in rectify normalizer') + return cls(**key) def rectify_conv(dim=2): diff --git a/netharn/receptive_field_for.py b/netharn/receptive_field_for.py index 1875700..d01af50 100644 --- a/netharn/receptive_field_for.py +++ b/netharn/receptive_field_for.py @@ -531,7 +531,7 @@ class _TorchMixin(object): return ReceptiveFieldFor._unchanged(module, input_field) @staticmethod - @compute_type(nn.ReLU6, nn.PReLU, nn.LeakyReLU) + @compute_type(nn.ReLU6, nn.PReLU, nn.LeakyReLU, nn.ELU, nn.CELU, nn.SELU) def _unchanged_activation(module, input_field=None): return ReceptiveFieldFor._unchanged(module, input_field) -- GitLab From de892866e243b9f8b2d8075febb053058641bf43 Mon Sep 17 00:00:00 2001 From: joncrall Date: Fri, 15 Nov 2019 14:59:51 -0500 Subject: [PATCH 02/24] wip --- requirements/problematic.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements/problematic.txt b/requirements/problematic.txt index e2d00f5..cc82faf 100644 --- a/requirements/problematic.txt +++ b/requirements/problematic.txt @@ -1 +1,2 @@ # These are optional requirements that are problematic when installing via pip +pycocotools -- GitLab From 9812c794935541416c4d5604144c7a499f04b028 Mon Sep 17 00:00:00 2001 From: Jon Crall Date: Tue, 19 Nov 2019 08:53:01 -0500 Subject: [PATCH 03/24] Update supersetup --- super_setup.py | 93 +++++++++++++++++++++++++++++++++++--------------- 1 file changed, 65 insertions(+), 28 deletions(-) diff --git a/super_setup.py b/super_setup.py index 41b9995..54b3cae 100755 --- a/super_setup.py +++ b/super_setup.py @@ -254,9 +254,9 @@ class Repo(ub.NiceRepr): else: repo.debug(ub.color_text('Ensuring {}'.format(repo), 'blue')) - if dry: - if not exists(repo.dpath): - repo.debug('NEED TO CLONE {}'.format(repo)) + if not exists(repo.dpath): + repo.debug('NEED TO CLONE {}'.format(repo)) + if dry: return repo.ensure_clone() @@ -266,13 +266,16 @@ class Repo(ub.NiceRepr): # Ensure all registered remotes exist for remote_name, remote_url in repo.remotes.items(): try: - repo.pygit.remotes[remote_name] + remote = repo.pygit.remotes[remote_name] + have_urls = list(remote.urls) + if remote_url not in have_urls: + print('WARNING: REMOTE NAME EXIST BUT URL IS NOT {}. ' + 'INSTEAD GOT: {}'.format(remote_url, have_urls)) except (IndexError): try: - if dry: - print('NEED TO ADD REMOTE {}->{} FOR {}'.format( - remote_name, remote_url, repo)) - else: + print('NEED TO ADD REMOTE {}->{} FOR {}'.format( + remote_name, remote_url, repo)) + if not dry: repo._cmd('git remote add {} {}'.format(remote_name, remote_url)) except Exception: if remote_name == repo.remote: @@ -281,20 +284,33 @@ class Repo(ub.NiceRepr): # Ensure we are on the right branch if repo.branch != repo.pygit.active_branch.name: - if dry: - repo.debug('NEED TO SET BRANCH TO {} for {}'.format(repo.branch, repo)) - else: + repo.debug('NEED TO SET BRANCH TO {} for {}'.format(repo.branch, repo)) + if not dry: try: + remote = repo.pygit.remotes[repo.remote] + if not remote.exists(): + raise IndexError + except IndexError: + repo.debug('WARNING: remote={} does not exist'.format(remote)) + else: + if remote.exists(): + remote_branchnames = [ref.remote_head for ref in remote.refs] + if repo.branch not in remote_branchnames: + repo.info('Branch name not found in local remote. Attempting to fetch') + repo._cmd('git fetch {}'.format(remote.name)) + # remote.fetch() + repo._cmd('git checkout {}'.format(repo.branch)) - except Exception: - repo._cmd('git fetch --all') - repo._cmd('git checkout -b {} {}/{}'.format(repo.branch, repo.remote, repo.branch)) + # try: + # repo._cmd('git checkout {}'.format(repo.branch)) + # except Exception: + # repo._cmd('git fetch --all') + # repo._cmd('git checkout -b {} {}/{}'.format(repo.branch, repo.remote, repo.branch)) tracking_branch = repo.pygit.active_branch.tracking_branch() if tracking_branch is None or tracking_branch.remote_name != repo.remote: - if dry: - repo.debug('NEED TO SET UPSTREAM FOR FOR {}'.format(repo)) - else: + repo.debug('NEED TO SET UPSTREAM FOR FOR {}'.format(repo)) + if not dry: try: remote = repo.pygit.remotes[repo.remote] if not remote.exists(): @@ -303,15 +319,25 @@ class Repo(ub.NiceRepr): repo.debug('WARNING: remote={} does not exist'.format(remote)) else: if remote.exists(): - try: - repo._cmd('git branch --set-upstream-to={remote}/{branch} {branch}'.format( - remote=repo.remote, branch=repo.branch - )) - except Exception: - repo._cmd('git fetch --all') - repo._cmd('git branch --set-upstream-to={remote}/{branch} {branch}'.format( - remote=repo.remote, branch=repo.branch - )) + remote_branchnames = [ref.remote_head for ref in remote.refs] + if repo.branch not in remote_branchnames: + repo.info('Branch name not found in local remote. Attempting to fetch') + remote.fetch() + + repo._cmd('git branch --set-upstream-to={remote}/{branch} {branch}'.format( + remote=repo.remote, branch=repo.branch + )) + + # try: + # repo._cmd('git branch --set-upstream-to={remote}/{branch} {branch}'.format( + # remote=repo.remote, branch=repo.branch + # )) + # except Exception: + # # remote.fetch() + # repo._cmd('git fetch --all') + # repo._cmd('git branch --set-upstream-to={remote}/{branch} {branch}'.format( + # remote=repo.remote, branch=repo.branch + # )) # Print some status repo.debug(' * branch = {} -> {}'.format( @@ -470,8 +496,8 @@ def make_netharn_registry(): # For example data and CLI CommonRepo( - name='scriptconfig', branch='dev/0.5.1', remote='computer-vision', - remotes={'computer-vision': 'git@gitlab.kitware.com:computer-vision/scriptconfig.git'}, + name='scriptconfig', branch='dev/0.5.1', remote='utils', + remotes={'utils': 'git@gitlab.kitware.com:utils/scriptconfig.git'}, ), CommonRepo( name='ndsampler', branch='dev/0.5.0', remote='computer-vision', @@ -492,6 +518,11 @@ def main(): import click registery = make_netharn_registry() + only = ub.argval('--only', default=None) + if only is not None: + only = only.split(',') + registery.repos = [repo for repo in registery.repos if repo.name in only] + num_workers = int(ub.argval('--workers', default=8)) if ub.argflag('--serial'): num_workers = 0 @@ -513,6 +544,9 @@ def main(): @cli_group.add_command @click.command('ensure', context_settings=default_context_settings) def ensure(): + """ + Ensure is the live run of "check". + """ registery.apply('ensure', num_workers=num_workers) @cli_group.add_command @@ -523,6 +557,9 @@ def main(): @cli_group.add_command @click.command('check', context_settings=default_context_settings) def check(): + """ + Check is just a dry run of "ensure". + """ registery.apply('check', num_workers=num_workers) @cli_group.add_command -- GitLab From 68550b713a3cdaf2beffe915304baa09c3ab7c56 Mon Sep 17 00:00:00 2001 From: Jon Crall Date: Tue, 19 Nov 2019 08:54:20 -0500 Subject: [PATCH 04/24] Fix branch --- super_setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/super_setup.py b/super_setup.py index 54b3cae..154e359 100755 --- a/super_setup.py +++ b/super_setup.py @@ -506,7 +506,7 @@ def make_netharn_registry(): # netharn - training harness CommonRepo( - name='netharn', branch='dev/0.5.1', remote='computer-vision', + name='netharn', branch='dev/0.5.2', remote='computer-vision', remotes={'computer-vision': 'git@gitlab.kitware.com:computer-vision/netharn.git'}, ), ] -- GitLab From 8df1a8231774704f6b5b3030061ba2aeb89210c2 Mon Sep 17 00:00:00 2001 From: Jon Crall Date: Tue, 19 Nov 2019 08:55:03 -0500 Subject: [PATCH 05/24] wip --- super_setup.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/super_setup.py b/super_setup.py index 154e359..dbec519 100755 --- a/super_setup.py +++ b/super_setup.py @@ -481,33 +481,33 @@ def make_netharn_registry(): # The util libs CommonRepo( - name='kwarray', branch='dev/0.5.2', remote='computer-vision', - remotes={'computer-vision': 'git@gitlab.kitware.com:computer-vision/kwarray.git'}, + name='kwarray', branch='dev/0.5.2', remote='public', + remotes={'public': 'git@gitlab.kitware.com:computer-vision/kwarray.git'}, ), CommonRepo( - name='kwimage', branch='dev/0.5.2', remote='computer-vision', - remotes={'computer-vision': 'git@gitlab.kitware.com:computer-vision/kwimage.git'}, + name='kwimage', branch='dev/0.5.2', remote='public', + remotes={'public': 'git@gitlab.kitware.com:computer-vision/kwimage.git'}, ), CommonRepo( - name='kwplot', branch='dev/0.4.0', remote='computer-vision', - remotes={'computer-vision': 'git@gitlab.kitware.com:computer-vision/kwplot.git'}, + name='kwplot', branch='dev/0.4.0', remote='public', + remotes={'public': 'git@gitlab.kitware.com:computer-vision/kwplot.git'}, ), # For example data and CLI CommonRepo( - name='scriptconfig', branch='dev/0.5.1', remote='utils', - remotes={'utils': 'git@gitlab.kitware.com:utils/scriptconfig.git'}, + name='scriptconfig', branch='dev/0.5.1', remote='public', + remotes={'public': 'git@gitlab.kitware.com:utils/scriptconfig.git'}, ), CommonRepo( - name='ndsampler', branch='dev/0.5.0', remote='computer-vision', - remotes={'computer-vision': 'git@gitlab.kitware.com:computer-vision/ndsampler.git'}, + name='ndsampler', branch='dev/0.5.0', remote='public', + remotes={'public': 'git@gitlab.kitware.com:computer-vision/ndsampler.git'}, ), # netharn - training harness CommonRepo( - name='netharn', branch='dev/0.5.2', remote='computer-vision', - remotes={'computer-vision': 'git@gitlab.kitware.com:computer-vision/netharn.git'}, + name='netharn', branch='dev/0.5.2', remote='public', + remotes={'public': 'git@gitlab.kitware.com:computer-vision/netharn.git'}, ), ] registery = RepoRegistry(repos) -- GitLab From ff6e31fe2d5865c2e2e859a3de7fbcb7b52b63ec Mon Sep 17 00:00:00 2001 From: Jon Crall Date: Tue, 19 Nov 2019 09:07:40 -0500 Subject: [PATCH 06/24] Fix small issues in CIFAR --- CHANGELOG.md | 7 ++++- examples/cifar.py | 29 +++++++++---------- examples/mnist.py | 71 ++++++++++++++++++++++++----------------------- netharn/mixins.py | 7 ++++- 4 files changed, 63 insertions(+), 51 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c2efd13..9a6569a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,9 +7,14 @@ This project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.htm ## Version 0.5.2 ### Added - * Rectify nonlinearity now supports more torch activations +### Changed +* Smoothing no longer applied to lr (learning rate) and momentum monitor plots + +### Fixed +* Small issues in CIFAR Example + ## Version 0.5.1 diff --git a/examples/cifar.py b/examples/cifar.py index 7807c28..caff1d6 100644 --- a/examples/cifar.py +++ b/examples/cifar.py @@ -190,7 +190,8 @@ class CIFAR_FitHarn(nh.FitHarn): 'pred_scores': pred_scores, } if true_cxs is not None: - hot = nh.criterions.focal.one_hot_embedding(true_cxs, class_probs.shape[1]) + import kwarray + hot = kwarray.one_hot_embedding(true_cxs, class_probs.shape[1]) true_probs = (hot * class_probs).sum(dim=1) decoded['true_scores'] = true_probs return decoded @@ -198,7 +199,7 @@ class CIFAR_FitHarn(nh.FitHarn): def _draw_batch(harn, batch, decoded, limit=32): """ CommandLine: - xdoctest -m ~/code/netharn/examples/cifar.py CIFAR_FitHarn._draw_batch --show --arch=wrn_22 + xdoctest -m ~/code/netharn/examples/cifar.py CIFAR_FitHarn._draw_batch --show --arch=resnet50 Example: >>> import sys @@ -210,11 +211,12 @@ class CIFAR_FitHarn(nh.FitHarn): >>> decoded = harn._decode(outputs, batch['label']) >>> stacked = harn._draw_batch(batch, decoded, limit=42) >>> # xdoctest: +REQUIRES(--show) - >>> import netharn as nh - >>> nh.util.autompl() - >>> nh.util.imshow(stacked, colorspace='rgb', doclf=True) - >>> nh.util.show_if_requested() + >>> import kwplot + >>> kwplot.autompl() + >>> kwplot.imshow(stacked, colorspace='rgb', doclf=True) + >>> kwplot.show_if_requested() """ + import kwimage inputs = batch['input'] inputs = inputs[0:limit] @@ -256,25 +258,25 @@ class CIFAR_FitHarn(nh.FitHarn): } color = 'dodgerblue' if pcx == tcx else 'orangered' - im_ = nh.util.draw_text_on_image(im_, pred_label, org=org1 - 2, + im_ = kwimage.draw_text_on_image(im_, pred_label, org=org1 - 2, color='white', **fontkw) - im_ = nh.util.draw_text_on_image(im_, true_label, org=org2 - 2, + im_ = kwimage.draw_text_on_image(im_, true_label, org=org2 - 2, color='white', **fontkw) for i in [-2, -1, 1, 2]: for j in [-2, -1, 1, 2]: - im_ = nh.util.draw_text_on_image(im_, pred_label, org=org1 + i, + im_ = kwimage.draw_text_on_image(im_, pred_label, org=org1 + i, color='black', **fontkw) - im_ = nh.util.draw_text_on_image(im_, true_label, org=org2 + j, + im_ = kwimage.draw_text_on_image(im_, true_label, org=org2 + j, color='black', **fontkw) - im_ = nh.util.draw_text_on_image(im_, pred_label, org=org1, + im_ = kwimage.draw_text_on_image(im_, pred_label, org=org1, color=color, **fontkw) - im_ = nh.util.draw_text_on_image(im_, true_label, org=org2, + im_ = kwimage.draw_text_on_image(im_, true_label, org=org2, color='lawngreen', **fontkw) todraw.append(im_) - stacked = nh.util.stack_images_grid(todraw, overlap=-10, bg_value=(10, 40, 30), chunksize=8) + stacked = kwimage.stack_images_grid(todraw, overlap=-10, bg_value=(10, 40, 30), chunksize=8) return stacked @@ -304,7 +306,6 @@ def setup_harn(): from torchvision import transforms config = { - # TODO: the fast.ai baseline # 'arch': ub.argval('--arch', default='wrn_22'), # 'schedule': ub.argval('--arch', default='onecycle'), diff --git a/examples/mnist.py b/examples/mnist.py index 75c34ad..f669976 100644 --- a/examples/mnist.py +++ b/examples/mnist.py @@ -4,9 +4,8 @@ fit_harness takes your hyperparams and applys standardized "state-of-the-art" training procedures But everything is overwritable. -Experimentation and freedom to protype quickly is extremely important -We do our best not to get in the way, just performing a jumping off -point. +Experimentation and freedom to protype quickly is extremely important. +We do our best not to get in the way, just performing a jumping off point. """ from __future__ import absolute_import, division, print_function, unicode_literals import ubelt as ub @@ -78,11 +77,12 @@ class MnistHarn(nh.FitHarn): true_labels = batch['label'].cpu().numpy() if harn.batch_index < 3: + import kwimage decoded = harn._decode(outputs, batch['label']) stacked = harn._draw_batch(batch, decoded) dpath = ub.ensuredir((harn.train_dpath, 'monitor', harn.current_tag)) fpath = join(dpath, 'epoch_{}_batch_{}.jpg'.format(harn.epoch, harn.batch_index)) - nh.util.imwrite(fpath, stacked) + kwimage.imwrite(fpath, stacked) acc = (true_labels == pred_labels).mean() @@ -109,9 +109,6 @@ class MnistHarn(nh.FitHarn): def _draw_batch(harn, batch, decoded, limit=32): """ - CommandLine: - xdoctest -m ~/code/netharn/examples/cifar.py CIFAR_FitHarn._draw_batch --show --arch=wrn_22 - Example: >>> import sys >>> sys.path.append('/home/joncrall/code/netharn/examples') @@ -125,11 +122,13 @@ class MnistHarn(nh.FitHarn): >>> fpath = harn._draw_batch(bx, batch, decoded, limit=42) >>> print('fpath = {!r}'.format(fpath)) >>> # xdoctest: +REQUIRES(--show) - >>> import netharn as nh - >>> nh.util.autompl() - >>> nh.util.imshow(fpath, colorspace='rgb', doclf=True) - >>> nh.util.show_if_requested() + >>> import kwplot + >>> kwplot.autompl() + >>> kwplot.imshow(fpath, colorspace='rgb', doclf=True) + >>> kwplot.show_if_requested() """ + import kwimage + import kwplot inputs = batch['input'] inputs = inputs[0:limit] @@ -150,12 +149,12 @@ class MnistHarn(nh.FitHarn): todraw = [] for im, pcx, tcx, probs in zip(inputs, pred_cxs, true_cxs, class_probs): im_ = im.transpose(1, 2, 0) - im_ = nh.util.convert_colorspace(im_, 'gray', 'rgb') + im_ = kwimage.convert_colorspace(im_, 'gray', 'rgb') im_ = np.ascontiguousarray(im_) - im_ = nh.util.draw_clf_on_image(im_, dset.classes, tcx, probs) + im_ = kwplot.draw_clf_on_image(im_, dset.classes, tcx, probs) todraw.append(im_) - stacked = nh.util.stack_images_grid(todraw, overlap=-10, bg_value=(10, 40, 30), chunksize=8) + stacked = kwimage.stack_images_grid(todraw, overlap=-10, bg_value=(10, 40, 30), chunksize=8) return stacked @@ -255,28 +254,30 @@ def setup_harn(**kw): datasets=datasets, loaders=loaders, model=(MnistNet, dict(num_channels=1, classes=datasets['train'].classes)), - # optimizer=torch.optim.Adam, + # optimizer=torch.optim.AdamW, optimizer=(torch.optim.SGD, {'lr': 0.01, 'weight_decay': 3e-6}), - scheduler='ReduceLROnPlateau', - # scheduler=(nh.schedulers.ListedScheduler, { - # 'points': { - # 'lr': { - # 0 : 0.01, - # 10 : 0.10, - # 20 : 0.01, - # 40 : 0.0001, - # }, - # 'momentum': { - # 0 : 0.95, - # 10 : 0.85, - # 20 : 0.95, - # 40 : 0.99, - # }, - # 'weight_decay': { - # 0: 3e-6, - # } - # } - # }), + # scheduler='ReduceLROnPlateau', + scheduler=(nh.schedulers.ListedScheduler, { + 'points': { + 'lr': { + 0 : 0.01, + 2 : 0.05, + 10 : 0.10, + 20 : 0.01, + 40 : 0.0001, + }, + 'momentum': { + 0 : 0.95, + 10 : 0.85, + 20 : 0.95, + 40 : 0.99, + }, + 'weight_decay': { + 0: 3e-6, + } + }, + 'interpolation': 'linear', + }), criterion=torch.nn.CrossEntropyLoss, initializer=initializer, monitor=(nh.Monitor, { diff --git a/netharn/mixins.py b/netharn/mixins.py index 9caddd7..68038e6 100644 --- a/netharn/mixins.py +++ b/netharn/mixins.py @@ -245,6 +245,9 @@ def _dump_measures(tb_data, out_dpath, mode=None, smoothing=0.6, high = high_ if high is None else max(high_, high) return (low, high) + # Hack values that we don't apply smoothing to + HACK_NO_SMOOTH = ['lr', 'momentum'] + GROUP_LOSSES = True if GROUP_LOSSES: # Group all losses in one plot for comparison @@ -268,7 +271,9 @@ def _dump_measures(tb_data, out_dpath, mode=None, smoothing=0.6, xydata = ub.odict() for key in sorted(losses): ydata = tb_data[key]['ydata'] - ydata = smooth_curve(ydata, smoothing) + + if key not in HACK_NO_SMOOTH: + ydata = smooth_curve(ydata, smoothing) try: pos_ys = ydata[ydata > 0] -- GitLab From 20d4f06c0d2a8630f29456f40ba6e4c69e324cb5 Mon Sep 17 00:00:00 2001 From: joncrall Date: Tue, 19 Nov 2019 09:20:08 -0500 Subject: [PATCH 07/24] enhance cifar example --- examples/cifar.py | 36 +++++++++++++++++++++++++++--------- netharn/fit_harn.py | 5 +++++ 2 files changed, 32 insertions(+), 9 deletions(-) diff --git a/examples/cifar.py b/examples/cifar.py index caff1d6..4069e01 100644 --- a/examples/cifar.py +++ b/examples/cifar.py @@ -41,6 +41,7 @@ CommandLine: python examples/cifar.py --gpu=1,2 --arch=resnet50 --lr=0.003 --schedule=onecycle --optim=adamw """ +import sys from os.path import join import numpy as np import ubelt as ub @@ -101,6 +102,17 @@ class CIFAR_FitHarn(nh.FitHarn): loss = harn.criterion(outputs, labels) return outputs, loss + # def backpropogate(harn, bx, batch, loss): + # """ + # Note: this function usually does not need to be overloaded, + # but you can if you want to. The actual base implementation is + # slightly more nuanced. For details see: + # :func:netharn.fit_harn.CoreCallbacks.backpropogate + # """ + # loss.backward() + # harn.optimizer.step() + # harn.optimizer.zero_grad() + def on_batch(harn, batch, outputs, loss): """ Custom code executed at the end of each batch. @@ -305,12 +317,9 @@ def setup_harn(): import torchvision from torchvision import transforms + # Note that most netharn training scripts will use scriptconfig instead of + # this more explicit approach. config = { - # TODO: the fast.ai baseline - # 'arch': ub.argval('--arch', default='wrn_22'), - # 'schedule': ub.argval('--arch', default='onecycle'), - # 'lr': float(ub.argval('--lr', default=0.003)), - # A conservative traditional baseline 'arch': ub.argval('--arch', default='resnet50'), 'lr': float(ub.argval('--lr', default=0.1)), @@ -403,6 +412,7 @@ def setup_harn(): transform=transform_test), } if True: + # Create a test train split learn = datasets['train'] indices = np.arange(len(learn)) indices = nh.util.shuffle(indices, rng=0) @@ -411,7 +421,7 @@ def setup_harn(): datasets['train'] = torch.utils.data.Subset(learn, indices[num_vali:]) # For some reason the torchvision objects do not make the category names - # easilly available. We set them here for ease of use. + # easily available. We set them here for ease of use. reduction = int(ub.argval('--reduction', default=1)) for key, dset in datasets.items(): dset.categories = categories @@ -555,7 +565,7 @@ def setup_harn(): # Notice that arguments to hyperparameters are typically specified as a # tuple of (type, Dict), where the dictionary are the keyword arguments - # that can be used to instanciate an instance of that class. While + # that can be used to instantiate an instance of that class. While # this may be slightly awkward, it enables netharn to track hyperparameters # more effectively. Note that it is possible to simply pass an already # constructed instance of a class, but this causes information loss. @@ -563,11 +573,13 @@ def setup_harn(): # Datasets must be preconstructed datasets=datasets, nice='cifar10_' + config['arch'], - # Loader preconstructed + # Loader may be preconstructed loaders=loaders, workdir=config['workdir'], xpu=xpu, # The 6 major hyper components are best specified as a Tuple[type, dict] + # However, in recent releases of netharn, these may be preconstructed + # as well. model=model_, optimizer=optimizer_, scheduler=scheduler_, @@ -586,6 +598,12 @@ def setup_harn(): # Specify anything else that is special about your hyperparams here # Especially if you make a custom_batch_runner }, + # These extra arguments are recorded in the train_info.json but do + # not contribute to the hyperparameter hash. + extra={ + 'config': ub.repr2(config.asdict()), + 'argv': sys.argv, + } ) # Creating an instance of a Fitharn object is typically fast. @@ -606,7 +624,7 @@ def main(): if ub.argval(('--vd', '--view-directory')): ub.startfile(harn.train_dpath) - # This starts the main loop which will run until a the monitor's terminator + # This starts the main loop which will run until the monitor's terminator # criterion is satisfied. If the initialize step loaded a checkpointed that # already met the termination criterion, then this will simply return. deploy_fpath = harn.run() diff --git a/netharn/fit_harn.py b/netharn/fit_harn.py index 1118987..79b6d59 100644 --- a/netharn/fit_harn.py +++ b/netharn/fit_harn.py @@ -1934,6 +1934,11 @@ class CoreCallbacks(object): TODO: - [ ] perhaps remove dynamics as a netharn core component and simply allow the end-application to take care of that detail. + + Args: + bx (int): the current batch index + batch (object): the current batch + loss (Tensor): the loss computed in `run_batch`. """ loss.backward() -- GitLab From b9fc98e0781b2728767881fa6731f64270abd92a Mon Sep 17 00:00:00 2001 From: joncrall Date: Tue, 19 Nov 2019 09:53:29 -0500 Subject: [PATCH 08/24] Improve docs --- netharn/fit_harn.py | 144 ++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 139 insertions(+), 5 deletions(-) diff --git a/netharn/fit_harn.py b/netharn/fit_harn.py index 79b6d59..e812354 100644 --- a/netharn/fit_harn.py +++ b/netharn/fit_harn.py @@ -665,6 +665,15 @@ class ProgMixin(object): return Prog(*args, **kw) def _batch_msg(harn, metric_dict, batch_size, learn=False): + """ + Args: + metric_dict (dict): metrics to be reported in the message + batch_size (int): size of the current batch + learn (bool): formats a message for train or vali/test. + + Returns: + str : the message to be used in the progress bar + """ parts = ['{}:{:.4g}'.format(k, v) for k, v in metric_dict.items()] if harn.config['prog_backend'] == 'progiter': if learn and harn.scheduler and getattr(harn.scheduler, '__batchaware__', False): @@ -735,9 +744,21 @@ class LogMixin(object): pass def log(harn, msg): + """ + Logs an info message. Alias of :func:LogMixin.info + + Args: + msg (str): an info message to log + """ harn.info(msg) def info(harn, msg): + """ + Writes an info message to the logs + + Args: + msg (str): an info message to log + """ harn._ensure_prog_newline() if harn._log: try: @@ -748,6 +769,12 @@ class LogMixin(object): print(msg) def error(harn, msg): + """ + Writes an error message to the logs + + Args: + msg (str): an error message to log + """ harn._ensure_prog_newline() if harn._log: msg = strip_ansi(msg) @@ -756,6 +783,12 @@ class LogMixin(object): print(msg) def warn(harn, msg): + """ + Writes a warning message to the logs + + Args: + msg (str): a warning message to log + """ harn._ensure_prog_newline() if harn._log: msg = strip_ansi(msg) @@ -764,6 +797,12 @@ class LogMixin(object): print(msg) def debug(harn, msg): + """ + Writes a debug message to the logs + + Args: + msg (str): a debug message to log + """ if harn._log: msg = strip_ansi(six.text_type(msg)) # Encode to prevent errors on windows terminals @@ -838,6 +877,10 @@ class SnapshotMixin(object): @property def snapshot_dpath(harn): + """ + Returns: + str : path to the snapshot directory + """ # TODO: we should probably change the name of this directory to either # snapshots or checkpoints for simplicity. if harn.train_dpath is None: @@ -853,6 +896,9 @@ class SnapshotMixin(object): Keeps `num_keep_recent` most recent, `num_keep_best` best, and one every `keep_freq` epochs. + Returns: + set: epoch numbers to remove + Doctest: >>> import netharn as nh >>> harn = FitHarn({}) @@ -919,6 +965,9 @@ class SnapshotMixin(object): def backtrack_weights(harn, epoch): """ Reset the weights to a previous good state + + Args: + epoch (int): the epoch to backtrack to """ load_path = join(harn.snapshot_dpath, '_epoch_{:08d}.pt'.format(epoch)) snapshot = harn.xpu.load(load_path) @@ -1072,6 +1121,9 @@ class ScheduleMixin(object): def _current_lrs(harn): """ Get the of distinct learning rates (usually only 1) currently in use + + Returns: + List[float]: list of current learning rates """ # optim_lrs = {group['lr'] for group in harn.optimizer.param_groups} optim_lrs = [group['lr'] for group in harn.optimizer.param_groups] @@ -1127,6 +1179,10 @@ class ScheduleMixin(object): """ helper function to change the learning rate that handles the way that different schedulers might be used. + + Args: + improved (bool | None): if specified flags if the validation + metrics have improved (used by ReduceLROnPlateau scheduler) """ epoch_that_just_finished = harn.epoch if harn.scheduler is None: @@ -1199,6 +1255,14 @@ class CoreMixin(object): checkpointed that already met the termination criterion, then this will simply return. + Notes: + If harn.config['keyboard_debug'] is True, then pressing Ctrl+C + while this is running will result in an interactive prompt which + allows some amount of manual control over the training run. + + Raises: + TrainingDiverged: if training fails due to numerical issues + Returns: PathLike: deploy_fpath: the path to the standalone deployed model """ @@ -1361,7 +1425,12 @@ class CoreMixin(object): return deploy_fpath def _export(harn): - """ Export the model topology to the train_dpath """ + """ + Export the model topology to the train_dpath + + Returns: + str: path to the exported model topology + """ # TODO: might be good to check for multiple model exports at this time harn.debug('exporting model topology') static_modpath = None @@ -1383,6 +1452,9 @@ class CoreMixin(object): Packages the best validation (or most recent) weights with the exported model topology into a single-file model deployment that is "mostly" independent of the code used to train the model. + + Returns: + str: path to the deploy zipfile. """ harn._export() harn.debug('packaging deploying model') @@ -1416,6 +1488,11 @@ class CoreMixin(object): def _run_tagged_epochs(harn, train_loader, vali_loader, test_loader): """ Runs one epoch of train, validation, and testing + + Args: + train_loader (torch.utils.data.DataLoader | None): train loader + vali_loader (torch.utils.data.DataLoader | None): vali loader + test_loader (torch.utils.data.DataLoader | None): test loader """ if harn._check_termination(): raise StopTraining() @@ -1516,10 +1593,11 @@ class CoreMixin(object): evaluate the model on test / train / or validation data Args: - loader : the loader for your current data split (this will - usually be harn.loaders[tag] + loader (torch.utils.data.DataLoader): + the loader for your current data split (this will usually be + ``harn.loaders[tag]``) - tag : the label for the loader's data split + tag (str) : the label for the loader's data split learn (bool, default=False): if True, the weights of harn.model are updated by harn.optimizer @@ -1713,7 +1791,29 @@ class CoreMixin(object): @profiler.profile def _on_batch(harn, bx, batch, outputs, loss, loss_parts=None): - """ Internal function that prepares to call the `on_batch` callback. """ + """ + Internal function that prepares to call the + :func:CoreCallbacks.on_batch callback. + + Args: + bx (int): the current batch index + + batch (object): the current batch + + outputs (object): the first result of :func:CoreCallbacks.run_batch + These are the raw network outputs. + + loss (Tensor): the second result of :func:CoreCallbacks.run_batch + This is the batch loss computed by the criterion. + + loss_parts (Dict[str, Tensor]): components of the loss to be + individually logged. + + Returns: + Dict[str, float]: dictionary of logged metrics. This is the + union of the metrics returned by the user as well as addition + loss information added in this function. + """ loss_value = float(loss.data.cpu().item()) loss_value = harn._check_loss(loss_value) @@ -1739,6 +1839,11 @@ class ChecksMixin(object): def _check_gradients(harn): """ + Checks that the the accumulated gradients are all finite. + + Raises: + TrainingDiverged: if checks fail + Example: harn = ... all_grads = harn._check_gradients() @@ -1756,6 +1861,12 @@ class ChecksMixin(object): @profiler.profile def _check_loss(harn, loss_value): + """ + Checks that the the loss is not too large + + Raises: + TrainingDiverged: if checks fail + """ if not np.isfinite(loss_value): harn.warn('WARNING: got inf loss, setting loss to a large value') loss_value = harn.config['large_loss'] * 10 @@ -1768,6 +1879,12 @@ class ChecksMixin(object): @profiler.profile def _check_divergence(harn): + """ + Checks that the model weights are all finite + + Raises: + TrainingDiverged: if checks fail + """ # Eventually we may need to remove # num_batches_tracked once 0.5.0 lands state = harn.model.module.state_dict() @@ -1868,6 +1985,13 @@ class CoreCallbacks(object): ensure batch is in a standardized structure Overload Encouraged, but not always necessary + + Args: + raw_batch (object): the raw batch generated by the loader + + Returns: + object: batch - the prepared batch where relevant inputs have + been moved onto the appropriate XPU(s). """ try: if isinstance(raw_batch, (tuple, list)): @@ -1901,6 +2025,9 @@ class CoreCallbacks(object): tensors. In this case, the total loss will be the sum of the values and each loss component will be automatically logged. + Args: + batch (object): the current batch + Returns: Tuple[object, Tensor|Dict]: (outputs, loss) """ @@ -1976,6 +2103,13 @@ class CoreCallbacks(object): Overload Encouraged + Args: + batch (object): the current batch + outputs (object): the first result of :func:CoreCallbacks.run_batch + These are the raw network outputs. + loss (object): the second result of :func:CoreCallbacks.run_batch + This is the batch loss computed by the criterion. + Returns: dict or None: dictionary of scalar batch measures """ -- GitLab From 141f169d86bf361f9e386f56de28d826978a7527 Mon Sep 17 00:00:00 2001 From: joncrall Date: Tue, 19 Nov 2019 10:24:21 -0500 Subject: [PATCH 09/24] Fix segmentation example and focal loss --- CHANGELOG.md | 1 + examples/segmentation.py | 20 ++++- examples/sseg_camvid.py | 173 ++++++++++++++++++++---------------- netharn/criterions/focal.py | 2 + 4 files changed, 118 insertions(+), 78 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9a6569a..a26085b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,7 @@ This project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.htm ### Fixed * Small issues in CIFAR Example +* Small `imgaug` issue in `examples/sseg_camvid.py` and `examples/segmentation.py` ## Version 0.5.1 diff --git a/examples/segmentation.py b/examples/segmentation.py index 1560291..f431106 100644 --- a/examples/segmentation.py +++ b/examples/segmentation.py @@ -147,9 +147,17 @@ class SegmentationDataset(torch.utils.data.Dataset): if self.augmenter: augdet = self.augmenter.to_deterministic() imdata = augdet.augment_image(imdata) - cidx_segmap_oi = imgaug.SegmentationMapOnImage(cidx_segmap, cidx_segmap.shape, nb_classes=len(self.classes)) - cidx_segmap_oi = augdet.augment_segmentation_maps([cidx_segmap_oi])[0] - cidx_segmap = cidx_segmap_oi.arr.argmax(axis=2) + if hasattr(imgaug, 'SegmentationMapsOnImage'): + # Oh imgaug, stop breaking. + cidx_segmap_oi = imgaug.SegmentationMapsOnImage(cidx_segmap, cidx_segmap.shape) + cidx_segmap_oi = augdet.augment_segmentation_maps(cidx_segmap_oi) + assert cidx_segmap_oi.arr.shape[2] == 1 + cidx_segmap = cidx_segmap_oi.arr[..., 0] + cidx_segmap = np.ascontiguousarray(cidx_segmap) + else: + cidx_segmap_oi = imgaug.SegmentationMapOnImage(cidx_segmap, cidx_segmap.shape, nb_classes=len(self.classes)) + cidx_segmap_oi = augdet.augment_segmentation_maps([cidx_segmap_oi])[0] + cidx_segmap = cidx_segmap_oi.arr.argmax(axis=2) im_chw = torch.FloatTensor( imdata.transpose(2, 0, 1).astype(np.float32) / 255.) @@ -361,6 +369,12 @@ class SegmentationHarn(nh.FitHarn): true_img = kwimage.ensure_uint255(true_img) pred_img = kwimage.ensure_uint255(pred_img) + true_img = kwimage.draw_text_on_image( + true_img, 'true', org=(0, 0), valign='top', color='blue') + + pred_img = kwimage.draw_text_on_image( + pred_img, 'pred', org=(0, 0), valign='top', color='blue') + item_img = kwimage.stack_images([pred_img, true_img], axis=1) batch_imgs.append(item_img) diff --git a/examples/sseg_camvid.py b/examples/sseg_camvid.py index a9ed60f..5c8e6f7 100644 --- a/examples/sseg_camvid.py +++ b/examples/sseg_camvid.py @@ -1,4 +1,12 @@ # -*- coding: utf-8 -*- +""" +An train an example semenatic segmenation model on the CamVid dataset. +For a more general segmentation example that works with any (ndsampler-style) +MS-COCO dataset see segmentation.py. + +CommandLine: + python ~/code/netharn/examples/sseg_camvid.py --workers=4 --xpu=0 --batch_size=2 --nice=expt1 +""" from __future__ import absolute_import, division, print_function, unicode_literals from os.path import join import ubelt as ub @@ -15,6 +23,56 @@ import imgaug.augmenters as iaa import imgaug +class SegmentationConfig(scfg.Config): + """ + Default configuration for setting up a training session + """ + default = { + 'nice': scfg.Path('untitled', help='A human readable tag that is "nice" for humans'), + 'workdir': scfg.Path('~/work/camvid', help='Dump all results in your workdir'), + + 'workers': scfg.Value(0, help='number of parallel dataloading jobs'), + 'xpu': scfg.Value('argv', help='See netharn.XPU for details. can be cpu/gpu/cuda0/0,1,2,3)'), + + 'augment': scfg.Value('simple', help='type of training dataset augmentation'), + 'class_weights': scfg.Value('log-median-idf', help='how to weight inbalanced classes'), + # 'class_weights': scfg.Value(None, help='how to weight inbalanced classes'), + + 'datasets': scfg.Value('special:camvid', help='Eventually you may be able to sepcify a coco file'), + 'train_dataset': scfg.Value(None), + 'vali_dataset': scfg.Value(None), + + 'arch': scfg.Value('psp', help='Network architecture code'), + 'optim': scfg.Value('adamw', help='Weight optimizer. Can be SGD, ADAM, ADAMW, etc..'), + + 'input_dims': scfg.Value((128, 128), help='Window size to input to the network'), + 'input_overlap': scfg.Value(0.25, help='amount of overlap when creating a sliding window dataset'), + + 'batch_size': scfg.Value(4, help='number of items per batch'), + 'bstep': scfg.Value(1, help='number of batches before a gradient descent step'), + + 'max_epoch': scfg.Value(140, help='Maximum number of epochs'), + 'patience': scfg.Value(140, help='Maximum "bad" validation epochs before early stopping'), + + 'lr': scfg.Value(1e-3, help='Base learning rate'), + 'decay': scfg.Value(1e-5, help='Base weight decay'), + + 'focus': scfg.Value(2.0, help='focus for focal loss'), + + 'schedule': scfg.Value('step90', help=('Special coercable netharn code. Eg: onecycle50, step50, gamma')), + + 'init': scfg.Value('kaiming_normal', help='How to initialized weights. (can be a path to a pretrained model)'), + 'pretrained': scfg.Path(help=('alternative way to specify a path to a pretrained model')), + } + + def normalize(self): + if self['pretrained'] in ['null', 'None']: + self['pretrained'] = None + + if self['pretrained'] is not None: + self['init'] = 'pretrained' + + class SegmentationDataset(torch.utils.data.Dataset): """ Efficient loader for training on a sementic segmentation dataset @@ -24,14 +82,14 @@ class SegmentationDataset(torch.utils.data.Dataset): >>> #input_dims = (224, 224) >>> input_dims = (512, 512) >>> self = dset = SegmentationDataset(sampler, input_dims) - >>> output = self[10] + >>> item = self[10] >>> # xdoctest: +REQUIRES(--show) >>> import kwplot >>> plt = kwplot.autoplt() - >>> cidxs = output['class_idxs'] + >>> cidxs = item['class_idxs'] >>> colored_labels = self._colorized_labels(cidxs) >>> kwplot.figure(doclf=True) - >>> kwplot.imshow(output['im']) + >>> kwplot.imshow(item['im']) >>> kwplot.imshow(colored_labels, alpha=.4) Example: @@ -41,16 +99,16 @@ class SegmentationDataset(torch.utils.data.Dataset): >>> plt = kwplot.autoplt() >>> indices = list(range(len(self))) >>> for index in xdev.InteractiveIter(indices): - >>> output = self[index] - >>> cidxs = output['class_idxs'] + >>> item = self[index] + >>> cidxs = item['class_idxs'] >>> colored_labels = self._colorized_labels(cidxs) >>> kwplot.figure(doclf=True) - >>> kwplot.imshow(output['im']) + >>> kwplot.imshow(item['im']) >>> kwplot.imshow(colored_labels, alpha=.4) >>> xdev.InteractiveIter.draw() """ def __init__(self, sampler, input_dims=(224, 224), input_overlap=0.5, - augmenter=False): + augment=False): self.input_dims = None self.input_id = None self.cid_to_cidx = None @@ -68,30 +126,31 @@ class SegmentationDataset(torch.utils.data.Dataset): self.cid_to_cidx = sampler.catgraph.id_to_idx self.classes = sampler.catgraph + self.augmenter = self._rectify_augmenter(augment) + # Create a slider for every image self._build_sliders(input_dims=input_dims, input_overlap=input_overlap) - self.augmenter = self._rectify_augmenter(augmenter) - def _rectify_augmenter(self, augmenter): + def _rectify_augmenter(self, augment): import netharn as nh - if augmenter is True: - augmenter = 'simple' + if augment is True: + augment = 'simple' - if not augmenter: + if not augment: augmenter = None - elif augmenter == 'simple': + elif augment == 'simple': augmenter = iaa.Sequential([ iaa.Crop(percent=(0, .2)), iaa.Fliplr(p=.5) ]) - elif augmenter == 'complex': + elif augment == 'complex': augmenter = iaa.Sequential([ iaa.Sometimes(0.2, nh.data.transforms.HSVShift(hue=0.1, sat=1.5, val=1.5)), iaa.Crop(percent=(0, .2)), iaa.Fliplr(p=.5) ]) else: - raise KeyError('Unknown augmentation {!r}'.format(self.augment)) + raise KeyError('Unknown augmentation {!r}'.format(augment)) return augmenter def _build_sliders(self, input_dims=(224, 224), input_overlap=0.5): @@ -121,14 +180,14 @@ class SegmentationDataset(torch.utils.data.Dataset): def __getitem__(self, index): """ Example: - >>> self = SegmentationDataset.demo(augment=True) - >>> output = self[10] + >>> self = SegmentationDataset.demo(augment='complex') + >>> item = self[10] >>> # xdoctest: +REQUIRES(--show) >>> import kwplot >>> plt = kwplot.autoplt() - >>> colored_labels = self._colorized_labels(output['class_idxs']) + >>> colored_labels = self._colorized_labels(item['class_idxs']) >>> kwplot.figure(doclf=True) - >>> kwplot.imshow(output['im']) + >>> kwplot.imshow(item['im']) >>> kwplot.imshow(colored_labels, alpha=.4) """ outer, inner = self.subindex.unravel(index) @@ -147,9 +206,17 @@ class SegmentationDataset(torch.utils.data.Dataset): if self.augmenter: augdet = self.augmenter.to_deterministic() imdata = augdet.augment_image(imdata) - cidx_segmap_oi = imgaug.SegmentationMapOnImage(cidx_segmap, cidx_segmap.shape, nb_classes=len(self.classes)) - cidx_segmap_oi = augdet.augment_segmentation_maps([cidx_segmap_oi])[0] - cidx_segmap = cidx_segmap_oi.arr.argmax(axis=2) + if hasattr(imgaug, 'SegmentationMapsOnImage'): + # Oh imgaug, stop breaking. + cidx_segmap_oi = imgaug.SegmentationMapsOnImage(cidx_segmap, cidx_segmap.shape) + cidx_segmap_oi = augdet.augment_segmentation_maps(cidx_segmap_oi) + assert cidx_segmap_oi.arr.shape[2] == 1 + cidx_segmap = cidx_segmap_oi.arr[..., 0] + cidx_segmap = np.ascontiguousarray(cidx_segmap) + else: + cidx_segmap_oi = imgaug.SegmentationMapOnImage(cidx_segmap, cidx_segmap.shape, nb_classes=len(self.classes)) + cidx_segmap_oi = augdet.augment_segmentation_maps([cidx_segmap_oi])[0] + cidx_segmap = cidx_segmap_oi.arr.argmax(axis=2) im_chw = torch.FloatTensor( imdata.transpose(2, 0, 1).astype(np.float32) / 255.) @@ -157,12 +224,12 @@ class SegmentationDataset(torch.utils.data.Dataset): cidxs = torch.LongTensor(cidx_segmap) weight = (1 - (cidxs == 0).float()) - output = { + item = { 'im': im_chw, 'class_idxs': cidxs, 'weight': weight, } - return output + return item def _sample_to_sseg_heatmap(self, imdata, sample): annots = sample['annots'] @@ -345,6 +412,12 @@ class SegmentationHarn(nh.FitHarn): true_img = kwimage.ensure_uint255(true_img) pred_img = kwimage.ensure_uint255(pred_img) + true_img = kwimage.draw_text_on_image( + true_img, 'true', org=(0, 0), valign='top', color='blue') + + pred_img = kwimage.draw_text_on_image( + pred_img, 'pred', org=(0, 0), valign='top', color='blue') + item_img = kwimage.stack_images([pred_img, true_img], axis=1) batch_imgs.append(item_img) @@ -704,56 +777,6 @@ class PredSlidingWindowDataset(torch_data.Dataset, ub.NiceRepr): return batch_item -class SegmentationConfig(scfg.Config): - """ - Default configuration for setting up a training session - """ - default = { - 'nice': scfg.Path('untitled', help='A human readable tag that is "nice" for humans'), - 'workdir': scfg.Path('~/work/camvid', help='Dump all results in your workdir'), - - 'workers': scfg.Value(0, help='number of parallel dataloading jobs'), - 'xpu': scfg.Value('argv', help='See netharn.XPU for details. can be cpu/gpu/cuda0/0,1,2,3)'), - - 'augmenter': scfg.Value('simple', help='type of training dataset augmentation'), - 'class_weights': scfg.Value('log-median-idf', help='how to weight inbalanced classes'), - # 'class_weights': scfg.Value(None, help='how to weight inbalanced classes'), - - 'datasets': scfg.Value('special:camvid', help='Eventually you may be able to sepcify a coco file'), - 'train_dataset': scfg.Value(None), - 'vali_dataset': scfg.Value(None), - - 'arch': scfg.Value('unet', help='Network architecture code'), - 'optim': scfg.Value('adam', help='Weight optimizer. Can be SGD, ADAM, ADAMW, etc..'), - - 'input_dims': scfg.Value((128, 128), help='Window size to input to the network'), - 'input_overlap': scfg.Value(0.25, help='amount of overlap when creating a sliding window dataset'), - - 'batch_size': scfg.Value(4, help='number of items per batch'), - 'bstep': scfg.Value(1, help='number of batches before a gradient descent step'), - - 'max_epoch': scfg.Value(140, help='Maximum number of epochs'), - 'patience': scfg.Value(140, help='Maximum "bad" validation epochs before early stopping'), - - 'lr': scfg.Value(1e-3, help='Base learning rate'), - 'decay': scfg.Value(1e-5, help='Base weight decay'), - - 'focus': scfg.Value(2.0, help='focus for focal loss'), - - 'schedule': scfg.Value('step90', help=('Special coercable netharn code. Eg: onecycle50, step50, gamma')), - - 'init': scfg.Value('kaiming_normal', help='How to initialized weights. (can be a path to a pretrained model)'), - 'pretrained': scfg.Path(help=('alternative way to specify a path to a pretrained model')), - } - - def normalize(self): - if self['pretrained'] in ['null', 'None']: - self['pretrained'] = None - - if self['pretrained'] is not None: - self['init'] = 'pretrained' - - def setup_coco_datasets(): """ TODO: @@ -931,7 +954,7 @@ def setup_harn(cmdline=True, **kw): sampler, config['input_dims'], input_overlap=((tag == 'train') and config['input_overlap']), - augmenter=((tag == 'train') and config['augmenter']), + augment=((tag == 'train') and config['augment']), ) for tag, sampler in samplers.items() } diff --git a/netharn/criterions/focal.py b/netharn/criterions/focal.py index bfc6c01..d038d3a 100644 --- a/netharn/criterions/focal.py +++ b/netharn/criterions/focal.py @@ -288,6 +288,8 @@ class FocalLoss(torch.nn.modules.loss._WeightedLoss): reduction=ELEMENTWISE_MEAN, ignore_index=-100): size_average, reduce, reduction = _backwards_compat_reduction_kw( size_average, reduce, reduction) + if isinstance(weight, list): + weight = torch.FloatTensor(weight) super(FocalLoss, self).__init__(weight=weight, reduction=reduction) self.focus = focus self.ignore_index = ignore_index -- GitLab From b459bcd6e5f25db3006d5eef8aeb8b6a7445b912 Mon Sep 17 00:00:00 2001 From: joncrall Date: Tue, 19 Nov 2019 12:50:58 -0500 Subject: [PATCH 10/24] Fix batch size issue --- CHANGELOG.md | 1 + netharn/fit_harn.py | 13 +++++++++++-- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a26085b..32c096f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ This project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.htm ### Fixed * Small issues in CIFAR Example * Small `imgaug` issue in `examples/sseg_camvid.py` and `examples/segmentation.py` +* FitHarn no longer fails when loaders are missing batch sizes ## Version 0.5.1 diff --git a/netharn/fit_harn.py b/netharn/fit_harn.py index e812354..bf03353 100644 --- a/netharn/fit_harn.py +++ b/netharn/fit_harn.py @@ -1609,7 +1609,14 @@ class CoreMixin(object): """ harn.debug('_run_epoch {}, tag={}, learn={}'.format(harn.epoch, tag, learn)) harn.debug(' * len(loader) = {}'.format(len(loader))) - harn.debug(' * loader.batch_size = {}'.format(loader.batch_size)) + + try: + bsize = loader.batch_sampler.batch_size + except AttributeError: + # Some loaders might have variable batch sizes + bsize = None + + harn.debug(' * loader.batch_sampler.batch_size = {}'.format(bsize)) harn.current_tag = tag @@ -1625,7 +1632,6 @@ class CoreMixin(object): # call prepare epoch hook harn.prepare_epoch() - bsize = loader.batch_sampler.batch_size msg = harn._batch_msg({'loss': -1}, bsize, learn) desc = tag + ' ' + msg if harn.main_prog is None: @@ -2382,6 +2388,9 @@ class FitHarn(ExtraMixins, InitializeMixin, ProgMixin, LogMixin, SnapshotMixin, idx (int): the iteration number first (bool, default=False): if True, trigger on the first iteration, otherwise dont. + + Returns: + bool: if it is time to do something or not """ n = harn.intervals[tag] if n is None: -- GitLab From 4d056cdee9cfd859dbec80fdcf2c54433da9e8bc Mon Sep 17 00:00:00 2001 From: Jon Crall Date: Tue, 19 Nov 2019 13:57:15 -0500 Subject: [PATCH 11/24] main name --- examples/mnist.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/mnist.py b/examples/mnist.py index f669976..67162ff 100644 --- a/examples/mnist.py +++ b/examples/mnist.py @@ -300,7 +300,7 @@ def setup_harn(**kw): return harn -def train_mnist(): +def main(): harn = setup_harn() reset = ub.argflag('--reset') @@ -330,4 +330,4 @@ if __name__ == '__main__': tensorboard --logdir ~/data/work/mnist/fit/nice """ - train_mnist() + main() -- GitLab From 305131a103db3200bdea7091bada6ca396d20e3d Mon Sep 17 00:00:00 2001 From: joncrall Date: Thu, 21 Nov 2019 11:47:18 -0500 Subject: [PATCH 12/24] dont use deprecated ubelt funcs --- dev/debug_minst.py | 2 +- dev/manage_snapshots.py | 4 ++-- examples/ggr_matching.py | 4 ++-- examples/grab_voc.py | 2 +- examples/yolo_voc.py | 6 +++--- netharn/data/transforms/augmenters.py | 2 +- netharn/data/voc.py | 4 ++-- netharn/export/deployer.py | 4 ++-- netharn/fit_harn.py | 4 ++-- netharn/util/mplutil.py | 4 ++-- netharn/util/util_fname.py | 4 ++-- requirements/runtime.txt | 2 +- setup.py | 4 ++-- super_setup.py | 4 ++-- 14 files changed, 25 insertions(+), 25 deletions(-) diff --git a/dev/debug_minst.py b/dev/debug_minst.py index 3b5c733..abf09fc 100644 --- a/dev/debug_minst.py +++ b/dev/debug_minst.py @@ -1,7 +1,7 @@ # flake8: noqa import sys import ubelt as ub -sys.path.append(ub.truepath('~/code/netharn/examples')) +sys.path.append(ub.expandpath('~/code/netharn/examples')) from mnist_matching import setup_harn resnet_harn = setup_harn(arch='resnet').initialize() diff --git a/dev/manage_snapshots.py b/dev/manage_snapshots.py index 5b1fbfd..50264be 100755 --- a/dev/manage_snapshots.py +++ b/dev/manage_snapshots.py @@ -74,7 +74,7 @@ def _devcheck_remove_dead_runs(workdir, dry=True, dead_num_snap_thresh=10, """ import ubelt as ub - # workdir = ub.truepath('~/work/foobar') + # workdir = ub.expandpath('~/work/foobar') print('Checking for dead / dangling sessions in your runs dir') @@ -212,7 +212,7 @@ def _devcheck_manage_snapshots(workdir, recent=5, factor=10, dry=True): places where there is a significant change from a global perspective. # Specify your workdir - workdir = ub.truepath('~/work/voc_yolo2') + workdir = ub.expandpath('~/work/voc_yolo2') """ USE_RANGE_HUERISTIC = True diff --git a/examples/ggr_matching.py b/examples/ggr_matching.py index 2f4e3fd..87712ff 100644 --- a/examples/ggr_matching.py +++ b/examples/ggr_matching.py @@ -504,7 +504,7 @@ def setup_sampler(config): if config['dbname'] == 'ggr2': print('Creating torch CocoDataset') - root = ub.truepath('~/data/') + root = ub.expandpath('~/data/') print('root = {!r}'.format(root)) train_dset = ndsampler.CocoDataset( @@ -527,7 +527,7 @@ def setup_sampler(config): if config['dbname'] == 'ggr2-revised': print('Creating torch CocoDataset') - root = ub.truepath('~/data/ggr2.coco.revised') + root = ub.expandpath('~/data/ggr2.coco.revised') print('root = {!r}'.format(root)) train_dset = ndsampler.CocoDataset( diff --git a/examples/grab_voc.py b/examples/grab_voc.py index 29d479a..b1cb1cd 100644 --- a/examples/grab_voc.py +++ b/examples/grab_voc.py @@ -202,7 +202,7 @@ def ensure_voc_data(dpath=None, force=False, years=[2007, 2012]): >>> devkit_dpath = ensure_voc_data() """ if dpath is None: - dpath = ub.truepath('~/data/VOC') + dpath = ub.expandpath('~/data/VOC') devkit_dpath = join(dpath, 'VOCdevkit') # if force or not exists(devkit_dpath): ub.ensuredir(dpath) diff --git a/examples/yolo_voc.py b/examples/yolo_voc.py index 14aab46..b3c624f 100644 --- a/examples/yolo_voc.py +++ b/examples/yolo_voc.py @@ -88,7 +88,7 @@ class YoloVOCDataset(nh.data.voc.VOCDataset): Example: >>> # DISABLE_DOCTSET >>> import sys, ubelt - >>> sys.path.append(ubelt.truepath('~/code/netharn/examples')) + >>> sys.path.append(ubelt.expandpath('~/code/netharn/examples')) >>> from yolo_voc import * >>> self = YoloVOCDataset(split='train') >>> index = 7 @@ -109,7 +109,7 @@ class YoloVOCDataset(nh.data.voc.VOCDataset): Example: >>> # DISABLE_DOCTSET >>> import sys, ubelt - >>> sys.path.append(ubelt.truepath('~/code/netharn/examples')) + >>> sys.path.append(ubelt.expandpath('~/code/netharn/examples')) >>> from yolo_voc import * >>> self = YoloVOCDataset(split='test') >>> index = 0 @@ -769,7 +769,7 @@ def setup_yolo_harness(bsize=16, workers=0): hyper = nh.HyperParams(**{ 'nice': nice, - 'workdir': ub.truepath('~/work/voc_yolo2'), + 'workdir': ub.expandpath('~/work/voc_yolo2'), 'datasets': datasets, 'loaders': loaders, diff --git a/netharn/data/transforms/augmenters.py b/netharn/data/transforms/augmenters.py index 1de16cf..f084e00 100644 --- a/netharn/data/transforms/augmenters.py +++ b/netharn/data/transforms/augmenters.py @@ -88,7 +88,7 @@ class HSVShift(augmenter_base.ParamatarizedAugmenter): Ignore: >>> from netharn.data.transforms.augmenters import * - >>> lnpre = ub.import_module_from_path(ub.truepath('~/code/lightnet/lightnet/data/transform/_preprocess.py')) + >>> lnpre = ub.import_module_from_path(ub.expandpath('~/code/lightnet/lightnet/data/transform/_preprocess.py')) >>> self = lnpre.HSVShift(0.1, 1.5, 1.5) >>> from PIL import Image >>> img = demodata_hsv_image() diff --git a/netharn/data/voc.py b/netharn/data/voc.py index e3ef463..74cf1cf 100644 --- a/netharn/data/voc.py +++ b/netharn/data/voc.py @@ -56,7 +56,7 @@ class VOCDataset(torch_data.Dataset, ub.NiceRepr): """ def __init__(self, devkit_dpath=None, split='train', years=[2007, 2012]): if devkit_dpath is None: - # ub.truepath('~/data/VOC/VOCdevkit') + # ub.expandpath('~/data/VOC/VOCdevkit') devkit_dpath = self.ensure_voc_data(years=years) self.devkit_dpath = devkit_dpath @@ -139,7 +139,7 @@ class VOCDataset(torch_data.Dataset, ub.NiceRepr): >>> VOCDataset.ensure_voc_data() """ if dpath is None: - dpath = ub.truepath('~/data/VOC') + dpath = ub.expandpath('~/data/VOC') devkit_dpath = join(dpath, 'VOCdevkit') # if force or not exists(devkit_dpath): ub.ensuredir(dpath) diff --git a/netharn/export/deployer.py b/netharn/export/deployer.py index 78b5761..b8733c1 100644 --- a/netharn/export/deployer.py +++ b/netharn/export/deployer.py @@ -445,10 +445,10 @@ class DeployedModel(ub.NiceRepr): Ignore: from netharn.export.deployer import * - fcnn116 = ub.import_module_from_path(ub.truepath('~/remote/hermes/tmp/fcnn116.py')) + fcnn116 = ub.import_module_from_path(ub.expandpath('~/remote/hermes/tmp/fcnn116.py')) model = fcnn116.FCNN116() initkw = {} - snap_fpath = ub.truepath('~/remote/hermes/tmp/fcnn116.pt') + snap_fpath = ub.expandpath('~/remote/hermes/tmp/fcnn116.pt') train_info_fpath = None self = DeployedModel.custom(snap_fpath, model, initkw) zipfile = self.package(dpath) diff --git a/netharn/fit_harn.py b/netharn/fit_harn.py index bf03353..4d7c6b8 100644 --- a/netharn/fit_harn.py +++ b/netharn/fit_harn.py @@ -1285,7 +1285,7 @@ class CoreMixin(object): if harn._tlog is not None: train_base = os.path.dirname(harn.nice_dpath or harn.train_dpath) harn.info('dont forget to start:\n' - ' tensorboard --logdir ' + ub.compressuser(train_base)) + ' tensorboard --logdir ' + ub.shrinkuser(train_base)) try: if harn._check_termination(): @@ -1416,7 +1416,7 @@ class CoreMixin(object): harn.info('harn.train_dpath = {!r}'.format(harn.train_dpath)) harn.info('harn.nice_dpath = {!r}'.format(harn.nice_dpath)) harn.info('view tensorboard results for this run via:\n' - ' tensorboard --logdir ' + ub.compressuser(train_base)) + ' tensorboard --logdir ' + ub.shrinkuser(train_base)) deploy_fpath = harn._deploy() diff --git a/netharn/util/mplutil.py b/netharn/util/mplutil.py index 59c9544..f893b59 100644 --- a/netharn/util/mplutil.py +++ b/netharn/util/mplutil.py @@ -346,7 +346,7 @@ def _save_requested(fpath_, save_parts): fig = plt.gcf() fig.dpi = dpi - fpath_strict = ub.truepath(fpath) + fpath_strict = ub.expandpath(fpath) CLIP_WHITE = ub.argflag('--clipwhite') from netharn import util @@ -399,7 +399,7 @@ def _save_requested(fpath_, save_parts): savekw['dpi'] = dpi savekw['edgecolor'] = 'none' savekw['bbox_inches'] = extract_axes_extents(fig, combine=True) # replaces need for clipwhite - absfpath_ = ub.truepath(fpath) + absfpath_ = ub.expandpath(fpath) fig.savefig(absfpath_, **savekw) if CLIP_WHITE: diff --git a/netharn/util/util_fname.py b/netharn/util/util_fname.py index d336306..964e846 100644 --- a/netharn/util/util_fname.py +++ b/netharn/util/util_fname.py @@ -266,8 +266,8 @@ def align_paths(paths1, paths2): Speed: >>> import ubelt as ub - >>> paths1 = [ub.truepath('~/foo/{:04d}/{:04d}').format(i, j) for i in range(2) for j in range(10000)] - >>> paths2 = [ub.truepath('~/bar/{:04d}/{:04d}').format(i, j) for i in range(2) for j in range(10000)] + >>> paths1 = [ub.expandpath('~/foo/{:04d}/{:04d}').format(i, j) for i in range(2) for j in range(10000)] + >>> paths2 = [ub.expandpath('~/bar/{:04d}/{:04d}').format(i, j) for i in range(2) for j in range(10000)] >>> np.random.shuffle(paths2) >>> aligned = align_paths(paths1, paths2) diff --git a/requirements/runtime.txt b/requirements/runtime.txt index 265e61b..6f711e6 100644 --- a/requirements/runtime.txt +++ b/requirements/runtime.txt @@ -6,7 +6,7 @@ six >= 1.11.0 torch >= 1.0.0 numpy >= 1.14.5 scipy >= 1.2.1 -ubelt >= 0.8.2 +ubelt >= 0.8.4 progiter >= 0.0.2 parse >= 1.8.4 packaging >= 17.1 diff --git a/setup.py b/setup.py index dd291f6..09372b5 100755 --- a/setup.py +++ b/setup.py @@ -123,7 +123,7 @@ def parse_requirements(fname='requirements.txt'): def clean_repo(repodir, modname, rel_paths=[]): """ - repodir = ub.truepath('~/code/netharn/') + repodir = ub.expandpath('~/code/netharn/') modname = 'netharn' rel_paths = [ 'netharn/util/nms/cpu_nms.c', @@ -186,7 +186,7 @@ def clean_repo(repodir, modname, rel_paths=[]): def clean(): """ - __file__ = ub.truepath('~/code/netharn/setup.py') + __file__ = ub.expandpath('~/code/netharn/setup.py') """ modname = 'netharn' repodir = dirname(__file__) diff --git a/super_setup.py b/super_setup.py index dbec519..3e71e9d 100755 --- a/super_setup.py +++ b/super_setup.py @@ -418,10 +418,10 @@ class RepoRegistry(ub.NiceRepr): if cwd is None: cwd = os.get_cwd() if cwd != MY_CWD: - print('cd ' + ub.compressuser(cwd)) + print('cd ' + ub.shrinkuser(cwd)) MY_CWD = cwd print(cmd) - print('cd ' + ub.compressuser(ORIG_CWD)) + print('cd ' + ub.shrinkuser(ORIG_CWD)) def determine_code_dpath(): -- GitLab From 6693f2b13c2453b3d1c2d0e530cbab32ed161602 Mon Sep 17 00:00:00 2001 From: joncrall Date: Thu, 21 Nov 2019 18:07:35 -0500 Subject: [PATCH 13/24] remove dependency on packaging --- netharn/criterions/focal.py | 4 ++-- requirements/runtime.txt | 3 +-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/netharn/criterions/focal.py b/netharn/criterions/focal.py index d038d3a..2ebf344 100644 --- a/netharn/criterions/focal.py +++ b/netharn/criterions/focal.py @@ -3,10 +3,10 @@ import torch # NOQA import torch.nn.functional as F import torch.nn.modules import kwarray -from packaging import version +from distutils.version import LooseVersion -if version.parse(torch.__version__) < version.parse('1.0.0'): +if LooseVersion(torch.__version__) < LooseVersion('1.0.0'): ELEMENTWISE_MEAN = 'elementwise_mean' else: ELEMENTWISE_MEAN = 'mean' diff --git a/requirements/runtime.txt b/requirements/runtime.txt index 6f711e6..551e871 100644 --- a/requirements/runtime.txt +++ b/requirements/runtime.txt @@ -4,12 +4,11 @@ pandas >= 0.23.3 torchvision >= 0.2.1 six >= 1.11.0 torch >= 1.0.0 -numpy >= 1.14.5 +numpy >= 1.9.0 scipy >= 1.2.1 ubelt >= 0.8.4 progiter >= 0.0.2 parse >= 1.8.4 -packaging >= 17.1 pyflakes >= 1.6.0 astunparse >= 1.6.1 pygtrie >= 2.2 -- GitLab From 85db88a044f0143b98b073e05224a5bba3c416b1 Mon Sep 17 00:00:00 2001 From: joncrall Date: Thu, 21 Nov 2019 19:05:22 -0500 Subject: [PATCH 14/24] wip --- netharn/device.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/netharn/device.py b/netharn/device.py index c5304d6..f22b501 100644 --- a/netharn/device.py +++ b/netharn/device.py @@ -5,7 +5,6 @@ An XPU is an abstracted (X) procesesing unit (PU) with a common API for running torch operations on a CPU, GPU, or many GPUs. """ from __future__ import absolute_import, division, print_function -import psutil import ubelt as ub import warnings import torch @@ -410,6 +409,7 @@ class XPU(ub.NiceRepr): 'used': 0, } if self._device_ids is None: + import psutil tup = psutil.virtual_memory() MB = 1 / 2 ** 20 info['total'] += tup.total * MB -- GitLab From 8fb889954e90d06b7bcbc6a02718ac58d4a44372 Mon Sep 17 00:00:00 2001 From: joncrall Date: Thu, 21 Nov 2019 23:56:47 -0500 Subject: [PATCH 15/24] Cleanup dependencies --- CHANGELOG.md | 2 + netharn/data/voc.py | 7 +-- netharn/layers/gauss.py | 6 +-- netharn/metrics/_devcheck_detmetrics.py | 3 +- netharn/metrics/clf_report.py | 6 ++- netharn/metrics/confusion_vectors.py | 2 +- netharn/util/util_averages.py | 16 ++++-- netharn/util/util_json.py | 8 ++- netharn/util/util_subextreme.py | 2 +- requirements/optional.txt | 2 + requirements/runtime.txt | 8 --- requirements/super_setup.txt | 5 +- setup.py | 70 +++++++++++++------------ super_setup.py | 2 +- 14 files changed, 76 insertions(+), 63 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 32c096f..82d3c30 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,8 @@ This project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.htm ### Changed * Smoothing no longer applied to lr (learning rate) and momentum monitor plots +* pandas and scipy are now optional (in this package) +* removed several old dependencies ### Fixed * Small issues in CIFAR Example diff --git a/netharn/data/voc.py b/netharn/data/voc.py index 74cf1cf..1400802 100644 --- a/netharn/data/voc.py +++ b/netharn/data/voc.py @@ -6,9 +6,6 @@ resizes images to a standard size. from os.path import exists from os.path import join import re -import scipy -import scipy.sparse -import cv2 import torch import glob import ubelt as ub @@ -212,6 +209,7 @@ class VOCDataset(torch_data.Dataset, ub.NiceRepr): def _load_item(self, index, inp_size=None): # from netharn.models.yolo2.utils.yolo import _offset_boxes + import cv2 image = self._load_image(index) annot = self._load_annotation(index) @@ -232,12 +230,15 @@ class VOCDataset(torch_data.Dataset, ub.NiceRepr): return hwc, boxes, gt_classes def _load_image(self, index): + import cv2 fpath = self.gpaths[index] imbgr = cv2.imread(fpath, flags=cv2.IMREAD_COLOR) imrgb_255 = cv2.cvtColor(imbgr, cv2.COLOR_BGR2RGB) return imrgb_255 def _load_annotation(self, index): + import scipy + import scipy.sparse import xml.etree.ElementTree as ET fpath = self.apaths[index] tree = ET.parse(fpath) diff --git a/netharn/layers/gauss.py b/netharn/layers/gauss.py index 922153b..96eb5c0 100644 --- a/netharn/layers/gauss.py +++ b/netharn/layers/gauss.py @@ -2,10 +2,8 @@ import math import numpy as np import torch -import torch.nn.functional as F -import scipy -import scipy.ndimage from netharn.layers import common +import torch.nn.functional as F class Conv1d_pad(torch.nn.Conv1d, common.ModuleMixin): @@ -137,6 +135,8 @@ class GaussianBlurNd(common.Module): if self.separable: # Calculate the 1d Gaussian kernel # Follow scipy.ndimage method closely + import scipy + import scipy.ndimage kernel1d = scipy.ndimage.filters._gaussian_kernel1d(sigma, order=0, radius=lw)[::-1] kernel1d = torch.from_numpy(np.ascontiguousarray(kernel1d)) diff --git a/netharn/metrics/_devcheck_detmetrics.py b/netharn/metrics/_devcheck_detmetrics.py index 12b29f0..3d8e91f 100644 --- a/netharn/metrics/_devcheck_detmetrics.py +++ b/netharn/metrics/_devcheck_detmetrics.py @@ -1,5 +1,4 @@ import numpy as np -import pandas as pd import ubelt as ub from netharn.metrics.detections import _ave_precision from netharn.metrics.detections import detection_confusions @@ -193,6 +192,7 @@ def _devcheck_voc_consistency2(): Check how cocoeval works https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocotools/cocoeval.py """ + import pandas as pd from netharn.metrics.detections import DetectionMetrics xdata = [] ydatas = ub.ddict(list) @@ -270,6 +270,7 @@ def _devcheck_voc_consistency(): Check how cocoeval works https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocotools/cocoeval.py """ + import pandas as pd import netharn as nh # method = 'voc2012' method = 'voc2007' diff --git a/netharn/metrics/clf_report.py b/netharn/metrics/clf_report.py index 62c8683..686db8a 100644 --- a/netharn/metrics/clf_report.py +++ b/netharn/metrics/clf_report.py @@ -1,9 +1,7 @@ # -*- coding: utf-8 -*- from __future__ import absolute_import, division, print_function, unicode_literals import warnings -import scipy as sp import numpy as np -import pandas as pd import ubelt as ub @@ -65,6 +63,7 @@ def classification_report(y_true, y_pred, target_names=None, >>> rs.append(report) >>> import plottool as pt >>> pt.qtensure() + >>> import pandas as pd >>> df = pd.DataFrame(rs).drop(['raw'], axis=1) >>> delta = df.subtract(df['target'], axis=0) >>> sqrd_error = np.sqrt((delta ** 2).sum(axis=0)) @@ -73,6 +72,8 @@ def classification_report(y_true, y_pred, target_names=None, >>> ys = df.to_dict(orient='list') >>> pt.multi_plot(ydata_list=ys) """ + import pandas as pd + import scipy as sp import sklearn.metrics from sklearn.preprocessing import LabelEncoder @@ -351,6 +352,7 @@ def ovr_classification_report(mc_y_true, mc_probs, target_names=None, 2 0.8000 0.8693 0.2623 0.2652 0.1602 5 0.2778 """ + import pandas as pd import sklearn.metrics if metrics is None: diff --git a/netharn/metrics/confusion_vectors.py b/netharn/metrics/confusion_vectors.py index 652aaab..e92b530 100644 --- a/netharn/metrics/confusion_vectors.py +++ b/netharn/metrics/confusion_vectors.py @@ -1,5 +1,4 @@ import numpy as np -import pandas as pd import ubelt as ub import warnings from .functional import fast_confusion_matrix @@ -151,6 +150,7 @@ class ConfusionVectors(object): sample_weight=data.get('weight', None) ) + import pandas as pd cm = pd.DataFrame(matrix, index=list(self.classes), columns=list(self.classes)) if compress: diff --git a/netharn/util/util_averages.py b/netharn/util/util_averages.py index 9ad234c..a3747c0 100644 --- a/netharn/util/util_averages.py +++ b/netharn/util/util_averages.py @@ -1,11 +1,19 @@ # -*- coding: utf-8 -*- from __future__ import absolute_import, division, print_function, unicode_literals import collections -import pandas as pd import ubelt as ub import numpy as np +def _isnull(v): + try: + import pandas as pd + except Exception: + return v is None or np.isnan(v) + else: + return pd.isnull(v) + + class MovingAve(ub.NiceRepr): """ Abstract moving averages API @@ -108,7 +116,7 @@ class CumMovingAve(MovingAve): def update(self, other): for k, v in other.items(): - if pd.isnull(v): + if _isnull(v): if self.nan_method == 'ignore': continue elif self.nan_method == 'zero': @@ -167,7 +175,7 @@ class WindowedMovingAve(MovingAve): def update(self, other): for k, v in other.items(): - if pd.isnull(v): + if _isnull(v): v = 0 if k not in self.totals: self.history[k] = collections.deque() @@ -285,7 +293,7 @@ class ExpMovingAve(MovingAve): """ alpha = self.alpha for k, v in other.items(): - if pd.isnull(v): + if _isnull(v): v = 0 if self.correct_bias: if k not in self.means: diff --git a/netharn/util/util_json.py b/netharn/util/util_json.py index 60e01f6..3584676 100644 --- a/netharn/util/util_json.py +++ b/netharn/util/util_json.py @@ -4,7 +4,6 @@ import json import six import numpy as np import ubelt as ub -import pandas as pd def walk_json(node): @@ -101,7 +100,12 @@ def write_json(fpath, data): """ Write human readable json files """ - if isinstance(data, pd.DataFrame): + try: + import pandas as pd + except ImportError: + pd = None + + if pd and isinstance(data, pd.DataFrame): # pretty pandas json_text = (json.dumps(json.loads(data.to_json()), indent=4)) elif isinstance(data, dict): diff --git a/netharn/util/util_subextreme.py b/netharn/util/util_subextreme.py index 9323a57..d8399a0 100644 --- a/netharn/util/util_subextreme.py +++ b/netharn/util/util_subextreme.py @@ -1,7 +1,6 @@ # -*- coding: utf-8 -*- from __future__ import absolute_import, division, print_function, unicode_literals import numpy as np -import scipy.signal import ubelt as ub # NOQA @@ -106,6 +105,7 @@ def _hist_argmaxima(hist, centers=None, maxima_thresh=None): """ # FIXME: Not handling general cases # [0] index because argrelmaxima returns a tuple + import scipy.signal argmaxima_ = scipy.signal.argrelextrema(hist, np.greater)[0] if len(argmaxima_) == 0: argmaxima_ = hist.argmax() diff --git a/requirements/optional.txt b/requirements/optional.txt index 6e15327..1bdc086 100644 --- a/requirements/optional.txt +++ b/requirements/optional.txt @@ -1,3 +1,4 @@ +pandas >= 0.23.3 tqdm >= 4.23.4 Pillow >= 5.2.0 opencv-python >= 3.4.1 @@ -6,6 +7,7 @@ seaborn>=0.9.0 h5py >= 2.8.0 protobuf >= 3.6.0 scikit-learn >= 0.19.1 +scipy >= 1.2.1 tensorboard_logger >= 0.1.0 tensorboard >= 1.8.0 diff --git a/requirements/runtime.txt b/requirements/runtime.txt index 551e871..49aeaa5 100644 --- a/requirements/runtime.txt +++ b/requirements/runtime.txt @@ -1,13 +1,8 @@ -# Misc -# ==== -pandas >= 0.23.3 torchvision >= 0.2.1 six >= 1.11.0 torch >= 1.0.0 numpy >= 1.9.0 -scipy >= 1.2.1 ubelt >= 0.8.4 -progiter >= 0.0.2 parse >= 1.8.4 pyflakes >= 1.6.0 astunparse >= 1.6.1 @@ -20,7 +15,4 @@ kwarray >= 0.4.0 kwimage >= 0.4.0 kwplot >= 0.4.0 - -# Python 2.7 Only -# ==================== qualname>=0.1.0;python_version < '3.0' diff --git a/requirements/super_setup.txt b/requirements/super_setup.txt index c2c0318..fbbb3b8 100644 --- a/requirements/super_setup.txt +++ b/requirements/super_setup.txt @@ -1,6 +1,3 @@ click gitpython -ubelt >= 0.8.2 -Cython >= 0.28.4 -cffi >= 1.11.5 -scikit-build >= 0.8.1 +ubelt >= 0.8.4 diff --git a/setup.py b/setup.py index 09372b5..6b495d5 100755 --- a/setup.py +++ b/setup.py @@ -51,7 +51,7 @@ def parse_description(): return '' -def parse_requirements(fname='requirements.txt'): +def parse_requirements(fname='requirements.txt', with_version=False): """ Parse the package dependencies listed in a requirements file but strips specific versioning information. @@ -61,6 +61,7 @@ def parse_requirements(fname='requirements.txt'): CommandLine: python -c "import setup; print(setup.parse_requirements())" + python -c "import setup; print(chr(10).join(setup.parse_requirements(with_version=True)))" """ from os.path import exists import re @@ -75,28 +76,27 @@ def parse_requirements(fname='requirements.txt'): target = line.split(' ')[1] for info in parse_require_file(target): yield info - elif line.startswith('-e '): - info = {} - info['package'] = line.split('#egg=')[1] - yield info else: - # Remove versioning from the package - pat = '(' + '|'.join(['>=', '==', '>']) + ')' - parts = re.split(pat, line, maxsplit=1) - parts = [p.strip() for p in parts] - - info = {} - info['package'] = parts[0] - if len(parts) > 1: - op, rest = parts[1:] - if ';' in rest: - # Handle platform specific dependencies - # http://setuptools.readthedocs.io/en/latest/setuptools.html#declaring-platform-specific-dependencies - version, platform_deps = map(str.strip, rest.split(';')) - info['platform_deps'] = platform_deps - else: - version = rest # NOQA - info['version'] = (op, version) + info = {'line': line} + if line.startswith('-e '): + info['package'] = line.split('#egg=')[1] + else: + # Remove versioning from the package + pat = '(' + '|'.join(['>=', '==', '>']) + ')' + parts = re.split(pat, line, maxsplit=1) + parts = [p.strip() for p in parts] + + info['package'] = parts[0] + if len(parts) > 1: + op, rest = parts[1:] + if ';' in rest: + # Handle platform specific dependencies + # http://setuptools.readthedocs.io/en/latest/setuptools.html#declaring-platform-specific-dependencies + version, platform_deps = map(str.strip, rest.split(';')) + info['platform_deps'] = platform_deps + else: + version = rest # NOQA + info['version'] = (op, version) yield info def parse_require_file(fpath): @@ -107,17 +107,21 @@ def parse_requirements(fname='requirements.txt'): for info in parse_line(line): yield info - # This breaks on pip install, so check that it exists. - packages = [] - if exists(require_fpath): - for info in parse_require_file(require_fpath): - package = info['package'] - if not sys.version.startswith('3.4'): - # apparently package_deps are broken in 3.4 - platform_deps = info.get('platform_deps') - if platform_deps is not None: - package += ';' + platform_deps - packages.append(package) + def gen_packages_items(): + if exists(require_fpath): + for info in parse_require_file(require_fpath): + parts = [info['package']] + if with_version and 'version' in info: + parts.extend(info['version']) + if not sys.version.startswith('3.4'): + # apparently package_deps are broken in 3.4 + platform_deps = info.get('platform_deps') + if platform_deps is not None: + parts.append(';' + platform_deps) + item = ''.join(parts) + yield item + + packages = list(gen_packages_items()) return packages diff --git a/super_setup.py b/super_setup.py index 3e71e9d..829d61f 100755 --- a/super_setup.py +++ b/super_setup.py @@ -2,7 +2,7 @@ # -*- coding: utf-8 -*- """ Requirements: - pip install gitpython + pip install gitpython click ubelt """ from os.path import exists from os.path import join -- GitLab From 1095fdf3ebf4b3c803cdec4e8f4add71c2129b30 Mon Sep 17 00:00:00 2001 From: joncrall Date: Fri, 22 Nov 2019 00:13:20 -0500 Subject: [PATCH 16/24] wip --- setup.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index 6b495d5..2e060ea 100755 --- a/setup.py +++ b/setup.py @@ -56,8 +56,12 @@ def parse_requirements(fname='requirements.txt', with_version=False): Parse the package dependencies listed in a requirements file but strips specific versioning information. - TODO: - perhaps use https://github.com/davidfischer/requirements-parser instead + Args: + fname (str): path to requirements file + with_version (bool, default=False): if true include version specs + + Returns: + List[str]: list of requirements items CommandLine: python -c "import setup; print(setup.parse_requirements())" -- GitLab From d72ed25f393c1803a35df825af03d027d4a47131 Mon Sep 17 00:00:00 2001 From: joncrall Date: Fri, 22 Nov 2019 00:14:09 -0500 Subject: [PATCH 17/24] wip --- super_setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/super_setup.py b/super_setup.py index 829d61f..cdb39e1 100755 --- a/super_setup.py +++ b/super_setup.py @@ -489,7 +489,7 @@ def make_netharn_registry(): remotes={'public': 'git@gitlab.kitware.com:computer-vision/kwimage.git'}, ), CommonRepo( - name='kwplot', branch='dev/0.4.0', remote='public', + name='kwplot', branch='dev/0.4.1', remote='public', remotes={'public': 'git@gitlab.kitware.com:computer-vision/kwplot.git'}, ), -- GitLab From eface5017115247b43f9959c366a58219f0b5042 Mon Sep 17 00:00:00 2001 From: joncrall Date: Fri, 22 Nov 2019 00:27:26 -0500 Subject: [PATCH 18/24] wip --- setup.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/setup.py b/setup.py index 2e060ea..fa42817 100755 --- a/setup.py +++ b/setup.py @@ -71,14 +71,15 @@ def parse_requirements(fname='requirements.txt', with_version=False): import re require_fpath = fname - def parse_line(line): + def parse_line(line, base='.'): """ Parse information from a line in a requirements text file """ if line.startswith('-r '): # Allow specifying requirements in other files - target = line.split(' ')[1] - for info in parse_require_file(target): + new_fname = line.split(' ')[1] + new_fpath = join(base, new_fname) + for info in parse_require_file(new_fpath): yield info else: info = {'line': line} @@ -104,11 +105,12 @@ def parse_requirements(fname='requirements.txt', with_version=False): yield info def parse_require_file(fpath): + base = dirname(fpath) with open(fpath, 'r') as f: for line in f.readlines(): line = line.strip() if line and not line.startswith('#'): - for info in parse_line(line): + for info in parse_line(line, base): yield info def gen_packages_items(): -- GitLab From 78de05e0c4323acb96089f7b8b398efc9dac37ea Mon Sep 17 00:00:00 2001 From: joncrall Date: Fri, 22 Nov 2019 12:50:38 -0500 Subject: [PATCH 19/24] wip --- super_setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/super_setup.py b/super_setup.py index cdb39e1..30af0aa 100755 --- a/super_setup.py +++ b/super_setup.py @@ -500,7 +500,7 @@ def make_netharn_registry(): remotes={'public': 'git@gitlab.kitware.com:utils/scriptconfig.git'}, ), CommonRepo( - name='ndsampler', branch='dev/0.5.0', remote='public', + name='ndsampler', branch='dev/0.5.1', remote='public', remotes={'public': 'git@gitlab.kitware.com:computer-vision/ndsampler.git'}, ), -- GitLab From fc7c4946fec3023c77c47da5750c64c89b317618 Mon Sep 17 00:00:00 2001 From: joncrall Date: Sat, 23 Nov 2019 15:05:39 -0500 Subject: [PATCH 20/24] Fixed issue with ambiguous remotes --- super_setup.py | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/super_setup.py b/super_setup.py index 30af0aa..0fa2e7b 100755 --- a/super_setup.py +++ b/super_setup.py @@ -12,6 +12,12 @@ import ubelt as ub import functools +class ShellException(Exception): + """ + Raised when shell returns a non-zero error code + """ + + class DirtyRepoError(Exception): """ If the repo is in an unexpected state, its very easy to break things using @@ -179,7 +185,7 @@ class Repo(ub.NiceRepr): repo.debug(info['err']) if info['ret'] != 0: - raise Exception(ub.repr2(info)) + raise ShellException(ub.repr2(info)) return info @property @@ -228,7 +234,7 @@ class Repo(ub.NiceRepr): fmtkw['sha1'] = repo._cmd('git rev-parse HEAD', verbose=0)['out'].strip() try: fmtkw['tag'] = repo._cmd('git describe --tags', verbose=0)['out'].strip() + ',' - except Exception: + except ShellException: fmtkw['tag'] = ',' fmtkw['branch'] = repo.pygit.active_branch.name + ',' fmtkw['repo'] = repo.name + ',' @@ -277,7 +283,7 @@ class Repo(ub.NiceRepr): remote_name, remote_url, repo)) if not dry: repo._cmd('git remote add {} {}'.format(remote_name, remote_url)) - except Exception: + except ShellException: if remote_name == repo.remote: # Only error if the main remote is not available raise @@ -294,18 +300,25 @@ class Repo(ub.NiceRepr): repo.debug('WARNING: remote={} does not exist'.format(remote)) else: if remote.exists(): + repo.debug('Requested remote does exists') remote_branchnames = [ref.remote_head for ref in remote.refs] if repo.branch not in remote_branchnames: repo.info('Branch name not found in local remote. Attempting to fetch') repo._cmd('git fetch {}'.format(remote.name)) - # remote.fetch() + repo.info('Fetch was successful') + else: + repo.debug('Requested remote does NOT exist') + + try: + repo._cmd('git checkout {}'.format(repo.branch)) + except ShellException: + repo.debug('Checkout failed. Branch name might be ambiguous. Trying again') + repo._cmd('git checkout -b {} {}/{}'.format(repo.branch, repo.remote, repo.branch)) - repo._cmd('git checkout {}'.format(repo.branch)) # try: # repo._cmd('git checkout {}'.format(repo.branch)) # except Exception: # repo._cmd('git fetch --all') - # repo._cmd('git checkout -b {} {}/{}'.format(repo.branch, repo.remote, repo.branch)) tracking_branch = repo.pygit.active_branch.tracking_branch() if tracking_branch is None or tracking_branch.remote_name != repo.remote: -- GitLab From 3634bfe7f1153c470a2b6212a30a963afc32cd2f Mon Sep 17 00:00:00 2001 From: joncrall Date: Sat, 23 Nov 2019 17:13:06 -0500 Subject: [PATCH 21/24] Add support for https super setup --- super_setup.py | 257 +++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 206 insertions(+), 51 deletions(-) diff --git a/super_setup.py b/super_setup.py index 0fa2e7b..0d163eb 100755 --- a/super_setup.py +++ b/super_setup.py @@ -4,6 +4,7 @@ Requirements: pip install gitpython click ubelt """ +import re from os.path import exists from os.path import join from os.path import dirname @@ -49,6 +50,129 @@ def parse_version(package): return visitor.version +class GitURL(object): + """ + Represent and transform git urls between protocols defined in [3]_. + + The code in GitURL is largely derived from [1]_ and [2]_. + Credit to @coala and @FriendCode. + + Note: + while this code aims to suport protocols defined in [3]_, it is only + tested for specific use cases and therefore might need to be improved. + + References: + .. [1] https://github.com/coala/git-url-parse + .. [2] https://github.com/FriendCode/giturlparse.py + .. [3] https://git-scm.com/docs/git-clone#URLS + + Example: + >>> self = GitURL('git@gitlab.kitware.com:computer-vision/netharn.git') + >>> print(ub.repr2(self.parts())) + >>> print(self.format('ssh')) + >>> print(self.format('https')) + >>> self = GitURL('https://gitlab.kitware.com/computer-vision/netharn.git') + >>> print(ub.repr2(self.parts())) + >>> print(self.format('ssh')) + >>> print(self.format('https')) + """ + SYNTAX_PATTERNS = { + # git allows for a url style syntax + 'url': re.compile(r'(?P\w+://)' + r'((?P\w+[^@]*@))?' + r'(?P[a-z0-9_.-]+)' + r'((?P:[0-9]+))?' + r'/(?P.*\.git)'), + # git allows for ssh style syntax + 'ssh': re.compile(r'(?P\w+[^@]*@)' + r'(?P[a-z0-9_.-]+)' + r':(?P.*\.git)'), + } + + r""" + Ignore: + # Helper to build the parse pattern regexes + def named(key, regex): + return '(?P<{}>{})'.format(key, regex) + + def optional(pat): + return '({})?'.format(pat) + + parse_patterns = {} + # Standard url format + transport = named('transport', r'\w+://') + user = named('user', r'\w+[^@]*@') + host = named('host', r'[a-z0-9_.-]+') + port = named('port', r':[0-9]+') + path = named('path', r'.*\.git') + + pat = ''.join([transport, optional(user), host, optional(port), '/', path]) + parse_patterns['url'] = pat + + pat = ''.join([user, host, ':', path]) + parse_patterns['ssh'] = pat + print(ub.repr2(parse_patterns)) + """ + + def __init__(self, url): + self._url = url + self._parts = None + + def parts(self): + """ + Parses a GIT URL and returns an info dict. + + Returns: + dict: info about the url + + Raises: + Exception : if parsing fails + """ + info = { + 'syntax': '', + 'host': '', + 'user': '', + 'port': '', + 'path': None, + 'transport': '', + } + + for syntax, regex in self.SYNTAX_PATTERNS.items(): + match = regex.search(self._url) + if match: + info['syntax'] = syntax + info.update(match.groupdict()) + break + else: + raise Exception('Invalid URL {!r}'.format(self._url)) + + # change none to empty string + for k, v in info.items(): + if v is None: + info[k] = '' + return info + + def format(self, protocol): + """ + Change the protocol of the git URL + """ + parts = self.parts() + if protocol == 'ssh': + parts['user'] = 'git@' + url = ''.join([ + parts['user'], parts['host'], ':', parts['path'] + ]) + else: + parts['transport'] = protocol + '://' + parts['port'] = '' + parts['user'] = '' + url = ''.join([ + parts['transport'], parts['user'], parts['host'], + parts['port'], '/', parts['path'] + ]) + return url + + class Repo(ub.NiceRepr): """ Abstraction that references a git repository, and is able to manipulate it. @@ -150,6 +274,16 @@ class Repo(ub.NiceRepr): repo._pygit = None + def set_protocol(self, protocol): + """ + Changes the url protocol to either ssh or https + + Args: + protocol (str): can be ssh or https + """ + gurl = GitURL(self.url) + self.url = gurl.format(protocol) + def info(repo, msg): repo._logged_lines.append(('INFO', 'INFO: ' + msg)) if repo.verbose >= 1: @@ -288,69 +422,78 @@ class Repo(ub.NiceRepr): # Only error if the main remote is not available raise - # Ensure we are on the right branch - if repo.branch != repo.pygit.active_branch.name: - repo.debug('NEED TO SET BRANCH TO {} for {}'.format(repo.branch, repo)) - if not dry: - try: - remote = repo.pygit.remotes[repo.remote] - if not remote.exists(): - raise IndexError - except IndexError: - repo.debug('WARNING: remote={} does not exist'.format(remote)) - else: - if remote.exists(): - repo.debug('Requested remote does exists') - remote_branchnames = [ref.remote_head for ref in remote.refs] - if repo.branch not in remote_branchnames: - repo.info('Branch name not found in local remote. Attempting to fetch') - repo._cmd('git fetch {}'.format(remote.name)) - repo.info('Fetch was successful') + # Ensure we have the right remote + try: + remote = repo.pygit.remotes[repo.remote] + if not remote.exists(): + raise IndexError + else: + repo.debug('The requested remote={} name exists'.format(remote)) + except IndexError: + repo.debug('WARNING: remote={} does not exist'.format(remote)) + else: + if remote.exists(): + repo.debug('Requested remote does exists') + remote_branchnames = [ref.remote_head for ref in remote.refs] + if repo.branch not in remote_branchnames: + repo.info('Branch name not found in local remote. Attempting to fetch') + if dry: + repo.info('dry run, not fetching') else: - repo.debug('Requested remote does NOT exist') + repo._cmd('git fetch {}'.format(remote.name)) + repo.info('Fetch was successful') + else: + repo.debug('Requested remote does NOT exist') - try: - repo._cmd('git checkout {}'.format(repo.branch)) - except ShellException: - repo.debug('Checkout failed. Branch name might be ambiguous. Trying again') - repo._cmd('git checkout -b {} {}/{}'.format(repo.branch, repo.remote, repo.branch)) + # Ensure the remote points to the right place + if repo.url not in list(remote.urls): + repo.debug('WARNING: The requested url={} disagrees with remote urls={}'.format(repo.url, list(remote.urls))) - # try: - # repo._cmd('git checkout {}'.format(repo.branch)) - # except Exception: - # repo._cmd('git fetch --all') + if dry: + repo.info('Dry run, not updating remote url') + else: + repo.info('Updating remote url') + repo._cmd('git remote set-url {} {}'.format(repo.remote, repo.url)) + + # Ensure we are on the right branch + if repo.branch != repo.pygit.active_branch.name: + repo.debug('NEED TO SET BRANCH TO {} for {}'.format(repo.branch, repo)) + try: + repo._cmd('git checkout {}'.format(repo.branch)) + except ShellException: + repo.debug('Checkout failed. Branch name might be ambiguous. Trying again') + repo._cmd('git checkout -b {} {}/{}'.format(repo.branch, repo.remote, repo.branch)) tracking_branch = repo.pygit.active_branch.tracking_branch() if tracking_branch is None or tracking_branch.remote_name != repo.remote: repo.debug('NEED TO SET UPSTREAM FOR FOR {}'.format(repo)) - if not dry: - try: - remote = repo.pygit.remotes[repo.remote] - if not remote.exists(): - raise IndexError - except IndexError: - repo.debug('WARNING: remote={} does not exist'.format(remote)) - else: - if remote.exists(): - remote_branchnames = [ref.remote_head for ref in remote.refs] - if repo.branch not in remote_branchnames: + + try: + remote = repo.pygit.remotes[repo.remote] + if not remote.exists(): + raise IndexError + except IndexError: + repo.debug('WARNING: remote={} does not exist'.format(remote)) + else: + if remote.exists(): + remote_branchnames = [ref.remote_head for ref in remote.refs] + if repo.branch not in remote_branchnames: + if dry: + repo.info('Branch name not found in local remote. Dry run, use ensure to attempt to fetch') + else: repo.info('Branch name not found in local remote. Attempting to fetch') - remote.fetch() + repo._cmd('git fetch {}'.format(repo.remote)) + + remote_branchnames = [ref.remote_head for ref in remote.refs] + if repo.branch not in remote_branchnames: + raise Exception('Branch name still does not exist') + if not dry: repo._cmd('git branch --set-upstream-to={remote}/{branch} {branch}'.format( remote=repo.remote, branch=repo.branch )) - - # try: - # repo._cmd('git branch --set-upstream-to={remote}/{branch} {branch}'.format( - # remote=repo.remote, branch=repo.branch - # )) - # except Exception: - # # remote.fetch() - # repo._cmd('git fetch --all') - # repo._cmd('git branch --set-upstream-to={remote}/{branch} {branch}'.format( - # remote=repo.remote, branch=repo.branch - # )) + else: + repo.info('Would attempt to set upstream') # Print some status repo.debug(' * branch = {} -> {}'.format( @@ -540,6 +683,18 @@ def main(): if ub.argflag('--serial'): num_workers = 0 + protocol = ub.argval('--protocol', None) + if ub.argflag('--https'): + protocol = 'https' + if ub.argflag('--http'): + protocol = 'http' + if ub.argflag('--ssh'): + protocol = 'ssh' + + if protocol is not None: + for repo in registery.repos: + repo.set_protocol(protocol) + default_context_settings = { 'help_option_names': ['-h', '--help'], 'allow_extra_args': True, -- GitLab From 47572051692d43ea27e15278c7bcebc0edc5192f Mon Sep 17 00:00:00 2001 From: joncrall Date: Sat, 23 Nov 2019 17:21:11 -0500 Subject: [PATCH 22/24] wip --- super_setup.py | 121 +++++++++++++++++++++++++++---------------------- 1 file changed, 68 insertions(+), 53 deletions(-) diff --git a/super_setup.py b/super_setup.py index 0d163eb..13c88c1 100755 --- a/super_setup.py +++ b/super_setup.py @@ -425,75 +425,86 @@ class Repo(ub.NiceRepr): # Ensure we have the right remote try: remote = repo.pygit.remotes[repo.remote] - if not remote.exists(): - raise IndexError - else: - repo.debug('The requested remote={} name exists'.format(remote)) except IndexError: - repo.debug('WARNING: remote={} does not exist'.format(remote)) - else: - if remote.exists(): - repo.debug('Requested remote does exists') - remote_branchnames = [ref.remote_head for ref in remote.refs] - if repo.branch not in remote_branchnames: - repo.info('Branch name not found in local remote. Attempting to fetch') - if dry: - repo.info('dry run, not fetching') - else: - repo._cmd('git fetch {}'.format(remote.name)) - repo.info('Fetch was successful') - else: - repo.debug('Requested remote does NOT exist') - - # Ensure the remote points to the right place - if repo.url not in list(remote.urls): - repo.debug('WARNING: The requested url={} disagrees with remote urls={}'.format(repo.url, list(remote.urls))) - - if dry: - repo.info('Dry run, not updating remote url') + if not dry: + raise AssertionError('Something went wrong') else: - repo.info('Updating remote url') - repo._cmd('git remote set-url {} {}'.format(repo.remote, repo.url)) + remote = None - # Ensure we are on the right branch - if repo.branch != repo.pygit.active_branch.name: - repo.debug('NEED TO SET BRANCH TO {} for {}'.format(repo.branch, repo)) + if remote is not None: try: - repo._cmd('git checkout {}'.format(repo.branch)) - except ShellException: - repo.debug('Checkout failed. Branch name might be ambiguous. Trying again') - repo._cmd('git checkout -b {} {}/{}'.format(repo.branch, repo.remote, repo.branch)) - - tracking_branch = repo.pygit.active_branch.tracking_branch() - if tracking_branch is None or tracking_branch.remote_name != repo.remote: - repo.debug('NEED TO SET UPSTREAM FOR FOR {}'.format(repo)) - - try: - remote = repo.pygit.remotes[repo.remote] if not remote.exists(): raise IndexError + else: + repo.debug('The requested remote={} name exists'.format(remote)) except IndexError: repo.debug('WARNING: remote={} does not exist'.format(remote)) else: if remote.exists(): + repo.debug('Requested remote does exists') remote_branchnames = [ref.remote_head for ref in remote.refs] if repo.branch not in remote_branchnames: + repo.info('Branch name not found in local remote. Attempting to fetch') if dry: - repo.info('Branch name not found in local remote. Dry run, use ensure to attempt to fetch') + repo.info('dry run, not fetching') else: - repo.info('Branch name not found in local remote. Attempting to fetch') - repo._cmd('git fetch {}'.format(repo.remote)) + repo._cmd('git fetch {}'.format(remote.name)) + repo.info('Fetch was successful') + else: + repo.debug('Requested remote does NOT exist') - remote_branchnames = [ref.remote_head for ref in remote.refs] - if repo.branch not in remote_branchnames: - raise Exception('Branch name still does not exist') + # Ensure the remote points to the right place + if repo.url not in list(remote.urls): + repo.debug('WARNING: The requested url={} disagrees with remote urls={}'.format(repo.url, list(remote.urls))) - if not dry: - repo._cmd('git branch --set-upstream-to={remote}/{branch} {branch}'.format( - remote=repo.remote, branch=repo.branch - )) - else: - repo.info('Would attempt to set upstream') + if dry: + repo.info('Dry run, not updating remote url') + else: + repo.info('Updating remote url') + repo._cmd('git remote set-url {} {}'.format(repo.remote, repo.url)) + + # Ensure we are on the right branch + if repo.branch != repo.pygit.active_branch.name: + repo.debug('NEED TO SET BRANCH TO {} for {}'.format(repo.branch, repo)) + try: + repo._cmd('git checkout {}'.format(repo.branch)) + except ShellException: + repo.debug('Checkout failed. Branch name might be ambiguous. Trying again') + try: + repo._cmd('git checkout -b {} {}/{}'.format(repo.branch, repo.remote, repo.branch)) + except ShellException: + raise Exception('does the branch exist on the remote?') + + tracking_branch = repo.pygit.active_branch.tracking_branch() + if tracking_branch is None or tracking_branch.remote_name != repo.remote: + repo.debug('NEED TO SET UPSTREAM FOR FOR {}'.format(repo)) + + try: + remote = repo.pygit.remotes[repo.remote] + if not remote.exists(): + raise IndexError + except IndexError: + repo.debug('WARNING: remote={} does not exist'.format(remote)) + else: + if remote.exists(): + remote_branchnames = [ref.remote_head for ref in remote.refs] + if repo.branch not in remote_branchnames: + if dry: + repo.info('Branch name not found in local remote. Dry run, use ensure to attempt to fetch') + else: + repo.info('Branch name not found in local remote. Attempting to fetch') + repo._cmd('git fetch {}'.format(repo.remote)) + + remote_branchnames = [ref.remote_head for ref in remote.refs] + if repo.branch not in remote_branchnames: + raise Exception('Branch name still does not exist') + + if not dry: + repo._cmd('git branch --set-upstream-to={remote}/{branch} {branch}'.format( + remote=repo.remote, branch=repo.branch + )) + else: + repo.info('Would attempt to set upstream') # Print some status repo.debug(' * branch = {} -> {}'.format( @@ -644,6 +655,10 @@ def make_netharn_registry(): name='kwimage', branch='dev/0.5.2', remote='public', remotes={'public': 'git@gitlab.kitware.com:computer-vision/kwimage.git'}, ), + CommonRepo( + name='kwannot', branch='master', remote='public', + remotes={'public': 'git@gitlab.kitware.com:computer-vision/kwannot.git'}, + ), CommonRepo( name='kwplot', branch='dev/0.4.1', remote='public', remotes={'public': 'git@gitlab.kitware.com:computer-vision/kwplot.git'}, -- GitLab From 3b2c80a692b36105714171780b97e3b531c091af Mon Sep 17 00:00:00 2001 From: joncrall Date: Mon, 25 Nov 2019 13:13:17 -0500 Subject: [PATCH 23/24] wip --- CHANGELOG.md | 1 + netharn/util/util_zip.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 82d3c30..23454d7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,7 @@ This project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.htm * Small issues in CIFAR Example * Small `imgaug` issue in `examples/sseg_camvid.py` and `examples/segmentation.py` * FitHarn no longer fails when loaders are missing batch sizes +* Fixed windows issue in `util_zip.zopen`. ## Version 0.5.1 diff --git a/netharn/util/util_zip.py b/netharn/util/util_zip.py index 1413d08..fd885eb 100644 --- a/netharn/util/util_zip.py +++ b/netharn/util/util_zip.py @@ -124,7 +124,7 @@ class zopen(ub.NiceRepr): _handle = None if exists(self.fpath): _handle = open(self.fpath, self.mode) - elif '.zip/' in self.fpath: + elif '.zip/' in self.fpath or '.zip' + os.path.sep in self.fpath: fpath = self.fpath archivefile, internal = split_archive(fpath) myzip = zipfile.ZipFile(archivefile, 'r') -- GitLab From 395bdabbf8bb6f8bb410cc99e7b2b6e8efc74326 Mon Sep 17 00:00:00 2001 From: joncrall Date: Mon, 25 Nov 2019 13:50:56 -0500 Subject: [PATCH 24/24] Fix strip ansi --- CHANGELOG.md | 1 + netharn/fit_harn.py | 2 +- netharn/util/__init__.py | 4 ++-- netharn/util/util_misc.py | 22 ++++++++++++++++++++++ 4 files changed, 26 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 23454d7..28f59c9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,7 @@ This project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.htm * Small `imgaug` issue in `examples/sseg_camvid.py` and `examples/segmentation.py` * FitHarn no longer fails when loaders are missing batch sizes * Fixed windows issue in `util_zip.zopen`. +* Fixed runtime dependency on `strip_ansi` from xdoctest. ## Version 0.5.1 diff --git a/netharn/fit_harn.py b/netharn/fit_harn.py index 4d7c6b8..415b351 100644 --- a/netharn/fit_harn.py +++ b/netharn/fit_harn.py @@ -145,9 +145,9 @@ from netharn.exceptions import (StopTraining, CannotResume, TrainingDiverged, from netharn import util from netharn.util import profiler +from netharn.util import strip_ansi from netharn import export -from xdoctest.utils import strip_ansi try: import tensorboard_logger diff --git a/netharn/util/__init__.py b/netharn/util/__init__.py index 84333d0..bda3e9e 100644 --- a/netharn/util/__init__.py +++ b/netharn/util/__init__.py @@ -44,7 +44,7 @@ from .util_io import (read_arr, read_h5arr, write_arr, write_h5arr,) from .util_iter import (roundrobin,) from .util_json import (LossyJSONEncoder, NumpyEncoder, read_json, walk_json, write_json,) -from .util_misc import (SupressPrint, FlatIndexer) +from .util_misc import (SupressPrint, FlatIndexer, strip_ansi) from .util_resources import (ensure_ulimit,) from .util_slider import (SlidingIndexDataset, SlidingSlices, SlidingWindow, Stitcher,) @@ -138,7 +138,7 @@ __all__ = ['ArrayAPI', 'BatchNormContext', 'Boxes', 'CacheStamp', 'Color', 'subpixel_translate', 'torch_ravel_multi_index', 'trainable_layers', 'uniform', 'uniform32', 'util_dataframe', 'walk_json', 'warp_points', 'warp_tensor', 'wide_strides_1d', 'write_arr', - 'write_h5arr', 'write_json', 'zopen', 'FlatIndexer'] + 'write_h5arr', 'write_json', 'zopen', 'FlatIndexer', 'strip_ansi'] # diff --git a/netharn/util/util_misc.py b/netharn/util/util_misc.py index f8cdf57..35dfa11 100644 --- a/netharn/util/util_misc.py +++ b/netharn/util/util_misc.py @@ -2,6 +2,7 @@ from __future__ import absolute_import, division, print_function, unicode_literals import numpy as np import ubelt as ub +import re class SupressPrint(): @@ -84,3 +85,24 @@ class FlatIndexer(ub.NiceRepr): """ base = self.cums[outer] - self.lens[outer] return base + inner + + +def strip_ansi(text): + r""" + Removes all ansi directives from the string. + + References: + http://stackoverflow.com/questions/14693701/remove-ansi + https://stackoverflow.com/questions/13506033/filtering-out-ansi-escape-sequences + + Examples: + >>> line = '\t\u001b[0;35mBlabla\u001b[0m \u001b[0;36m172.18.0.2\u001b[0m' + >>> escaped_line = strip_ansi(line) + >>> assert escaped_line == '\tBlabla 172.18.0.2' + """ + # ansi_escape1 = re.compile(r'\x1b[^m]*m') + # text = ansi_escape1.sub('', text) + # ansi_escape2 = re.compile(r'\x1b\[([0-9,A-Z]{1,2}(;[0-9]{1,2})?(;[0-9]{3})?)?[m|K]?') + ansi_escape3 = re.compile(r'(\x9B|\x1B\[)[0-?]*[ -/]*[@-~]', flags=re.IGNORECASE) + text = ansi_escape3.sub('', text) + return text -- GitLab