From 436e12c1dea1cfb59d1ffc73cf1abc56e5cc81ae Mon Sep 17 00:00:00 2001 From: Paul Beasly <paul.beasly@kitware.com> Date: Wed, 12 Jun 2024 17:36:23 -0400 Subject: [PATCH 1/3] Update fusion/predict.py with hidden_layers --- docs/source/manual/tutorial/__init__.py | 0 .../tutorial/encoder_layers_fusion_predict.sh | 133 ++++++ .../tutorial/fusion_model_layer_info.sh | 266 +++++++++++ .../manual/tutorial/tutorial9_hidden_layer.py | 446 ++++++++++++++++++ geowatch/tasks/fusion/predict.py | 37 ++ 5 files changed, 882 insertions(+) create mode 100644 docs/source/manual/tutorial/__init__.py create mode 100644 docs/source/manual/tutorial/encoder_layers_fusion_predict.sh create mode 100644 docs/source/manual/tutorial/fusion_model_layer_info.sh create mode 100644 docs/source/manual/tutorial/tutorial9_hidden_layer.py diff --git a/docs/source/manual/tutorial/__init__.py b/docs/source/manual/tutorial/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/docs/source/manual/tutorial/encoder_layers_fusion_predict.sh b/docs/source/manual/tutorial/encoder_layers_fusion_predict.sh new file mode 100644 index 000000000..c302bcb91 --- /dev/null +++ b/docs/source/manual/tutorial/encoder_layers_fusion_predict.sh @@ -0,0 +1,133 @@ +# output of encoder layers + +Encoder Layer 0: Linear(in_features=560, out_features=128, bias=True) +Encoder Layer 1: Sequential( + (0): ChannelwiseTransformerEncoderLayer( + (attention_modules): ModuleDict( + (time mode height width): ResidualAttentionSequential( + (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True) + (1): MultiheadSelfAttention( + (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True) + ) + ) + ) + (mlp): ResidualSequential( + (0): Linear(in_features=128, out_features=128, bias=True) + (1): Dropout(p=0.1, inplace=False) + (2): GELU(approximate='none') + (3): Linear(in_features=128, out_features=128, bias=True) + ) + ) + (1): ChannelwiseTransformerEncoderLayer( + (attention_modules): ModuleDict( + (time mode height width): ResidualAttentionSequential( + (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True) + (1): MultiheadSelfAttention( + (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True) + ) + ) + ) + (mlp): ResidualSequential( + (0): Linear(in_features=128, out_features=128, bias=True) + (1): Dropout(p=0.1, inplace=False) + (2): GELU(approximate='none') + (3): Linear(in_features=128, out_features=128, bias=True) + ) + ) + (2): ChannelwiseTransformerEncoderLayer( + (attention_modules): ModuleDict( + (time mode height width): ResidualAttentionSequential( + (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True) + (1): MultiheadSelfAttention( + (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True) + ) + ) + ) + (mlp): ResidualSequential( + (0): Linear(in_features=128, out_features=128, bias=True) + (1): Dropout(p=0.1, inplace=False) + (2): GELU(approximate='none') + (3): Linear(in_features=128, out_features=128, bias=True) + ) + ) + (3): ChannelwiseTransformerEncoderLayer( + (attention_modules): ModuleDict( + (time mode height width): ResidualAttentionSequential( + (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True) + (1): MultiheadSelfAttention( + (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True) + ) + ) + ) + (mlp): ResidualSequential( + (0): Linear(in_features=128, out_features=128, bias=True) + (1): Dropout(p=0.1, inplace=False) + (2): GELU(approximate='none') + (3): Linear(in_features=128, out_features=128, bias=True) + ) + ) + (4): ChannelwiseTransformerEncoderLayer( + (attention_modules): ModuleDict( + (time mode height width): ResidualAttentionSequential( + (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True) + (1): MultiheadSelfAttention( + (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True) + ) + ) + ) + (mlp): ResidualSequential( + (0): Linear(in_features=128, out_features=128, bias=True) + (1): Dropout(p=0.1, inplace=False) + (2): GELU(approximate='none') + (3): Linear(in_features=128, out_features=128, bias=True) + ) + ) + (5): ChannelwiseTransformerEncoderLayer( + (attention_modules): ModuleDict( + (time mode height width): ResidualAttentionSequential( + (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True) + (1): MultiheadSelfAttention( + (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True) + ) + ) + ) + (mlp): ResidualSequential( + (0): Linear(in_features=128, out_features=128, bias=True) + (1): Dropout(p=0.1, inplace=False) + (2): GELU(approximate='none') + (3): Linear(in_features=128, out_features=128, bias=True) + ) + ) + (6): ChannelwiseTransformerEncoderLayer( + (attention_modules): ModuleDict( + (time mode height width): ResidualAttentionSequential( + (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True) + (1): MultiheadSelfAttention( + (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True) + ) + ) + ) + (mlp): ResidualSequential( + (0): Linear(in_features=128, out_features=128, bias=True) + (1): Dropout(p=0.1, inplace=False) + (2): GELU(approximate='none') + (3): Linear(in_features=128, out_features=128, bias=True) + ) + ) + (7): ChannelwiseTransformerEncoderLayer( + (attention_modules): ModuleDict( + (time mode height width): ResidualAttentionSequential( + (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True) + (1): MultiheadSelfAttention( + (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True) + ) + ) + ) + (mlp): ResidualSequential( + (0): Linear(in_features=128, out_features=128, bias=True) + (1): Dropout(p=0.1, inplace=False) + (2): GELU(approximate='none') + (3): Linear(in_features=128, out_features=128, bias=True) + ) + ) +) diff --git a/docs/source/manual/tutorial/fusion_model_layer_info.sh b/docs/source/manual/tutorial/fusion_model_layer_info.sh new file mode 100644 index 000000000..a4d96b5da --- /dev/null +++ b/docs/source/manual/tutorial/fusion_model_layer_info.sh @@ -0,0 +1,266 @@ +# output of layers for fusion predict.py mdoel + +Layer 0: RobustModuleDict( + (*): RobustModuleDict( + (B1|B10|B11|B8|B8a): InputNorm() + ) +) +Layer 1: RobustModuleDict( + (*): RobustModuleDict( + (B1|B10|B11|B8|B8a): RearrangeTokenizer( + (foot): MultiLayerPerceptronNd( + (hidden): Sequential( + (hidden0): ConvNormNd( + (conv): Conv2d(5, 6, kernel_size=(1, 1), stride=(1, 1)) + (noli): ReLU(inplace=True) + ) + (hidden1): ConvNormNd( + (conv): Conv2d(6, 6, kernel_size=(1, 1), stride=(1, 1)) + (noli): ReLU(inplace=True) + ) + (hidden2): ConvNormNd( + (conv): Conv2d(6, 7, kernel_size=(1, 1), stride=(1, 1)) + (noli): ReLU(inplace=True) + ) + (output): Conv2d(7, 8, kernel_size=(1, 1), stride=(1, 1)) + ) + (skip): Conv2d(5, 8, kernel_size=(1, 1), stride=(1, 1)) + ) + ) + ) +) +Layer 2: MultiLayerPerceptronNd( + (hidden): Sequential( + (hidden0): ConvNormNd( + (conv): Conv0d(in_features=1, out_features=3, bias=True) + (noli): ReLU(inplace=True) + ) + (hidden1): ConvNormNd( + (conv): Conv0d(in_features=3, out_features=4, bias=True) + (noli): ReLU(inplace=True) + ) + (hidden2): ConvNormNd( + (conv): Conv0d(in_features=4, out_features=6, bias=True) + (noli): ReLU(inplace=True) + ) + (output): Conv0d(in_features=6, out_features=8, bias=True) + ) + (skip): Conv0d(in_features=1, out_features=8, bias=True) +) +Layer 3: MultiLayerPerceptronNd( + (hidden): Sequential( + (hidden0): ConvNormNd( + (conv): Conv0d(in_features=16, out_features=14, bias=True) + (noli): ReLU(inplace=True) + ) + (hidden1): ConvNormNd( + (conv): Conv0d(in_features=14, out_features=12, bias=True) + (noli): ReLU(inplace=True) + ) + (hidden2): ConvNormNd( + (conv): Conv0d(in_features=12, out_features=10, bias=True) + (noli): ReLU(inplace=True) + ) + (output): Conv0d(in_features=10, out_features=8, bias=True) + ) + (skip): Conv0d(in_features=16, out_features=8, bias=True) +) +Layer 4: MultiLayerPerceptronNd( + (hidden): Sequential( + (hidden0): ConvNormNd( + (conv): Conv0d(in_features=16, out_features=14, bias=True) + (noli): ReLU(inplace=True) + ) + (hidden1): ConvNormNd( + (conv): Conv0d(in_features=14, out_features=12, bias=True) + (noli): ReLU(inplace=True) + ) + (hidden2): ConvNormNd( + (conv): Conv0d(in_features=12, out_features=10, bias=True) + (noli): ReLU(inplace=True) + ) + (output): Conv0d(in_features=10, out_features=8, bias=True) + ) + (skip): Conv0d(in_features=16, out_features=8, bias=True) +) +Layer 5: FusionEncoder( + (first): Linear(in_features=560, out_features=128, bias=True) + (layers): Sequential( + (0): ChannelwiseTransformerEncoderLayer( + (attention_modules): ModuleDict( + (time mode height width): ResidualAttentionSequential( + (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True) + (1): MultiheadSelfAttention( + (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True) + ) + ) + ) + (mlp): ResidualSequential( + (0): Linear(in_features=128, out_features=128, bias=True) + (1): Dropout(p=0.1, inplace=False) + (2): GELU(approximate='none') + (3): Linear(in_features=128, out_features=128, bias=True) + ) + ) + (1): ChannelwiseTransformerEncoderLayer( + (attention_modules): ModuleDict( + (time mode height width): ResidualAttentionSequential( + (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True) + (1): MultiheadSelfAttention( + (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True) + ) + ) + ) + (mlp): ResidualSequential( + (0): Linear(in_features=128, out_features=128, bias=True) + (1): Dropout(p=0.1, inplace=False) + (2): GELU(approximate='none') + (3): Linear(in_features=128, out_features=128, bias=True) + ) + ) + (2): ChannelwiseTransformerEncoderLayer( + (attention_modules): ModuleDict( + (time mode height width): ResidualAttentionSequential( + (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True) + (1): MultiheadSelfAttention( + (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True) + ) + ) + ) + (mlp): ResidualSequential( + (0): Linear(in_features=128, out_features=128, bias=True) + (1): Dropout(p=0.1, inplace=False) + (2): GELU(approximate='none') + (3): Linear(in_features=128, out_features=128, bias=True) + ) + ) + (3): ChannelwiseTransformerEncoderLayer( + (attention_modules): ModuleDict( + (time mode height width): ResidualAttentionSequential( + (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True) + (1): MultiheadSelfAttention( + (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True) + ) + ) + ) + (mlp): ResidualSequential( + (0): Linear(in_features=128, out_features=128, bias=True) + (1): Dropout(p=0.1, inplace=False) + (2): GELU(approximate='none') + (3): Linear(in_features=128, out_features=128, bias=True) + ) + ) + (4): ChannelwiseTransformerEncoderLayer( + (attention_modules): ModuleDict( + (time mode height width): ResidualAttentionSequential( + (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True) + (1): MultiheadSelfAttention( + (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True) + ) + ) + ) + (mlp): ResidualSequential( + (0): Linear(in_features=128, out_features=128, bias=True) + (1): Dropout(p=0.1, inplace=False) + (2): GELU(approximate='none') + (3): Linear(in_features=128, out_features=128, bias=True) + ) + ) + (5): ChannelwiseTransformerEncoderLayer( + (attention_modules): ModuleDict( + (time mode height width): ResidualAttentionSequential( + (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True) + (1): MultiheadSelfAttention( + (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True) + ) + ) + ) + (mlp): ResidualSequential( + (0): Linear(in_features=128, out_features=128, bias=True) + (1): Dropout(p=0.1, inplace=False) + (2): GELU(approximate='none') + (3): Linear(in_features=128, out_features=128, bias=True) + ) + ) + (6): ChannelwiseTransformerEncoderLayer( + (attention_modules): ModuleDict( + (time mode height width): ResidualAttentionSequential( + (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True) + (1): MultiheadSelfAttention( + (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True) + ) + ) + ) + (mlp): ResidualSequential( + (0): Linear(in_features=128, out_features=128, bias=True) + (1): Dropout(p=0.1, inplace=False) + (2): GELU(approximate='none') + (3): Linear(in_features=128, out_features=128, bias=True) + ) + ) + (7): ChannelwiseTransformerEncoderLayer( + (attention_modules): ModuleDict( + (time mode height width): ResidualAttentionSequential( + (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True) + (1): MultiheadSelfAttention( + (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True) + ) + ) + ) + (mlp): ResidualSequential( + (0): Linear(in_features=128, out_features=128, bias=True) + (1): Dropout(p=0.1, inplace=False) + (2): GELU(approximate='none') + (3): Linear(in_features=128, out_features=128, bias=True) + ) + ) + ) +) +Layer 6: ModuleDict( + (change): CrossEntropyLoss() + (saliency): FocalLoss() + (class): FocalLoss() +) +Layer 7: ModuleDict( + (change): MultiLayerPerceptronNd( + (hidden): Sequential( + (hidden0): ConvNormNd( + (conv): Conv0d(in_features=128, out_features=86, bias=True) + (noli): ReLU(inplace=True) + ) + (hidden1): ConvNormNd( + (conv): Conv0d(in_features=86, out_features=44, bias=True) + (noli): ReLU(inplace=True) + ) + (output): Conv0d(in_features=44, out_features=2, bias=True) + ) + ) + (saliency): MultiLayerPerceptronNd( + (hidden): Sequential( + (hidden0): ConvNormNd( + (conv): Conv0d(in_features=128, out_features=86, bias=True) + (noli): ReLU(inplace=True) + ) + (hidden1): ConvNormNd( + (conv): Conv0d(in_features=86, out_features=44, bias=True) + (noli): ReLU(inplace=True) + ) + (output): Conv0d(in_features=44, out_features=2, bias=True) + ) + ) + (class): MultiLayerPerceptronNd( + (hidden): Sequential( + (hidden0): ConvNormNd( + (conv): Conv0d(in_features=128, out_features=87, bias=True) + (noli): ReLU(inplace=True) + ) + (hidden1): ConvNormNd( + (conv): Conv0d(in_features=87, out_features=45, bias=True) + (noli): ReLU(inplace=True) + ) + (output): Conv0d(in_features=45, out_features=4, bias=True) + ) + ) +) +Layer 8: SinePositionalEncoding() +Layer 9: SinePositionalEncoding() diff --git a/docs/source/manual/tutorial/tutorial9_hidden_layer.py b/docs/source/manual/tutorial/tutorial9_hidden_layer.py new file mode 100644 index 000000000..96b8b90e2 --- /dev/null +++ b/docs/source/manual/tutorial/tutorial9_hidden_layer.py @@ -0,0 +1,446 @@ +# Tutorial to extract hidden layer features from a pre-trained model + +import sys +import os +import ubelt as ub + + +# import module +from geowatch.tasks.fusion.predict import * # NOQA +print(Predictor) +from geowatch.utils.lightning_ext.monkeypatches import disable_lightning_hardware_warnings +disable_lightning_hardware_warnings() +args = None +cmdline = False +devices = None +test_dpath = ub.Path.appdir('geowatch/tests/fusion/').ensuredir() +results_path = (test_dpath / 'predict').ensuredir() +results_path.delete() +results_path.ensuredir() + +# generate kwcoco toydataset +import kwcoco +train_dset = kwcoco.CocoDataset.demo('special:vidshapes2-gsize64-frames9-speed0.5-multispectral') +test_dset = kwcoco.CocoDataset.demo('special:vidshapes1-gsize64-frames9-speed0.5-multispectral') + +root_dpath = ub.Path(test_dpath, 'train').ensuredir() + +# create config +fit_config = kwargs = { + 'subcommand': 'fit', + 'fit.data.train_dataset': train_dset.fpath, + 'fit.data.time_steps': 2, + 'fit.data.time_span': "2m", + 'fit.data.chip_dims': 64, + 'fit.data.time_sampling': 'hardish3', + 'fit.data.num_workers': 0, + # 'package_fpath': package_fpath, + 'fit.model.class_path': 'geowatch.tasks.fusion.methods.MultimodalTransformer', + 'fit.model.init_args.global_change_weight': 1.0, + 'fit.model.init_args.global_class_weight': 1.0, + 'fit.model.init_args.global_saliency_weight': 1.0, + 'fit.optimizer.class_path': 'torch.optim.SGD', + 'fit.optimizer.init_args.lr': 1e-5, + 'fit.trainer.max_steps': 10, + 'fit.trainer.accelerator': 'cpu', + 'fit.trainer.devices': 1, + 'fit.trainer.max_epochs': 3, + 'fit.trainer.log_every_n_steps': 1, + 'fit.trainer.default_root_dir': os.fspath(root_dpath), +} + +from geowatch.tasks.fusion import fit_lightning +package_fpath = root_dpath / 'final_package.pt' + +# sets up the model +fit_lightning.main(fit_config) + +# Unfortunately, its not as easy to get the package path of +# this call.. +assert ub.Path(package_fpath).exists() +# Predict via that model +predict_kwargs = kwargs = { + 'package_fpath': package_fpath, + 'pred_dataset': ub.Path(results_path) / 'pred.kwcoco.json', + 'test_dataset': test_dset.fpath, + 'datamodule': 'KWCocoVideoDataModule', + 'batch_size': 1, + 'num_workers': 0, + 'devices': devices, + 'draw_batches': 1, +} + +# Does this run the prediction model? +result_dataset = predict(**kwargs) + +#----------------------------------------------------------------- +# Load config, model and dataset into predictor model +import rich +from rich.markup import escape +config = PredictConfig.cli(cmdline=cmdline, data=kwargs, strict=True) +rich.print('config = {}'.format(escape(ub.urepr(config, nl=2)))) +predictor = Predictor(config) +predictor._load_model() +predictor._load_dataset() + +#--------------------------------------------------------------- +# Move to the predictor _run() function (line 1370) +# Need to assign 'self' due to class method +self = predictor +datamodule = self.datamodule +model = self.model +config = self.config +test_coco_dataset = datamodule.coco_datasets['test'] + +# test_torch_dataset = datamodule.torch_datasets['test'] +# T, H, W = test_torch_dataset.window_dims + +# Create the results dataset as a copy of the test CocoDataset +print('Populate result dataset') +result_dataset: kwcoco.CocoDataset = test_coco_dataset.copy() + +# Remove all annotations in the results copy +if config['clear_annots']: + result_dataset.clear_annotations() + +# Change all paths to be absolute paths +result_dataset.reroot(absolute=True) +if not config['pred_dataset']: + raise ValueError( + f'Must specify path to the output (predicted) kwcoco file. ' + f'Got {config["pred_dataset"]=}') +result_dataset.fpath = str(ub.Path(config['pred_dataset']).expand()) + +from geowatch.utils.lightning_ext import util_device +print('devices = {!r}'.format(config['devices'])) +devices = util_device.coerce_devices(config['devices']) +print('devices = {!r}'.format(devices)) +if len(devices) > 1: + raise NotImplementedError('TODO: handle multiple devices') +device = devices[0] + +fit_config = self.fit_config + +#--------------------------------------------------------------- +# The critical work happens here at _predict_critical_loop(), line 1407 +# The data module produces batches +# We move into the function to perform the operations (line 557) + +print('Predict on device = {!r}'.format(device)) +downweight_edges = config.downweight_edges + +UNPACKAGE_METHOD_HACK = 0 +if UNPACKAGE_METHOD_HACK: + # unpackage model hack + from geowatch.tasks.fusion import methods + unpackged_method = methods.MultimodalTransformer(**model.hparams) + unpackged_method.load_state_dict(model.state_dict()) + model = unpackged_method + +model = model.to(device) + +# Introspection of config +# Resolve what tasks are requested by looking at what heads are available. +global_head_weights = getattr(model, 'global_head_weights', {}) +if config['with_change'] == 'auto': + config['with_change'] = getattr(model, 'global_change_weight', 1.0) or global_head_weights.get('change', 1) +if config['with_class'] == 'auto': + config['with_class'] = getattr(model, 'global_class_weight', 1.0) or global_head_weights.get('class', 1) +if config['with_saliency'] == 'auto': + config['with_saliency'] = getattr(model, 'global_saliency_weight', 0.0) or global_head_weights.get('saliency', 1) + + +test_dataloader = datamodule.test_dataloader() +batch_iter = iter(test_dataloader) + +from kwutil import util_progress +pman = util_progress.ProgressManager(backend='rich') + +# prog = ub.ProgIter(batch_iter, desc='fusion predict', verbose=1, freq=1) + +# Make threads after starting background proces. +if config.write_workers == 'datamodule': + config.write_workers = datamodule.num_workers +writer_queue = util_parallel.BlockingJobQueue( + mode='thread', + # mode='serial', + max_workers=config.write_workers +) +result_fpath = ub.Path(result_dataset.fpath) +result_fpath.parent.ensuredir() +print('result_fpath = {!r}'.format(result_fpath)) + +# Definition of the building stitching managers starts line 198 +# Need a change here to include hidden layer +# 226-247 pass info +# Stitching manager needs bands and short code +# chan_code = is important +# create a stitching manager and add to the dicstionary +# Model knows what classes it wants to predict +# creates a 4-channel raster, then takes "saliency" +# create another stitching manager to gather descriptor features +# "coco-stitcher" + +#-------------------------------------------------------------- +stitch_managers = build_stitching_managers( + config, model, result_dataset, + writer_queue=writer_queue +) +stitch_managers + +#-------------------------------------------------------------- + +expected_outputs = set(stitch_managers.keys()) +got_outputs = None +writable_outputs = None + +print('Expected outputs: ' + str(expected_outputs)) + +head_key_mapping = { + 'saliency_probs': 'saliency', + 'class_probs': 'class', + 'change_probs': 'change', + 'hidden_layers': 'hidden_layers', +} + +from geowatch.tasks.fusion.predict import _jsonify + +info = result_dataset.dataset.get('info', []) + +pred_dpath = ub.Path(result_dataset.fpath).parent +rich.print(f'Pred Dpath: [link={pred_dpath}]{pred_dpath}[/link]') + +DRAW_BATCHES = config.draw_batches +if DRAW_BATCHES: + viz_batch_dpath = (pred_dpath / '_viz_pred_batches').ensuredir() + +config_resolved = _jsonify(config.asdict()) +fit_config = _jsonify(fit_config) + +from kwcoco.util import util_json +unresolvable = list(util_json.find_json_unserializable(config_resolved)) +if unresolvable: + import warnings + warnings.warn(f'NotReproducibleWarning: Found unresolvable configuration options: {unresolvable!r}') + config_walker = ub.IndexableWalker(config_resolved) + for unresolvable_item in unresolvable: + _value = unresolvable_item['data'] + config_walker[unresolvable_item['loc']] = f'Unresolvable: {_value}' + + unresolvable = list(util_json.find_json_unserializable(config_resolved)) + assert not unresolvable, 'should have entered dummy values for unresolvable data' + +if config['record_context']: + from geowatch.utils import process_context + proc_context = process_context.ProcessContext( + name='geowatch.tasks.fusion.predict', + type='process', + config=config_resolved, + track_emissions=config['track_emissions'], + # Extra information was adjusted in 0.15.1 to ensure more relevant + # fit params are returned here. A script + # ~/code/geowatch/geowatch/cli/experimental/fixup_predict_kwcoco_metadata.py + # exist to help update old results to use this new format. + extra={ + 'fit_config': fit_config + } + ) + # assert not list(util_json.find_json_unserializable(proc_context.obj)) + info.append(proc_context.obj) + proc_context.start() + test_coco_dataset = datamodule.coco_datasets['test'] + proc_context.add_disk_info(test_coco_dataset.fpath) + +memory_monitor_timer = ub.Timer().tic() +memory_monitor_interval_seconds = 60 * 60 +with_memory_units = bool(ub.modname_to_modpath('pint')) + +#-------------------------------------------------------------------------- +#--------------------------------------------------------------------------- +torch.set_grad_enabled(False) + +EMERGENCY_INPUT_AGREEMENT_HACK = 1 and hasattr(model, 'input_norms') + +# prog.set_extra(' <will populate stats after first video>') +# pman.start() + +prog = pman.progiter(batch_iter, desc='fusion predict') +_batch_iter = iter(prog) + +if 0: + item = test_dataloader.dataset[0] + + orig_batch = next(_batch_iter) + item = orig_batch[0] + item['target'] + frame = item['frames'][0] + ub.peek(frame['modes'].values()).shape + +batch_idx = 0 + +# can ignore pman +pman.stopall() + +item = test_dataloader.dataset[0] +orig_batch = next(_batch_iter) + +# check out orig_batch + +# Iterates through every item in dataset +# We just step through loop once +batch_idx += 1 +batch_trs = [] + +# Move data onto the prediction device, grab spacetime region info +fixed_batch = [] +for item in orig_batch: + if item is None: + continue + item = item.copy() + batch_gids = [frame['gid'] for frame in item['frames']] + frame_infos = [ub.udict(f) & { + 'gid', + 'output_space_slice', + 'output_image_dsize', + 'output_weights', + 'scale_outspace_from_vid', + } for f in item['frames']] + batch_trs.append({ + 'space_slice': tuple(item['target']['space_slice']), + # 'scale': item['target']['scale'], + 'scale': item['target'].get('scale', None), + 'gids': batch_gids, + 'frame_infos': frame_infos, + 'fliprot_params': item['target'].get('fliprot_params', None) + }) + position_tensors = item.get('positional_tensors', None) + if position_tensors is not None: + for k, v in position_tensors.items(): + position_tensors[k] = v.to(device) + + filtered_frames = [] + for frame in item['frames']: + frame = frame.copy() + sensor = frame['sensor'] + if EMERGENCY_INPUT_AGREEMENT_HACK: + try: + known_sensor_modes = model.input_norms[sensor] + except KeyError: + known_sensor_modes = None + continue + filtered_modes = {} + modes = frame['modes'] + for key, mode in modes.items(): + if EMERGENCY_INPUT_AGREEMENT_HACK: + if key not in known_sensor_modes: + continue + filtered_modes[key] = mode.to(device) + frame['modes'] = filtered_modes + filtered_frames.append(frame) + item['frames'] = filtered_frames + fixed_batch.append(item) + +# fixes batch components +batch = fixed_batch + +# can view again `batch` + + +from geowatch.utils.util_netharn import _debug_inbatch_shapes +print(_debug_inbatch_shapes(batch)) + + +if memory_monitor_timer.toc() > memory_monitor_interval_seconds: + # TODO: monitor memory usage and report if it looks like we + # are about to run out of memory, and maybe do something to + # handle it. + from geowatch.utils import util_hardware + mem_info = util_hardware.get_mem_info(with_units=with_memory_units) + print(f'\n\nmem_info = {ub.urepr(mem_info, nl=1)}\n\n') + memory_monitor_timer.tic() + + +# Entire purpose of the function. Where model connects to data and runs +# neural network. Arbritrary code we don't control +# Need to inject something to find hidden state +# Prepare input --> model.forward_step --> prepare output +outputs = model.forward_step(batch, with_loss=False) + +# view `outputs` +# view `outputs.keys()` + +# Compatibility step +outputs = {head_key_mapping.get(k, k): v for k, v in outputs.items()} +outputs.keys() + +# Checks/hack, runs once +got_outputs = list(outputs.keys()) +prog.ensure_newline() +writable_outputs = set(got_outputs) & expected_outputs +print('got_outputs = {!r}'.format(got_outputs)) +print('writable_outputs = {!r}'.format(writable_outputs)) + +# For each item in the batch, process the results +for head_key in writable_outputs: + head_probs = outputs[head_key] + head_stitcher = stitch_managers[head_key] + chan_keep_idxs = head_stitcher.head_keep_idxs + +# hack +predicted_frame_slice = slice(None) + +num_batches = len(batch_trs) + +# ----------------------------------------------------------------------- +for bx in range(num_batches): + target: dict = batch_trs[bx] + item_head_probs: list[torch.Tensor] | torch.Tensor = head_probs[bx] + + # check the shape, was ([2,64,64,2]) for initial run + item_head_probs.shape + + # Keep only the channels we want to write to disk + # convert to numpy + item_head_relevant_probs = [p[..., chan_keep_idxs] for p in item_head_probs] + bin_probs = [p.detach().cpu().numpy() for p in item_head_relevant_probs] + + + + # check probs + DEBUG_PRED_SPATIAL_COVERAGE=0 + + frame_infos: list[dict] = target['frame_infos'][predicted_frame_slice] + + fliprot_params: dict = target['fliprot_params'] + # Update the stitcher with this windowed prediction + for probs, frame_info in zip(bin_probs, frame_infos): + if fliprot_params is not None: + # Undo fliprot TTA + probs = data_utils.inv_fliprot(probs, **fliprot_params) + + gid = frame_info['gid'] + output_image_dsize = frame_info['output_image_dsize'] + output_space_slice = frame_info['output_space_slice'] + scale_outspace_from_vid = frame_info['scale_outspace_from_vid'] + + if DEBUG_PRED_SPATIAL_COVERAGE: + image_id_to_video_space_slices[gid].append(target['space_slice']) + image_id_to_output_space_slices[gid].append(output_space_slice) + + output_weights = frame_info.get('output_weights', None) + + # View some details + print(probs.shape) + print(frame_info) + print(f"model._activation['hidden'] {model._activation_cache['hidden']}") + + + + # Checks if an image is done, submits to finalize + head_stitcher.accumulate_image( + gid, output_space_slice, probs, + asset_dsize=output_image_dsize, + scale_asset_from_stitchspace=scale_outspace_from_vid, + weights=output_weights, + downweight_edges=downweight_edges, + ) diff --git a/geowatch/tasks/fusion/predict.py b/geowatch/tasks/fusion/predict.py index 52c8dfe1c..6d0f545b1 100644 --- a/geowatch/tasks/fusion/predict.py +++ b/geowatch/tasks/fusion/predict.py @@ -201,6 +201,43 @@ class PredictConfig(DataModuleConfigMixin): is used. ''')) +# --------------Add hidden layer hook to model---------------- + +def _register_hidden_layer_hook(model): + # TODO: generalize to other models + # Specific to UNetR model + # These are at half of the output image resolution. + + model._activation_cache = {} + + # print("info on model", dir(model)) + + print("Enumerate over model.children()\n") + for i, layer in enumerate(model.children()): + print(f"Layer {i}: {layer}") + + # Not sure this is the correct code + encoder_layers = list(model.encoder.children()) + + print(f"\nNumber of encoder layers {len(encoder_layers)}\n") + + + print("Enumerate over model.encoder.children()\n") + for i, layer in enumerate(encoder_layers): + print(f"Encoder Layer {i}: {layer}") + + def record_hidden_activation(layer, input, output): + activation = output.detach() + model._activation_cache['hidden'] = activation + print(f"Hidden layer activation shape {activation.shape}") + + # Desired layer is the 5th layer of the 2nd encoder layer + layer_of_interest = list(model.encoder.children())[1][6] + layer_of_interest._forward_hooks.clear() + layer_of_interest.register_forward_hook(record_hidden_activation) +# ---------------------------------------------------------- + + # --------------Add hidden layer hook to model---------------- -- GitLab From 55e5c7c6dd730226cab961247f6984e71a295bcd Mon Sep 17 00:00:00 2001 From: Paul Beasly <paul.beasly@kitware.com> Date: Thu, 20 Jun 2024 12:10:27 -0400 Subject: [PATCH 2/3] Recent update of tutorial9 --- .../manual/tutorial/fusion_predict_notes.sh | 14 + .../tutorial/hidden_descriptors_notes.sh | 511 ++++++++++++++++++ .../manual/tutorial/hidden_layer_notes.sh | 43 ++ .../manual/tutorial/module_path_test.py | 6 + docs/source/manual/tutorial/record_demo.sh | 430 +++++++++++++++ .../manual/tutorial/tutorial9_hidden_layer.py | 49 +- geowatch/tasks/fusion/predict.py | 6 +- 7 files changed, 1046 insertions(+), 13 deletions(-) create mode 100644 docs/source/manual/tutorial/fusion_predict_notes.sh create mode 100644 docs/source/manual/tutorial/hidden_descriptors_notes.sh create mode 100644 docs/source/manual/tutorial/hidden_layer_notes.sh create mode 100644 docs/source/manual/tutorial/module_path_test.py create mode 100644 docs/source/manual/tutorial/record_demo.sh diff --git a/docs/source/manual/tutorial/fusion_predict_notes.sh b/docs/source/manual/tutorial/fusion_predict_notes.sh new file mode 100644 index 000000000..ac063c8f7 --- /dev/null +++ b/docs/source/manual/tutorial/fusion_predict_notes.sh @@ -0,0 +1,14 @@ +# notes on the structure and flow of the fusion predict model + +# Fusion predict creates predicted heatmaps +# Need to modify code to extract hidden layers + +# 1. Imports packages + +# 2. Define config classes + +# 2A. DataConfig `DataModuleConfigMixin()` class +# - does this need the `with_hidden` kwarg? +# 2B. PredictConfig, takes the previous config as arguments +# Noted as "Prediction script for the fusion task" +# - added `with_hidden_layers = scfg.Value('auto', help=None)` diff --git a/docs/source/manual/tutorial/hidden_descriptors_notes.sh b/docs/source/manual/tutorial/hidden_descriptors_notes.sh new file mode 100644 index 000000000..b7b14b4d3 --- /dev/null +++ b/docs/source/manual/tutorial/hidden_descriptors_notes.sh @@ -0,0 +1,511 @@ +__doc__=" + +Notes on exposing hidden layer descipritors in fusion/predict.py +================================================================ + +The script takes two inputs; the model and kwcoco data. Need to add a +stitching manager that gathers descritors/features from designated layer. Port +the paradigm from landcover/predict.py to fusion/predict.py. This will add +an option to output intermediate features from the model. +" + +# Overall steps for running the code: +# Setup uses the doctest info +# kwargs get overwritten by predict_kwargs +# Now the info is loaded into the namespace + +# the chan_code is where the features are saved + +#-----------------Basically----------------- +# 1. Get Activations from the model +# 2. Send to the stitching manager +#------------------------------------------------- + +# Next is to perform the prediction _run() function +# need to assign self to predictor for the class method +# script to modify +geowatch/tasks/fusion/predict.py + +# take the hidden layer hook from the script +geowatch/tasks/landcover/predict.py + +# Start with an interactive python sesstion +# Run code in doctests to setup paths and config. (approx line 1017-1069) +from geowatch.tasks.fusion.predict import * # NOQA (line 1017) +# ...to... +root_dpath = ub.Path(test_dpath, 'train').ensuredir() # line 1030 + +# test files go into cache directory specified by variable +test_dpath + +# run the code for config +fit_config = kwargs = {...} + +# the model being used is fit_lightning. To generate the model +fit_lightning.main(fit_config) + +# setup the predict_kwargs that overwrite the generic kwargs +# and perform the prediction (approx lines 1060-1069) + + +# Now we have the required info in the namespace. + +# Now move to the main script to run the predictor (line 1370) +# need to assign self to predictor for the class method + + +#-------------------------------------------- +# Notes on porting paradigm. The file is +tasks/landcover/predict.py +# line 227-240 finds the hidden layer +# lines 238-240 perform the register function +#----------------------------------------------------- + +# History of commands in interactive python session +# Need to run the ipython session in the geowatch directory + +# ----------------------------------------------------------------- +# List of what is imported in first line of code below +# {'utils', 'kwarray', 'datamodules', 'main', 'DataModuleConfigMixin', +# 'PredictConfig', 'kwimage', 'ub', 'Predictor', 'resolve_datamodule', +# 'profile', 'quantize_float01', 'before_import', 'build_stitching_managers', +# 'np', 'kwcoco', 'util_parallel', 'monkey_torch', 'monkey_torchmetrics', +# 'predict', 'CocoStitchingManager', 'scfg', 'torch', 'data_utils'} + + +#--------------------------------------------------------------------- +# From doctest (line 1016), import modules and set up some config args +from geowatch.tasks.fusion.predict import * # NOQA +import os +from geowatch.utils.lightning_ext.monkeypatches import disable_lightning_hardware_warnings +disable_lightning_hardware_warnings() +args = None +cmdline = False +devices = None +test_dpath = ub.Path.appdir('geowatch/tests/fusion/').ensuredir() +results_path = (test_dpath / 'predict').ensuredir() +results_path.delete() +results_path.ensuredir() + +# generate kwcoco toydataset +import kwcoco +train_dset = kwcoco.CocoDataset.demo('special:vidshapes2-gsize64-frames9-speed0.5-multispectral') +test_dset = kwcoco.CocoDataset.demo('special:vidshapes1-gsize64-frames9-speed0.5-multispectral') + +root_dpath = ub.Path(test_dpath, 'train').ensuredir() + +# create config +fit_config = kwargs = { + 'subcommand': 'fit', + 'fit.data.train_dataset': train_dset.fpath, + 'fit.data.time_steps': 2, + 'fit.data.time_span': "2m", + 'fit.data.chip_dims': 64, + 'fit.data.time_sampling': 'hardish3', + 'fit.data.num_workers': 0, + # 'package_fpath': package_fpath, + 'fit.model.class_path': 'geowatch.tasks.fusion.methods.MultimodalTransformer', + 'fit.model.init_args.global_change_weight': 1.0, + 'fit.model.init_args.global_class_weight': 1.0, + 'fit.model.init_args.global_saliency_weight': 1.0, + 'fit.optimizer.class_path': 'torch.optim.SGD', + 'fit.optimizer.init_args.lr': 1e-5, + 'fit.trainer.max_steps': 10, + 'fit.trainer.accelerator': 'cpu', + 'fit.trainer.devices': 1, + 'fit.trainer.max_epochs': 3, + 'fit.trainer.log_every_n_steps': 1, + 'fit.trainer.default_root_dir': os.fspath(root_dpath), +} + +from geowatch.tasks.fusion import fit_lightning +package_fpath = root_dpath / 'final_package.pt' + +# sets up the model +fit_lightning.main(fit_config) + +# Unfortunately, its not as easy to get the package path of +# this call.. +assert ub.Path(package_fpath).exists() +# Predict via that model +predict_kwargs = kwargs = { + 'package_fpath': package_fpath, + 'pred_dataset': ub.Path(results_path) / 'pred.kwcoco.json', + 'test_dataset': test_dset.fpath, + 'datamodule': 'KWCocoVideoDataModule', + 'batch_size': 1, + 'num_workers': 0, + 'devices': devices, + 'draw_batches': 1, +} + +# Does this run the prediction model? +result_dataset = predict(**kwargs) + +#----------------------------------------------------------------- +# Load config, model and dataset into predictor model +import rich +from rich.markup import escape +config = PredictConfig.cli(cmdline=cmdline, data=kwargs, strict=True) +rich.print('config = {}'.format(escape(ub.urepr(config, nl=2)))) +predictor = Predictor(config) +predictor._load_model() +predictor._load_dataset() + +#--------------------------------------------------------------- +# Move to the predictor _run() function (line 1370) +# Need to assign 'self' due to class method +self = predictor +datamodule = self.datamodule +model = self.model +config = self.config +test_coco_dataset = datamodule.coco_datasets['test'] + +# test_torch_dataset = datamodule.torch_datasets['test'] +# T, H, W = test_torch_dataset.window_dims + +# Create the results dataset as a copy of the test CocoDataset +print('Populate result dataset') +result_dataset: kwcoco.CocoDataset = test_coco_dataset.copy() + +# Remove all annotations in the results copy +if config['clear_annots']: + result_dataset.clear_annotations() + +# Change all paths to be absolute paths +result_dataset.reroot(absolute=True) +if not config['pred_dataset']: + raise ValueError( + f'Must specify path to the output (predicted) kwcoco file. ' + f'Got {config["pred_dataset"]=}') +result_dataset.fpath = str(ub.Path(config['pred_dataset']).expand()) + +from geowatch.utils.lightning_ext import util_device +print('devices = {!r}'.format(config['devices'])) +devices = util_device.coerce_devices(config['devices']) +print('devices = {!r}'.format(devices)) +if len(devices) > 1: + raise NotImplementedError('TODO: handle multiple devices') +device = devices[0] + +fit_config = self.fit_config + +#--------------------------------------------------------------- +# The critical work happens here at _predict_critical_loop(), line 1407 +# The data module produces batches +# We move into the function to perform the operations (line 557) + +print('Predict on device = {!r}'.format(device)) +downweight_edges = config.downweight_edges + +UNPACKAGE_METHOD_HACK = 0 +if UNPACKAGE_METHOD_HACK: + # unpackage model hack + from geowatch.tasks.fusion import methods + unpackged_method = methods.MultimodalTransformer(**model.hparams) + unpackged_method.load_state_dict(model.state_dict()) + model = unpackged_method + +model = model.to(device) + +# Introspection of config +# Resolve what tasks are requested by looking at what heads are available. +global_head_weights = getattr(model, 'global_head_weights', {}) +if config['with_change'] == 'auto': + config['with_change'] = getattr(model, 'global_change_weight', 1.0) or global_head_weights.get('change', 1) +if config['with_class'] == 'auto': + config['with_class'] = getattr(model, 'global_class_weight', 1.0) or global_head_weights.get('class', 1) +if config['with_saliency'] == 'auto': + config['with_saliency'] = getattr(model, 'global_saliency_weight', 0.0) or global_head_weights.get('saliency', 1) + + +test_dataloader = datamodule.test_dataloader() +batch_iter = iter(test_dataloader) + +from kwutil import util_progress +pman = util_progress.ProgressManager(backend='rich') + +# prog = ub.ProgIter(batch_iter, desc='fusion predict', verbose=1, freq=1) + +# Make threads after starting background proces. +if config.write_workers == 'datamodule': + config.write_workers = datamodule.num_workers +writer_queue = util_parallel.BlockingJobQueue( + mode='thread', + # mode='serial', + max_workers=config.write_workers +) +result_fpath = ub.Path(result_dataset.fpath) +result_fpath.parent.ensuredir() +print('result_fpath = {!r}'.format(result_fpath)) + +# Definition of the building stitching managers starts line 198 +# Need a change here to include hidden layer +# 226-247 pass info +# Stitching manager needs bands and short code +# chan_code = is important +# create a stitching manager and add to the dicstionary +# Model knows what classes it wants to predict +# creates a 4-channel raster, then takes "saliency" +# create another stitching manager to gather descriptor features +# "coco-stitcher" + +#-------------------------------------------------------------- +stitch_managers = build_stitching_managers( + config, model, result_dataset, + writer_queue=writer_queue +) +stitch_managers + +#-------------------------------------------------------------- + +expected_outputs = set(stitch_managers.keys()) +got_outputs = None +writable_outputs = None + +print('Expected outputs: ' + str(expected_outputs)) + +head_key_mapping = { + 'saliency_probs': 'saliency', + 'class_probs': 'class', + 'change_probs': 'change', +} + +from geowatch.tasks.fusion.predict import _jsonify + +info = result_dataset.dataset.get('info', []) + +pred_dpath = ub.Path(result_dataset.fpath).parent +rich.print(f'Pred Dpath: [link={pred_dpath}]{pred_dpath}[/link]') + +DRAW_BATCHES = config.draw_batches +if DRAW_BATCHES: + viz_batch_dpath = (pred_dpath / '_viz_pred_batches').ensuredir() + +config_resolved = _jsonify(config.asdict()) +fit_config = _jsonify(fit_config) + +from kwcoco.util import util_json +unresolvable = list(util_json.find_json_unserializable(config_resolved)) +if unresolvable: + import warnings + warnings.warn(f'NotReproducibleWarning: Found unresolvable configuration options: {unresolvable!r}') + config_walker = ub.IndexableWalker(config_resolved) + for unresolvable_item in unresolvable: + _value = unresolvable_item['data'] + config_walker[unresolvable_item['loc']] = f'Unresolvable: {_value}' + + unresolvable = list(util_json.find_json_unserializable(config_resolved)) + assert not unresolvable, 'should have entered dummy values for unresolvable data' + +if config['record_context']: + from geowatch.utils import process_context + proc_context = process_context.ProcessContext( + name='geowatch.tasks.fusion.predict', + type='process', + config=config_resolved, + track_emissions=config['track_emissions'], + # Extra information was adjusted in 0.15.1 to ensure more relevant + # fit params are returned here. A script + # ~/code/geowatch/geowatch/cli/experimental/fixup_predict_kwcoco_metadata.py + # exist to help update old results to use this new format. + extra={ + 'fit_config': fit_config + } + ) + # assert not list(util_json.find_json_unserializable(proc_context.obj)) + info.append(proc_context.obj) + proc_context.start() + test_coco_dataset = datamodule.coco_datasets['test'] + proc_context.add_disk_info(test_coco_dataset.fpath) + +memory_monitor_timer = ub.Timer().tic() +memory_monitor_interval_seconds = 60 * 60 +with_memory_units = bool(ub.modname_to_modpath('pint')) + +#-------------------------------------------------------------------------- +#--------------------------------------------------------------------------- +torch.set_grad_enabled(False) + +EMERGENCY_INPUT_AGREEMENT_HACK = 1 and hasattr(model, 'input_norms') + +# prog.set_extra(' <will populate stats after first video>') +# pman.start() + +prog = pman.progiter(batch_iter, desc='fusion predict') +_batch_iter = iter(prog) + +if 0: + item = test_dataloader.dataset[0] + + orig_batch = next(_batch_iter) + item = orig_batch[0] + item['target'] + frame = item['frames'][0] + ub.peek(frame['modes'].values()).shape + +batch_idx = 0 + +# can ignore pman +pman.stopall() + +item = test_dataloader.dataset[0] +orig_batch = next(_batch_iter) + +# check out orig_batch + +# Iterates through every item in dataset +# We just step through loop once +batch_idx += 1 +batch_trs = [] + +# Move data onto the prediction device, grab spacetime region info +fixed_batch = [] +for item in orig_batch: + if item is None: + continue + item = item.copy() + batch_gids = [frame['gid'] for frame in item['frames']] + frame_infos = [ub.udict(f) & { + 'gid', + 'output_space_slice', + 'output_image_dsize', + 'output_weights', + 'scale_outspace_from_vid', + } for f in item['frames']] + batch_trs.append({ + 'space_slice': tuple(item['target']['space_slice']), + # 'scale': item['target']['scale'], + 'scale': item['target'].get('scale', None), + 'gids': batch_gids, + 'frame_infos': frame_infos, + 'fliprot_params': item['target'].get('fliprot_params', None) + }) + position_tensors = item.get('positional_tensors', None) + if position_tensors is not None: + for k, v in position_tensors.items(): + position_tensors[k] = v.to(device) + + filtered_frames = [] + for frame in item['frames']: + frame = frame.copy() + sensor = frame['sensor'] + if EMERGENCY_INPUT_AGREEMENT_HACK: + try: + known_sensor_modes = model.input_norms[sensor] + except KeyError: + known_sensor_modes = None + continue + filtered_modes = {} + modes = frame['modes'] + for key, mode in modes.items(): + if EMERGENCY_INPUT_AGREEMENT_HACK: + if key not in known_sensor_modes: + continue + filtered_modes[key] = mode.to(device) + frame['modes'] = filtered_modes + filtered_frames.append(frame) + item['frames'] = filtered_frames + fixed_batch.append(item) + +# fixes batch components +batch = fixed_batch + +# can view again `batch` + + +from geowatch.utils.util_netharn import _debug_inbatch_shapes +print(_debug_inbatch_shapes(batch)) + + +if memory_monitor_timer.toc() > memory_monitor_interval_seconds: + # TODO: monitor memory usage and report if it looks like we + # are about to run out of memory, and maybe do something to + # handle it. + from geowatch.utils import util_hardware + mem_info = util_hardware.get_mem_info(with_units=with_memory_units) + print(f'\n\nmem_info = {ub.urepr(mem_info, nl=1)}\n\n') + memory_monitor_timer.tic() + + +# Entire purpose of the function. Where model connects to data and runs +# neural network. Arbritrary code we don't control +# Need to inject something to find hidden state +# Prepare input --> model.forward_step --> prepare output +outputs = model.forward_step(batch, with_loss=False) + +# view `outputs` +# view `outputs.keys()` + +# Compatibility step +outputs = {head_key_mapping.get(k, k): v for k, v in outputs.items()} +outputs.keys() + +# Checks/hack, runs once +got_outputs = list(outputs.keys()) +prog.ensure_newline() +writable_outputs = set(got_outputs) & expected_outputs +print('got_outputs = {!r}'.format(got_outputs)) +print('writable_outputs = {!r}'.format(writable_outputs)) + +# For each item in the batch, process the results +for head_key in writable_outputs: + head_probs = outputs[head_key] + head_stitcher = stitch_managers[head_key] + chan_keep_idxs = head_stitcher.head_keep_idxs + +# hack +predicted_frame_slice = slice(None) + +num_batches = len(batch_trs) + +# ----------------------------------------------------------------------- +for bx in range(num_batches): + target: dict = batch_trs[bx] + item_head_probs: list[torch.Tensor] | torch.Tensor = head_probs[bx] + + # check the shape, was ([2,64,64,2]) for initial run + item_head_probs.shape + + # Keep only the channels we want to write to disk + # convert to numpy + item_head_relevant_probs = [p[..., chan_keep_idxs] for p in item_head_probs] + bin_probs = [p.detach().cpu().numpy() for p in item_head_relevant_probs] + + + + # check probs + DEBUG_PRED_SPATIAL_COVERAGE=0 + + frame_infos: list[dict] = target['frame_infos'][predicted_frame_slice] + + fliprot_params: dict = target['fliprot_params'] + # Update the stitcher with this windowed prediction + for probs, frame_info in zip(bin_probs, frame_infos): + if fliprot_params is not None: + # Undo fliprot TTA + probs = data_utils.inv_fliprot(probs, **fliprot_params) + + gid = frame_info['gid'] + output_image_dsize = frame_info['output_image_dsize'] + output_space_slice = frame_info['output_space_slice'] + scale_outspace_from_vid = frame_info['scale_outspace_from_vid'] + + if DEBUG_PRED_SPATIAL_COVERAGE: + image_id_to_video_space_slices[gid].append(target['space_slice']) + image_id_to_output_space_slices[gid].append(output_space_slice) + + output_weights = frame_info.get('output_weights', None) + + # View some details + probs.shape + frame_info + + # Checks if an image is done, submits to finalize + head_stitcher.accumulate_image( + gid, output_space_slice, probs, + asset_dsize=output_image_dsize, + scale_asset_from_stitchspace=scale_outspace_from_vid, + weights=output_weights, + downweight_edges=downweight_edges, + ) diff --git a/docs/source/manual/tutorial/hidden_layer_notes.sh b/docs/source/manual/tutorial/hidden_layer_notes.sh new file mode 100644 index 000000000..0f645181c --- /dev/null +++ b/docs/source/manual/tutorial/hidden_layer_notes.sh @@ -0,0 +1,43 @@ +# Document the flow and to output hidden layers in landcover/predict + +# The code uses the model (DZYNE_LANDCOVER_MODEL) and drop6 data. +# Model is a based on pytorch for landcover classification. + +# Code requires some basic kwargs and runs the `predict` function. + +# 1. imports packages +# 2. Loads config from the `LandcoverPredictConfig` class and adds cli kwargs +# - this config class has `with_hidden` kwarg + +# 3. Unpack variables from the config, dataset and model +# - `model` is loaded by runnning + model_info = lookup_model_info(weights_filename) + ptdataset = model_info.create_dataset(input_dset) + model = model_info.load_model(weights_filename, device) + +# 4. Run the `CocoStitchingManager` to get the stitchers. + +# 5. For the `hidden_stitcher` option: +# - Runs the `_register_hidden_layer_hook` on with the loaded model +# - Adds `._activation_cache = {}` dict attribute to the model +# - Adds `layer_of_interest` from `model.decoder1[]` attribute +# - Clears hooks and `register_forward_hook()` using +# - `record_hidden_activation()` that detaches output and stores in +model._activation_cache['hidden'] = activation + + +# 6. Starts the main predict loop +# - Starts the `ProcessContext()` function +# - starts the `pman` progress manager +# - sets up the wrapping for the main loop `for img_info in _prog` +# - predicts current image with `_predict_single` +# - uses both stitchers + +# 7. Steps into `def _predict_single` function +# - There is a hidden_scale - need to figure what this is +# - Runs a sliding window over the image + +# 8. Extracts the hidden layer using +hidden_raw = model._activation_cache['hidden'].cpu().numpy() +# - Then hardcodes transformations of the activation layer +# - back to the image stitcher diff --git a/docs/source/manual/tutorial/module_path_test.py b/docs/source/manual/tutorial/module_path_test.py new file mode 100644 index 000000000..dc2b521e2 --- /dev/null +++ b/docs/source/manual/tutorial/module_path_test.py @@ -0,0 +1,6 @@ +# simple test to see if modeules are loaded + + + +from geowatch.tasks.fusion.predict import * +print(dir(Predictor)) diff --git a/docs/source/manual/tutorial/record_demo.sh b/docs/source/manual/tutorial/record_demo.sh new file mode 100644 index 000000000..42648c231 --- /dev/null +++ b/docs/source/manual/tutorial/record_demo.sh @@ -0,0 +1,430 @@ +__doc__=" + +Notes on exposing hidden layer descipritors in fusion/predict.py +================================================================ + +The script takes two inputs; the model and kwcoco data. Need to add a +stitching manager that gathers descritors/features from designated layer. Port +the paradigm from landcover/predict.py to fusion/predict.py. This will add +an option to output intermediate features from the model. +" + +# Overall steps for running the code: +# Setup uses the doctest info +# kwargs get overwritten by predict_kwargs +# Now the info is loaded into the namespace + +# Next is to perform the prediction _run() function +# need to assign self to predictor for the class method +# script to modify +geowatch/tasks/fusion/predict.py + +# take the hidden layer hook from the script +geowatch/tasks/landcover/predict.py + +# Start with an interactive pyton sesstion +# Run code in doctests to setup paths and config. (approx line 1017-1069) +from geowatch.tasks.fusion.predict import * # NOQA (line 1017) +# ...to... +root_dpath = ub.Path(test_dpath, 'train').ensuredir() # line 1030 + +# test files go into cache directory specified by variable +test_dpath + +# run the code for config +fit_config = kwargs = {...} + +# the model being used is fit_lightning. To generate the model +fit_lightning.main(fit_config) + +# setup the predict_kwargs that overwrite the generic kwargs +# and perform the prediction (approx lines 1060-1069) + + +# Now we have the required info in the namespace. + +# Now move to the main script to run the predictor (line 1370) +# need to assign self to predictor for the class method + + +# History of commands in interactive python session + +# From doctest, import modules and set up some config args +from geowatch.tasks.fusion.predict import * # NOQA +import os +from geowatch.utils.lightning_ext.monkeypatches import disable_lightning_hardware_warnings +disable_lightning_hardware_warnings() +args = None +cmdline = False +devices = None +test_dpath = ub.Path.appdir('geowatch/tests/fusion/').ensuredir() + +test_dpath # view path + +results_path = (test_dpath / 'predict').ensuredir() +results_path.delete() +results_path.ensuredir() + +# generate kwcoco toydataset +import kwcoco +train_dset = kwcoco.CocoDataset.demo('special:vidshapes2-gsize64-frames9-speed0.5-multispectral') +test_dset = kwcoco.CocoDataset.demo('special:vidshapes1-gsize64-frames9-speed0.5-multispectral') +20/4: + >>> root_dpath = ub.Path(test_dpath, 'train').ensuredir() + >>> fit_config = kwargs = { + ... 'subcommand': 'fit', + ... 'fit.data.train_dataset': train_dset.fpath, + ... 'fit.data.time_steps': 2, + ... 'fit.data.time_span': "2m", + ... 'fit.data.chip_dims': 64, + ... 'fit.data.time_sampling': 'hardish3', + ... 'fit.data.num_workers': 0, + ... #'package_fpath': package_fpath, + ... 'fit.model.class_path': 'geowatch.tasks.fusion.methods.MultimodalTransformer', + ... 'fit.model.init_args.global_change_weight': 1.0, + ... 'fit.model.init_args.global_class_weight': 1.0, + ... 'fit.model.init_args.global_saliency_weight': 1.0, + ... 'fit.optimizer.class_path': 'torch.optim.SGD', + ... 'fit.optimizer.init_args.lr': 1e-5, + ... 'fit.trainer.max_steps': 10, + ... 'fit.trainer.accelerator': 'cpu', + ... 'fit.trainer.devices': 1, + ... 'fit.trainer.max_epochs': 3, + ... 'fit.trainer.log_every_n_steps': 1, + ... 'fit.trainer.default_root_dir': os.fspath(root_dpath), + ... } + >>> from geowatch.tasks.fusion import fit_lightning + >>> package_fpath = root_dpath / 'final_package.pt' +20/5: >>> fit_lightning.main(fit_config) +20/6: + >>> # Unfortunately, its not as easy to get the package path of + >>> # this call.. + >>> assert ub.Path(package_fpath).exists() + >>> # Predict via that model + >>> predict_kwargs = kwargs = { + >>> 'package_fpath': package_fpath, + >>> 'pred_dataset': ub.Path(results_path) / 'pred.kwcoco.json', + >>> 'test_dataset': test_dset.fpath, + >>> 'datamodule': 'KWCocoVideoDataModule', + >>> 'batch_size': 1, + >>> 'num_workers': 0, + >>> 'devices': devices, + >>> 'draw_batches': 1, + >>> } +20/7: >>> result_dataset = predict(**kwargs) +20/8: + import rich + from rich.markup import escape + config = PredictConfig.cli(cmdline=cmdline, data=kwargs, strict=True) + rich.print('config = {}'.format(escape(ub.urepr(config, nl=2)))) +20/9: predictor = Predictor(config) +20/10: + predictor._load_model() + predictor._load_dataset() +20/11: self = predictor +20/12: + datamodule = self.datamodule + model = self.model + config = self.config + + test_coco_dataset = datamodule.coco_datasets['test'] + + # test_torch_dataset = datamodule.torch_datasets['test'] + # T, H, W = test_torch_dataset.window_dims + + # Create the results dataset as a copy of the test CocoDataset + print('Populate result dataset') + result_dataset: kwcoco.CocoDataset = test_coco_dataset.copy() + + # Remove all annotations in the results copy + if config['clear_annots']: + result_dataset.clear_annotations() + + # Change all paths to be absolute paths + result_dataset.reroot(absolute=True) + if not config['pred_dataset']: + raise ValueError( + f'Must specify path to the output (predicted) kwcoco file. ' + f'Got {config["pred_dataset"]=}') + result_dataset.fpath = str(ub.Path(config['pred_dataset']).expand()) + + from geowatch.utils.lightning_ext import util_device + print('devices = {!r}'.format(config['devices'])) + devices = util_device.coerce_devices(config['devices']) + print('devices = {!r}'.format(devices)) + if len(devices) > 1: + raise NotImplementedError('TODO: handle multiple devices') + device = devices[0] + + fit_config = self.fit_config +20/13: + import rich + + print('Predict on device = {!r}'.format(device)) + downweight_edges = config.downweight_edges + + UNPACKAGE_METHOD_HACK = 0 + if UNPACKAGE_METHOD_HACK: + # unpackage model hack + from geowatch.tasks.fusion import methods + unpackged_method = methods.MultimodalTransformer(**model.hparams) + unpackged_method.load_state_dict(model.state_dict()) + model = unpackged_method + + model = model.to(device) + global_head_weights = getattr(model, 'global_head_weights', {}) + if config['with_change'] == 'auto': + config['with_change'] = getattr(model, 'global_change_weight', 1.0) or global_head_weights.get('change', 1) + if config['with_class'] == 'auto': + config['with_class'] = getattr(model, 'global_class_weight', 1.0) or global_head_weights.get('class', 1) + if config['with_saliency'] == 'auto': + config['with_saliency'] = getattr(model, 'global_saliency_weight', 0.0) or global_head_weights.get('saliency', 1) +20/14: + test_dataloader = datamodule.test_dataloader() + batch_iter = iter(test_dataloader) + + from kwutil import util_progress + pman = util_progress.ProgressManager(backend='rich') + + # prog = ub.ProgIter(batch_iter, desc='fusion predict', verbose=1, freq=1) + + # Make threads after starting background proces. + if config.write_workers == 'datamodule': + config.write_workers = datamodule.num_workers + writer_queue = util_parallel.BlockingJobQueue( + mode='thread', + # mode='serial', + max_workers=config.write_workers + ) +20/15: + result_fpath = ub.Path(result_dataset.fpath) + result_fpath.parent.ensuredir() + print('result_fpath = {!r}'.format(result_fpath)) + + stitch_managers = build_stitching_managers( + config, model, result_dataset, + writer_queue=writer_queue + ) +20/16: stitch_managers +20/17: # error with indentation, move to 20/18 +20/18: + expected_outputs = set(stitch_managers.keys()) + got_outputs = None + writable_outputs = None + + print('Expected outputs: ' + str(expected_outputs)) + + head_key_mapping = { + 'saliency_probs': 'saliency', + 'class_probs': 'class', + 'change_probs': 'change', + } +20/19: # error, need to import _jsonify. Jump to 20/21 + +20/21: from geowatch.tasks.fusion.predict import _jsonify +20/22: + info = result_dataset.dataset.get('info', []) + + pred_dpath = ub.Path(result_dataset.fpath).parent + rich.print(f'Pred Dpath: [link={pred_dpath}]{pred_dpath}[/link]') + + DRAW_BATCHES = config.draw_batches + if DRAW_BATCHES: + viz_batch_dpath = (pred_dpath / '_viz_pred_batches').ensuredir() + + config_resolved = _jsonify(config.asdict()) + fit_config = _jsonify(fit_config) + + from kwcoco.util import util_json + unresolvable = list(util_json.find_json_unserializable(config_resolved)) + if unresolvable: + import warnings + warnings.warn(f'NotReproducibleWarning: Found unresolvable configuration options: {unresolvable!r}') + config_walker = ub.IndexableWalker(config_resolved) + for unresolvable_item in unresolvable: + _value = unresolvable_item['data'] + config_walker[unresolvable_item['loc']] = f'Unresolvable: {_value}' + + unresolvable = list(util_json.find_json_unserializable(config_resolved)) + assert not unresolvable, 'should have entered dummy values for unresolvable data' +20/23: + if config['record_context']: + from geowatch.utils import process_context + proc_context = process_context.ProcessContext( + name='geowatch.tasks.fusion.predict', + type='process', + config=config_resolved, + track_emissions=config['track_emissions'], + # Extra information was adjusted in 0.15.1 to ensure more relevant + # fit params are returned here. A script + # ~/code/geowatch/geowatch/cli/experimental/fixup_predict_kwcoco_metadata.py + # exist to help update old results to use this new format. + extra={ + 'fit_config': fit_config + } + ) + # assert not list(util_json.find_json_unserializable(proc_context.obj)) + info.append(proc_context.obj) + proc_context.start() + test_coco_dataset = datamodule.coco_datasets['test'] + proc_context.add_disk_info(test_coco_dataset.fpath) + + memory_monitor_timer = ub.Timer().tic() + memory_monitor_interval_seconds = 60 * 60 + with_memory_units = bool(ub.modname_to_modpath('pint')) +20/24: torch.set_grad_enabled(False) +20/25: + EMERGENCY_INPUT_AGREEMENT_HACK = 1 and hasattr(model, 'input_norms') + + # prog.set_extra(' <will populate stats after first video>') + # pman.start() + + prog = pman.progiter(batch_iter, desc='fusion predict') + _batch_iter = iter(prog) + if 0: + item = test_dataloader.dataset[0] + + orig_batch = next(_batch_iter) + item = orig_batch[0] + item['target'] + frame = item['frames'][0] + ub.peek(frame['modes'].values()).shape + + batch_idx = 0 +20/26: pman.stopall() +20/27: item = test_dataloader.dataset[0] +20/28: orig_batch = next(_batch_iter) +20/29: orig_batch +20/30: + batch_idx += 1 + batch_trs = [] + # Move data onto the prediction device, grab spacetime region info + fixed_batch = [] + for item in orig_batch: + if item is None: + continue + item = item.copy() + batch_gids = [frame['gid'] for frame in item['frames']] + frame_infos = [ub.udict(f) & { + 'gid', + 'output_space_slice', + 'output_image_dsize', + 'output_weights', + 'scale_outspace_from_vid', + } for f in item['frames']] + batch_trs.append({ + 'space_slice': tuple(item['target']['space_slice']), + # 'scale': item['target']['scale'], + 'scale': item['target'].get('scale', None), + 'gids': batch_gids, + 'frame_infos': frame_infos, + 'fliprot_params': item['target'].get('fliprot_params', None) + }) + position_tensors = item.get('positional_tensors', None) + if position_tensors is not None: + for k, v in position_tensors.items(): + position_tensors[k] = v.to(device) + + filtered_frames = [] + for frame in item['frames']: + frame = frame.copy() + sensor = frame['sensor'] + if EMERGENCY_INPUT_AGREEMENT_HACK: + try: + known_sensor_modes = model.input_norms[sensor] + except KeyError: + known_sensor_modes = None + continue + filtered_modes = {} + modes = frame['modes'] + for key, mode in modes.items(): + if EMERGENCY_INPUT_AGREEMENT_HACK: + if key not in known_sensor_modes: + continue + filtered_modes[key] = mode.to(device) + frame['modes'] = filtered_modes + filtered_frames.append(frame) + item['frames'] = filtered_frames + fixed_batch.append(item) +20/31: batch = fixed_batch +20/32: batch +20/33: + from geowatch.utils.util_netharn import _debug_inbatch_shapes + print(_debug_inbatch_shapes(batch)) +20/34: + if memory_monitor_timer.toc() > memory_monitor_interval_seconds: + # TODO: monitor memory usage and report if it looks like we + # are about to run out of memory, and maybe do something to + # handle it. + from geowatch.utils import util_hardware + mem_info = util_hardware.get_mem_info(with_units=with_memory_units) + print(f'\n\nmem_info = {ub.urepr(mem_info, nl=1)}\n\n') + memory_monitor_timer.tic() +20/35: mem_info +20/36: outputs = model.forward_step(batch, with_loss=False) +20/37: outputs +20/38: # error +20/39: outputs.keys() +20/40: outputs = {head_key_mapping.get(k, k): v for k, v in outputs.items()} +20/41: outputs.keys() +20/42: + got_outputs = list(outputs.keys()) + prog.ensure_newline() + writable_outputs = set(got_outputs) & expected_outputs + print('got_outputs = {!r}'.format(got_outputs)) + print('writable_outputs = {!r}'.format(writable_outputs)) +20/43: + for head_key in writable_outputs: + head_probs = outputs[head_key] +20/44: head_key +20/45: head_probs +20/46-48: # errors +20/49: + head_stitcher = stitch_managers[head_key] + chan_keep_idxs = head_stitcher.head_keep_idxs +20/50: chan_keep_idxs +20/51: predicted_frame_slice = slice(None) +20/52: + num_batches = len(batch_trs) + + for bx in range(num_batches): + target: dict = batch_trs[bx] + item_head_probs: list[torch.Tensor] | torch.Tensor = head_probs[bx] +20/53: item_head_probs +20/54: item_head_probs.shape +20/55: + item_head_relevant_probs = [p[..., chan_keep_idxs] for p in item_head_probs] + bin_probs = [p.detach().cpu().numpy() for p in item_head_relevant_probs] +20/56: # error +20/57: DEBUG_PRED_SPATIAL_COVERAGE=0 +20/58: + frame_infos: list[dict] = target['frame_infos'][predicted_frame_slice] + + fliprot_params: dict = target['fliprot_params'] + # Update the stitcher with this windowed prediction + for probs, frame_info in zip(bin_probs, frame_infos): + if fliprot_params is not None: + # Undo fliprot TTA + probs = data_utils.inv_fliprot(probs, **fliprot_params) + + gid = frame_info['gid'] + output_image_dsize = frame_info['output_image_dsize'] + output_space_slice = frame_info['output_space_slice'] + scale_outspace_from_vid = frame_info['scale_outspace_from_vid'] + + if DEBUG_PRED_SPATIAL_COVERAGE: + image_id_to_video_space_slices[gid].append(target['space_slice']) + image_id_to_output_space_slices[gid].append(output_space_slice) + + output_weights = frame_info.get('output_weights', None) +20/59: probs +20/60: probs.shape +20/61: frame_info +20/62: + head_stitcher.accumulate_image( + gid, output_space_slice, probs, + asset_dsize=output_image_dsize, + scale_asset_from_stitchspace=scale_outspace_from_vid, + weights=output_weights, + downweight_edges=downweight_edges, + ) diff --git a/docs/source/manual/tutorial/tutorial9_hidden_layer.py b/docs/source/manual/tutorial/tutorial9_hidden_layer.py index 96b8b90e2..8d57c8292 100644 --- a/docs/source/manual/tutorial/tutorial9_hidden_layer.py +++ b/docs/source/manual/tutorial/tutorial9_hidden_layer.py @@ -1,12 +1,14 @@ -# Tutorial to extract hidden layer features from a pre-trained model +# Tutorial code to extract hidden layer features from a pre-trained model +# Need to then pass it to the stitching manager import sys import os import ubelt as ub - -# import module +# -------------------------------------------------------------- +# import module and details from the doctest example from geowatch.tasks.fusion.predict import * # NOQA +from geowatch.tasks.fusion.predict import _register_hidden_layer_hook print(Predictor) from geowatch.utils.lightning_ext.monkeypatches import disable_lightning_hardware_warnings disable_lightning_hardware_warnings() @@ -52,11 +54,12 @@ fit_config = kwargs = { from geowatch.tasks.fusion import fit_lightning package_fpath = root_dpath / 'final_package.pt' -# sets up the model +# sets up the training model and trains it fit_lightning.main(fit_config) # Unfortunately, its not as easy to get the package path of # this call.. +# Looks like it setups up kwargs for the prediction model assert ub.Path(package_fpath).exists() # Predict via that model predict_kwargs = kwargs = { @@ -69,9 +72,12 @@ predict_kwargs = kwargs = { 'devices': devices, 'draw_batches': 1, } - -# Does this run the prediction model? -result_dataset = predict(**kwargs) +#--------------------------------- +# Performs a prediction on the dataset +# this runs parts of the code shown below. Script runs the various portions +#------------------------------------- +## result_dataset = predict(**kwargs) +#--------------------------------------- #----------------------------------------------------------------- # Load config, model and dataset into predictor model @@ -174,8 +180,8 @@ print('result_fpath = {!r}'.format(result_fpath)) # Need a change here to include hidden layer # 226-247 pass info # Stitching manager needs bands and short code -# chan_code = is important -# create a stitching manager and add to the dicstionary +# chan_code = is important, provides the path to the features/activations +# create a stitching manager and add to the dictionary # Model knows what classes it wants to predict # creates a 4-channel raster, then takes "saliency" # create another stitching manager to gather descriptor features @@ -366,6 +372,23 @@ if memory_monitor_timer.toc() > memory_monitor_interval_seconds: # Prepare input --> model.forward_step --> prepare output outputs = model.forward_step(batch, with_loss=False) + +# Hidden_layers have been extracted at this point +print(model._activation_cache.keys()) +print(model._activation_cache['hidden']) +print(model._activation_cache['hidden'].shape) + +#----this needs to get added to the source code + +hidden_stitcher = CocoStitchingManager( + result_dataset + 'hidden_layers', + chan_code='hidden_layers', + stiching_space='image', + writer_queue=writer_queue, + assets_dname=config.assets_dname, +) + # view `outputs` # view `outputs.keys()` @@ -376,7 +399,8 @@ outputs.keys() # Checks/hack, runs once got_outputs = list(outputs.keys()) prog.ensure_newline() -writable_outputs = set(got_outputs) & expected_outputs +# This code eliminates the "hidden_layers" key +writable_outputs = expected_outputs & set(got_outputs) print('got_outputs = {!r}'.format(got_outputs)) print('writable_outputs = {!r}'.format(writable_outputs)) @@ -392,6 +416,8 @@ predicted_frame_slice = slice(None) num_batches = len(batch_trs) # ----------------------------------------------------------------------- +# This code looks like it does some cleanup, filtering and then +# writes the results. for bx in range(num_batches): target: dict = batch_trs[bx] item_head_probs: list[torch.Tensor] | torch.Tensor = head_probs[bx] @@ -433,8 +459,9 @@ for bx in range(num_batches): print(probs.shape) print(frame_info) print(f"model._activation['hidden'] {model._activation_cache['hidden']}") + print(f"\n hidden layer shape {model._activation_cache['hidden'].shape}") - + head_stitcher = stitch_managers["hidden_layers"] # Checks if an image is done, submits to finalize head_stitcher.accumulate_image( diff --git a/geowatch/tasks/fusion/predict.py b/geowatch/tasks/fusion/predict.py index 6d0f545b1..06aeb1172 100644 --- a/geowatch/tasks/fusion/predict.py +++ b/geowatch/tasks/fusion/predict.py @@ -229,14 +229,16 @@ def _register_hidden_layer_hook(model): def record_hidden_activation(layer, input, output): activation = output.detach() model._activation_cache['hidden'] = activation - print(f"Hidden layer activation shape {activation.shape}") + print(f"Hidden Layer Extracted! Shape is {activation.shape}") # Desired layer is the 5th layer of the 2nd encoder layer + # See `/docs/source/manual/tutorial/fusion_model_layer_info.sh` + # for an example structure of the model layer_of_interest = list(model.encoder.children())[1][6] layer_of_interest._forward_hooks.clear() layer_of_interest.register_forward_hook(record_hidden_activation) -# ---------------------------------------------------------- +# ---------------------------------------------------------- # --------------Add hidden layer hook to model---------------- -- GitLab From 7e469cba48922ce2a986e576cf08fcd4448d00a4 Mon Sep 17 00:00:00 2001 From: joncrall <jon.crall@kitware.com> Date: Thu, 20 Jun 2024 13:11:24 -0400 Subject: [PATCH 3/3] Hack outputs to contain hidden features --- .../manual/tutorial/tutorial9_hidden_layer.py | 20 ++++++------- geowatch/tasks/fusion/predict.py | 29 +++++++++---------- geowatch_tpl/submodules/loss-of-plasticity | 2 +- 3 files changed, 25 insertions(+), 26 deletions(-) diff --git a/docs/source/manual/tutorial/tutorial9_hidden_layer.py b/docs/source/manual/tutorial/tutorial9_hidden_layer.py index 8d57c8292..e1db42bf2 100644 --- a/docs/source/manual/tutorial/tutorial9_hidden_layer.py +++ b/docs/source/manual/tutorial/tutorial9_hidden_layer.py @@ -378,16 +378,16 @@ print(model._activation_cache.keys()) print(model._activation_cache['hidden']) print(model._activation_cache['hidden'].shape) -#----this needs to get added to the source code - -hidden_stitcher = CocoStitchingManager( - result_dataset - 'hidden_layers', - chan_code='hidden_layers', - stiching_space='image', - writer_queue=writer_queue, - assets_dname=config.assets_dname, -) +# #----this needs to get added to the source code + +# hidden_stitcher = CocoStitchingManager( +# result_dataset +# 'hidden_layers', +# chan_code='hidden_layers', +# stiching_space='image', +# writer_queue=writer_queue, +# assets_dname=config.assets_dname, +# ) # view `outputs` # view `outputs.keys()` diff --git a/geowatch/tasks/fusion/predict.py b/geowatch/tasks/fusion/predict.py index 06aeb1172..87d3db976 100644 --- a/geowatch/tasks/fusion/predict.py +++ b/geowatch/tasks/fusion/predict.py @@ -212,29 +212,28 @@ def _register_hidden_layer_hook(model): # print("info on model", dir(model)) - print("Enumerate over model.children()\n") - for i, layer in enumerate(model.children()): - print(f"Layer {i}: {layer}") + # print("Enumerate over model.children()\n") + # for i, layer in enumerate(model.children()): + # print(f"Layer {i}: {layer}") # Not sure this is the correct code - encoder_layers = list(model.encoder.children()) - - print(f"\nNumber of encoder layers {len(encoder_layers)}\n") + # Hack to grab the inputs to one of the heads + # This will let us grab pre-formated spacetime features + # out of the multimodal model. + available_decoders = (ub.oset(['saliency', 'class', 'change']) & model.heads.keys()) + chosen_head_key = available_decoders[0] + layer_of_interest = model.heads[chosen_head_key].hidden.hidden0.conv - print("Enumerate over model.encoder.children()\n") - for i, layer in enumerate(encoder_layers): - print(f"Encoder Layer {i}: {layer}") - - def record_hidden_activation(layer, input, output): - activation = output.detach() + def record_hidden_activation(layer, inputs, output): + assert len(inputs) == 1 + input_features = inputs[0] + activation = input_features.detach() model._activation_cache['hidden'] = activation print(f"Hidden Layer Extracted! Shape is {activation.shape}") - # Desired layer is the 5th layer of the 2nd encoder layer # See `/docs/source/manual/tutorial/fusion_model_layer_info.sh` # for an example structure of the model - layer_of_interest = list(model.encoder.children())[1][6] layer_of_interest._forward_hooks.clear() layer_of_interest.register_forward_hook(record_hidden_activation) @@ -549,7 +548,7 @@ def resolve_datamodule(config, model, datamodule_defaults, fit_config): DZYNE_MODEL_HACK = 1 if DZYNE_MODEL_HACK and isinstance(config['package_fpath'], str): package_fpath = ub.Path(config['package_fpath']) - if package_fpath.stem == 'lc_rgb_fusion_model_package': + if package_fpath.stest_dataloadertem == 'lc_rgb_fusion_model_package': # This model has an issue with the L8 features it was trained on datamodule_vars['exclude_sensors'] = ['L8'] diff --git a/geowatch_tpl/submodules/loss-of-plasticity b/geowatch_tpl/submodules/loss-of-plasticity index eb0868220..50f7a5614 160000 --- a/geowatch_tpl/submodules/loss-of-plasticity +++ b/geowatch_tpl/submodules/loss-of-plasticity @@ -1 +1 @@ -Subproject commit eb0868220cebb4930476ff51110cffddfb983b0a +Subproject commit 50f7a561445e734c7444dc4f354d9b22b2644cb1 -- GitLab