From 436e12c1dea1cfb59d1ffc73cf1abc56e5cc81ae Mon Sep 17 00:00:00 2001
From: Paul Beasly <paul.beasly@kitware.com>
Date: Wed, 12 Jun 2024 17:36:23 -0400
Subject: [PATCH 1/3] Update fusion/predict.py with hidden_layers

---
 docs/source/manual/tutorial/__init__.py       |   0
 .../tutorial/encoder_layers_fusion_predict.sh | 133 ++++++
 .../tutorial/fusion_model_layer_info.sh       | 266 +++++++++++
 .../manual/tutorial/tutorial9_hidden_layer.py | 446 ++++++++++++++++++
 geowatch/tasks/fusion/predict.py              |  37 ++
 5 files changed, 882 insertions(+)
 create mode 100644 docs/source/manual/tutorial/__init__.py
 create mode 100644 docs/source/manual/tutorial/encoder_layers_fusion_predict.sh
 create mode 100644 docs/source/manual/tutorial/fusion_model_layer_info.sh
 create mode 100644 docs/source/manual/tutorial/tutorial9_hidden_layer.py

diff --git a/docs/source/manual/tutorial/__init__.py b/docs/source/manual/tutorial/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/docs/source/manual/tutorial/encoder_layers_fusion_predict.sh b/docs/source/manual/tutorial/encoder_layers_fusion_predict.sh
new file mode 100644
index 000000000..c302bcb91
--- /dev/null
+++ b/docs/source/manual/tutorial/encoder_layers_fusion_predict.sh
@@ -0,0 +1,133 @@
+# output of encoder layers
+
+Encoder Layer 0: Linear(in_features=560, out_features=128, bias=True)
+Encoder Layer 1: Sequential(
+  (0): ChannelwiseTransformerEncoderLayer(
+    (attention_modules): ModuleDict(
+      (time mode height width): ResidualAttentionSequential(
+        (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
+        (1): MultiheadSelfAttention(
+          (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
+        )
+      )
+    )
+    (mlp): ResidualSequential(
+      (0): Linear(in_features=128, out_features=128, bias=True)
+      (1): Dropout(p=0.1, inplace=False)
+      (2): GELU(approximate='none')
+      (3): Linear(in_features=128, out_features=128, bias=True)
+    )
+  )
+  (1): ChannelwiseTransformerEncoderLayer(
+    (attention_modules): ModuleDict(
+      (time mode height width): ResidualAttentionSequential(
+        (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
+        (1): MultiheadSelfAttention(
+          (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
+        )
+      )
+    )
+    (mlp): ResidualSequential(
+      (0): Linear(in_features=128, out_features=128, bias=True)
+      (1): Dropout(p=0.1, inplace=False)
+      (2): GELU(approximate='none')
+      (3): Linear(in_features=128, out_features=128, bias=True)
+    )
+  )
+  (2): ChannelwiseTransformerEncoderLayer(
+    (attention_modules): ModuleDict(
+      (time mode height width): ResidualAttentionSequential(
+        (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
+        (1): MultiheadSelfAttention(
+          (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
+        )
+      )
+    )
+    (mlp): ResidualSequential(
+      (0): Linear(in_features=128, out_features=128, bias=True)
+      (1): Dropout(p=0.1, inplace=False)
+      (2): GELU(approximate='none')
+      (3): Linear(in_features=128, out_features=128, bias=True)
+    )
+  )
+  (3): ChannelwiseTransformerEncoderLayer(
+    (attention_modules): ModuleDict(
+      (time mode height width): ResidualAttentionSequential(
+        (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
+        (1): MultiheadSelfAttention(
+          (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
+        )
+      )
+    )
+    (mlp): ResidualSequential(
+      (0): Linear(in_features=128, out_features=128, bias=True)
+      (1): Dropout(p=0.1, inplace=False)
+      (2): GELU(approximate='none')
+      (3): Linear(in_features=128, out_features=128, bias=True)
+    )
+  )
+  (4): ChannelwiseTransformerEncoderLayer(
+    (attention_modules): ModuleDict(
+      (time mode height width): ResidualAttentionSequential(
+        (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
+        (1): MultiheadSelfAttention(
+          (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
+        )
+      )
+    )
+    (mlp): ResidualSequential(
+      (0): Linear(in_features=128, out_features=128, bias=True)
+      (1): Dropout(p=0.1, inplace=False)
+      (2): GELU(approximate='none')
+      (3): Linear(in_features=128, out_features=128, bias=True)
+    )
+  )
+  (5): ChannelwiseTransformerEncoderLayer(
+    (attention_modules): ModuleDict(
+      (time mode height width): ResidualAttentionSequential(
+        (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
+        (1): MultiheadSelfAttention(
+          (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
+        )
+      )
+    )
+    (mlp): ResidualSequential(
+      (0): Linear(in_features=128, out_features=128, bias=True)
+      (1): Dropout(p=0.1, inplace=False)
+      (2): GELU(approximate='none')
+      (3): Linear(in_features=128, out_features=128, bias=True)
+    )
+  )
+  (6): ChannelwiseTransformerEncoderLayer(
+    (attention_modules): ModuleDict(
+      (time mode height width): ResidualAttentionSequential(
+        (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
+        (1): MultiheadSelfAttention(
+          (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
+        )
+      )
+    )
+    (mlp): ResidualSequential(
+      (0): Linear(in_features=128, out_features=128, bias=True)
+      (1): Dropout(p=0.1, inplace=False)
+      (2): GELU(approximate='none')
+      (3): Linear(in_features=128, out_features=128, bias=True)
+    )
+  )
+  (7): ChannelwiseTransformerEncoderLayer(
+    (attention_modules): ModuleDict(
+      (time mode height width): ResidualAttentionSequential(
+        (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
+        (1): MultiheadSelfAttention(
+          (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
+        )
+      )
+    )
+    (mlp): ResidualSequential(
+      (0): Linear(in_features=128, out_features=128, bias=True)
+      (1): Dropout(p=0.1, inplace=False)
+      (2): GELU(approximate='none')
+      (3): Linear(in_features=128, out_features=128, bias=True)
+    )
+  )
+)
diff --git a/docs/source/manual/tutorial/fusion_model_layer_info.sh b/docs/source/manual/tutorial/fusion_model_layer_info.sh
new file mode 100644
index 000000000..a4d96b5da
--- /dev/null
+++ b/docs/source/manual/tutorial/fusion_model_layer_info.sh
@@ -0,0 +1,266 @@
+# output of layers for fusion predict.py mdoel
+
+Layer 0: RobustModuleDict(
+  (*): RobustModuleDict(
+    (B1|B10|B11|B8|B8a): InputNorm()
+  )
+)
+Layer 1: RobustModuleDict(
+  (*): RobustModuleDict(
+    (B1|B10|B11|B8|B8a): RearrangeTokenizer(
+      (foot): MultiLayerPerceptronNd(
+        (hidden): Sequential(
+          (hidden0): ConvNormNd(
+            (conv): Conv2d(5, 6, kernel_size=(1, 1), stride=(1, 1))
+            (noli): ReLU(inplace=True)
+          )
+          (hidden1): ConvNormNd(
+            (conv): Conv2d(6, 6, kernel_size=(1, 1), stride=(1, 1))
+            (noli): ReLU(inplace=True)
+          )
+          (hidden2): ConvNormNd(
+            (conv): Conv2d(6, 7, kernel_size=(1, 1), stride=(1, 1))
+            (noli): ReLU(inplace=True)
+          )
+          (output): Conv2d(7, 8, kernel_size=(1, 1), stride=(1, 1))
+        )
+        (skip): Conv2d(5, 8, kernel_size=(1, 1), stride=(1, 1))
+      )
+    )
+  )
+)
+Layer 2: MultiLayerPerceptronNd(
+  (hidden): Sequential(
+    (hidden0): ConvNormNd(
+      (conv): Conv0d(in_features=1, out_features=3, bias=True)
+      (noli): ReLU(inplace=True)
+    )
+    (hidden1): ConvNormNd(
+      (conv): Conv0d(in_features=3, out_features=4, bias=True)
+      (noli): ReLU(inplace=True)
+    )
+    (hidden2): ConvNormNd(
+      (conv): Conv0d(in_features=4, out_features=6, bias=True)
+      (noli): ReLU(inplace=True)
+    )
+    (output): Conv0d(in_features=6, out_features=8, bias=True)
+  )
+  (skip): Conv0d(in_features=1, out_features=8, bias=True)
+)
+Layer 3: MultiLayerPerceptronNd(
+  (hidden): Sequential(
+    (hidden0): ConvNormNd(
+      (conv): Conv0d(in_features=16, out_features=14, bias=True)
+      (noli): ReLU(inplace=True)
+    )
+    (hidden1): ConvNormNd(
+      (conv): Conv0d(in_features=14, out_features=12, bias=True)
+      (noli): ReLU(inplace=True)
+    )
+    (hidden2): ConvNormNd(
+      (conv): Conv0d(in_features=12, out_features=10, bias=True)
+      (noli): ReLU(inplace=True)
+    )
+    (output): Conv0d(in_features=10, out_features=8, bias=True)
+  )
+  (skip): Conv0d(in_features=16, out_features=8, bias=True)
+)
+Layer 4: MultiLayerPerceptronNd(
+  (hidden): Sequential(
+    (hidden0): ConvNormNd(
+      (conv): Conv0d(in_features=16, out_features=14, bias=True)
+      (noli): ReLU(inplace=True)
+    )
+    (hidden1): ConvNormNd(
+      (conv): Conv0d(in_features=14, out_features=12, bias=True)
+      (noli): ReLU(inplace=True)
+    )
+    (hidden2): ConvNormNd(
+      (conv): Conv0d(in_features=12, out_features=10, bias=True)
+      (noli): ReLU(inplace=True)
+    )
+    (output): Conv0d(in_features=10, out_features=8, bias=True)
+  )
+  (skip): Conv0d(in_features=16, out_features=8, bias=True)
+)
+Layer 5: FusionEncoder(
+  (first): Linear(in_features=560, out_features=128, bias=True)
+  (layers): Sequential(
+    (0): ChannelwiseTransformerEncoderLayer(
+      (attention_modules): ModuleDict(
+        (time mode height width): ResidualAttentionSequential(
+          (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
+          (1): MultiheadSelfAttention(
+            (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
+          )
+        )
+      )
+      (mlp): ResidualSequential(
+        (0): Linear(in_features=128, out_features=128, bias=True)
+        (1): Dropout(p=0.1, inplace=False)
+        (2): GELU(approximate='none')
+        (3): Linear(in_features=128, out_features=128, bias=True)
+      )
+    )
+    (1): ChannelwiseTransformerEncoderLayer(
+      (attention_modules): ModuleDict(
+        (time mode height width): ResidualAttentionSequential(
+          (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
+          (1): MultiheadSelfAttention(
+            (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
+          )
+        )
+      )
+      (mlp): ResidualSequential(
+        (0): Linear(in_features=128, out_features=128, bias=True)
+        (1): Dropout(p=0.1, inplace=False)
+        (2): GELU(approximate='none')
+        (3): Linear(in_features=128, out_features=128, bias=True)
+      )
+    )
+    (2): ChannelwiseTransformerEncoderLayer(
+      (attention_modules): ModuleDict(
+        (time mode height width): ResidualAttentionSequential(
+          (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
+          (1): MultiheadSelfAttention(
+            (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
+          )
+        )
+      )
+      (mlp): ResidualSequential(
+        (0): Linear(in_features=128, out_features=128, bias=True)
+        (1): Dropout(p=0.1, inplace=False)
+        (2): GELU(approximate='none')
+        (3): Linear(in_features=128, out_features=128, bias=True)
+      )
+    )
+    (3): ChannelwiseTransformerEncoderLayer(
+      (attention_modules): ModuleDict(
+        (time mode height width): ResidualAttentionSequential(
+          (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
+          (1): MultiheadSelfAttention(
+            (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
+          )
+        )
+      )
+      (mlp): ResidualSequential(
+        (0): Linear(in_features=128, out_features=128, bias=True)
+        (1): Dropout(p=0.1, inplace=False)
+        (2): GELU(approximate='none')
+        (3): Linear(in_features=128, out_features=128, bias=True)
+      )
+    )
+    (4): ChannelwiseTransformerEncoderLayer(
+      (attention_modules): ModuleDict(
+        (time mode height width): ResidualAttentionSequential(
+          (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
+          (1): MultiheadSelfAttention(
+            (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
+          )
+        )
+      )
+      (mlp): ResidualSequential(
+        (0): Linear(in_features=128, out_features=128, bias=True)
+        (1): Dropout(p=0.1, inplace=False)
+        (2): GELU(approximate='none')
+        (3): Linear(in_features=128, out_features=128, bias=True)
+      )
+    )
+    (5): ChannelwiseTransformerEncoderLayer(
+      (attention_modules): ModuleDict(
+        (time mode height width): ResidualAttentionSequential(
+          (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
+          (1): MultiheadSelfAttention(
+            (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
+          )
+        )
+      )
+      (mlp): ResidualSequential(
+        (0): Linear(in_features=128, out_features=128, bias=True)
+        (1): Dropout(p=0.1, inplace=False)
+        (2): GELU(approximate='none')
+        (3): Linear(in_features=128, out_features=128, bias=True)
+      )
+    )
+    (6): ChannelwiseTransformerEncoderLayer(
+      (attention_modules): ModuleDict(
+        (time mode height width): ResidualAttentionSequential(
+          (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
+          (1): MultiheadSelfAttention(
+            (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
+          )
+        )
+      )
+      (mlp): ResidualSequential(
+        (0): Linear(in_features=128, out_features=128, bias=True)
+        (1): Dropout(p=0.1, inplace=False)
+        (2): GELU(approximate='none')
+        (3): Linear(in_features=128, out_features=128, bias=True)
+      )
+    )
+    (7): ChannelwiseTransformerEncoderLayer(
+      (attention_modules): ModuleDict(
+        (time mode height width): ResidualAttentionSequential(
+          (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
+          (1): MultiheadSelfAttention(
+            (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
+          )
+        )
+      )
+      (mlp): ResidualSequential(
+        (0): Linear(in_features=128, out_features=128, bias=True)
+        (1): Dropout(p=0.1, inplace=False)
+        (2): GELU(approximate='none')
+        (3): Linear(in_features=128, out_features=128, bias=True)
+      )
+    )
+  )
+)
+Layer 6: ModuleDict(
+  (change): CrossEntropyLoss()
+  (saliency): FocalLoss()
+  (class): FocalLoss()
+)
+Layer 7: ModuleDict(
+  (change): MultiLayerPerceptronNd(
+    (hidden): Sequential(
+      (hidden0): ConvNormNd(
+        (conv): Conv0d(in_features=128, out_features=86, bias=True)
+        (noli): ReLU(inplace=True)
+      )
+      (hidden1): ConvNormNd(
+        (conv): Conv0d(in_features=86, out_features=44, bias=True)
+        (noli): ReLU(inplace=True)
+      )
+      (output): Conv0d(in_features=44, out_features=2, bias=True)
+    )
+  )
+  (saliency): MultiLayerPerceptronNd(
+    (hidden): Sequential(
+      (hidden0): ConvNormNd(
+        (conv): Conv0d(in_features=128, out_features=86, bias=True)
+        (noli): ReLU(inplace=True)
+      )
+      (hidden1): ConvNormNd(
+        (conv): Conv0d(in_features=86, out_features=44, bias=True)
+        (noli): ReLU(inplace=True)
+      )
+      (output): Conv0d(in_features=44, out_features=2, bias=True)
+    )
+  )
+  (class): MultiLayerPerceptronNd(
+    (hidden): Sequential(
+      (hidden0): ConvNormNd(
+        (conv): Conv0d(in_features=128, out_features=87, bias=True)
+        (noli): ReLU(inplace=True)
+      )
+      (hidden1): ConvNormNd(
+        (conv): Conv0d(in_features=87, out_features=45, bias=True)
+        (noli): ReLU(inplace=True)
+      )
+      (output): Conv0d(in_features=45, out_features=4, bias=True)
+    )
+  )
+)
+Layer 8: SinePositionalEncoding()
+Layer 9: SinePositionalEncoding()
diff --git a/docs/source/manual/tutorial/tutorial9_hidden_layer.py b/docs/source/manual/tutorial/tutorial9_hidden_layer.py
new file mode 100644
index 000000000..96b8b90e2
--- /dev/null
+++ b/docs/source/manual/tutorial/tutorial9_hidden_layer.py
@@ -0,0 +1,446 @@
+# Tutorial to extract hidden layer features from a pre-trained model
+
+import sys
+import os
+import ubelt as ub
+
+
+# import module
+from geowatch.tasks.fusion.predict import *  # NOQA
+print(Predictor)
+from geowatch.utils.lightning_ext.monkeypatches import disable_lightning_hardware_warnings
+disable_lightning_hardware_warnings()
+args = None
+cmdline = False
+devices = None
+test_dpath = ub.Path.appdir('geowatch/tests/fusion/').ensuredir()
+results_path = (test_dpath / 'predict').ensuredir()
+results_path.delete()
+results_path.ensuredir()
+
+# generate kwcoco toydataset
+import kwcoco
+train_dset = kwcoco.CocoDataset.demo('special:vidshapes2-gsize64-frames9-speed0.5-multispectral')
+test_dset = kwcoco.CocoDataset.demo('special:vidshapes1-gsize64-frames9-speed0.5-multispectral')
+
+root_dpath = ub.Path(test_dpath, 'train').ensuredir()
+
+# create config
+fit_config = kwargs = {
+    'subcommand': 'fit',
+    'fit.data.train_dataset': train_dset.fpath,
+    'fit.data.time_steps': 2,
+    'fit.data.time_span': "2m",
+    'fit.data.chip_dims': 64,
+    'fit.data.time_sampling': 'hardish3',
+    'fit.data.num_workers': 0,
+    # 'package_fpath': package_fpath,
+    'fit.model.class_path': 'geowatch.tasks.fusion.methods.MultimodalTransformer',
+    'fit.model.init_args.global_change_weight': 1.0,
+    'fit.model.init_args.global_class_weight': 1.0,
+    'fit.model.init_args.global_saliency_weight': 1.0,
+    'fit.optimizer.class_path': 'torch.optim.SGD',
+    'fit.optimizer.init_args.lr': 1e-5,
+    'fit.trainer.max_steps': 10,
+    'fit.trainer.accelerator': 'cpu',
+    'fit.trainer.devices': 1,
+    'fit.trainer.max_epochs': 3,
+    'fit.trainer.log_every_n_steps': 1,
+    'fit.trainer.default_root_dir': os.fspath(root_dpath),
+}
+
+from geowatch.tasks.fusion import fit_lightning
+package_fpath = root_dpath / 'final_package.pt'
+
+# sets up the model
+fit_lightning.main(fit_config)
+
+# Unfortunately, its not as easy to get the package path of
+# this call..
+assert ub.Path(package_fpath).exists()
+# Predict via that model
+predict_kwargs = kwargs = {
+     'package_fpath': package_fpath,
+     'pred_dataset': ub.Path(results_path) / 'pred.kwcoco.json',
+     'test_dataset': test_dset.fpath,
+     'datamodule': 'KWCocoVideoDataModule',
+     'batch_size': 1,
+     'num_workers': 0,
+     'devices': devices,
+     'draw_batches': 1,
+}
+
+# Does this run the prediction model?
+result_dataset = predict(**kwargs)
+
+#-----------------------------------------------------------------
+# Load config, model and dataset into predictor model
+import rich
+from rich.markup import escape
+config = PredictConfig.cli(cmdline=cmdline, data=kwargs, strict=True)
+rich.print('config = {}'.format(escape(ub.urepr(config, nl=2))))
+predictor = Predictor(config)
+predictor._load_model()
+predictor._load_dataset()
+
+#---------------------------------------------------------------
+# Move to the predictor _run() function (line 1370)
+# Need to assign 'self' due to class method
+self = predictor
+datamodule = self.datamodule
+model = self.model
+config = self.config
+test_coco_dataset = datamodule.coco_datasets['test']
+
+# test_torch_dataset = datamodule.torch_datasets['test']
+# T, H, W = test_torch_dataset.window_dims
+
+# Create the results dataset as a copy of the test CocoDataset
+print('Populate result dataset')
+result_dataset: kwcoco.CocoDataset = test_coco_dataset.copy()
+
+# Remove all annotations in the results copy
+if config['clear_annots']:
+    result_dataset.clear_annotations()
+
+# Change all paths to be absolute paths
+result_dataset.reroot(absolute=True)
+if not config['pred_dataset']:
+    raise ValueError(
+        f'Must specify path to the output (predicted) kwcoco file. '
+        f'Got {config["pred_dataset"]=}')
+result_dataset.fpath = str(ub.Path(config['pred_dataset']).expand())
+
+from geowatch.utils.lightning_ext import util_device
+print('devices = {!r}'.format(config['devices']))
+devices = util_device.coerce_devices(config['devices'])
+print('devices = {!r}'.format(devices))
+if len(devices) > 1:
+    raise NotImplementedError('TODO: handle multiple devices')
+device = devices[0]
+
+fit_config = self.fit_config
+
+#---------------------------------------------------------------
+# The critical work happens here at _predict_critical_loop(), line 1407
+# The data module produces batches
+# We move into the function to perform the operations (line 557)
+
+print('Predict on device = {!r}'.format(device))
+downweight_edges = config.downweight_edges
+
+UNPACKAGE_METHOD_HACK = 0
+if UNPACKAGE_METHOD_HACK:
+    # unpackage model hack
+    from geowatch.tasks.fusion import methods
+    unpackged_method = methods.MultimodalTransformer(**model.hparams)
+    unpackged_method.load_state_dict(model.state_dict())
+    model = unpackged_method
+
+model = model.to(device)
+
+# Introspection of config
+# Resolve what tasks are requested by looking at what heads are available.
+global_head_weights = getattr(model, 'global_head_weights', {})
+if config['with_change'] == 'auto':
+    config['with_change'] = getattr(model, 'global_change_weight', 1.0) or global_head_weights.get('change', 1)
+if config['with_class'] == 'auto':
+    config['with_class'] = getattr(model, 'global_class_weight', 1.0) or global_head_weights.get('class', 1)
+if config['with_saliency'] == 'auto':
+    config['with_saliency'] = getattr(model, 'global_saliency_weight', 0.0) or global_head_weights.get('saliency', 1)
+
+
+test_dataloader = datamodule.test_dataloader()
+batch_iter = iter(test_dataloader)
+
+from kwutil import util_progress
+pman = util_progress.ProgressManager(backend='rich')
+
+# prog = ub.ProgIter(batch_iter, desc='fusion predict', verbose=1, freq=1)
+
+# Make threads after starting background proces.
+if config.write_workers == 'datamodule':
+    config.write_workers = datamodule.num_workers
+writer_queue = util_parallel.BlockingJobQueue(
+    mode='thread',
+    # mode='serial',
+    max_workers=config.write_workers
+)
+result_fpath = ub.Path(result_dataset.fpath)
+result_fpath.parent.ensuredir()
+print('result_fpath = {!r}'.format(result_fpath))
+
+# Definition of the building stitching managers starts line 198
+# Need a change here to include hidden layer
+# 226-247 pass info
+# Stitching manager needs bands and short code
+# chan_code = is important
+# create a stitching manager and add to the dicstionary
+# Model knows what classes it wants to predict
+# creates a 4-channel raster, then takes "saliency"
+# create another stitching manager to gather descriptor features
+# "coco-stitcher"
+
+#--------------------------------------------------------------
+stitch_managers = build_stitching_managers(
+    config, model, result_dataset,
+    writer_queue=writer_queue
+)
+stitch_managers
+
+#--------------------------------------------------------------
+
+expected_outputs = set(stitch_managers.keys())
+got_outputs = None
+writable_outputs = None
+
+print('Expected outputs: ' + str(expected_outputs))
+
+head_key_mapping = {
+    'saliency_probs': 'saliency',
+    'class_probs': 'class',
+    'change_probs': 'change',
+    'hidden_layers': 'hidden_layers',
+}
+
+from geowatch.tasks.fusion.predict import _jsonify
+
+info = result_dataset.dataset.get('info', [])
+
+pred_dpath = ub.Path(result_dataset.fpath).parent
+rich.print(f'Pred Dpath: [link={pred_dpath}]{pred_dpath}[/link]')
+
+DRAW_BATCHES = config.draw_batches
+if DRAW_BATCHES:
+    viz_batch_dpath = (pred_dpath / '_viz_pred_batches').ensuredir()
+
+config_resolved = _jsonify(config.asdict())
+fit_config = _jsonify(fit_config)
+
+from kwcoco.util import util_json
+unresolvable = list(util_json.find_json_unserializable(config_resolved))
+if unresolvable:
+    import warnings
+    warnings.warn(f'NotReproducibleWarning: Found unresolvable configuration options: {unresolvable!r}')
+    config_walker = ub.IndexableWalker(config_resolved)
+    for unresolvable_item in unresolvable:
+        _value = unresolvable_item['data']
+        config_walker[unresolvable_item['loc']] = f'Unresolvable: {_value}'
+
+    unresolvable = list(util_json.find_json_unserializable(config_resolved))
+    assert not unresolvable, 'should have entered dummy values for unresolvable data'
+
+if config['record_context']:
+    from geowatch.utils import process_context
+    proc_context = process_context.ProcessContext(
+        name='geowatch.tasks.fusion.predict',
+        type='process',
+        config=config_resolved,
+        track_emissions=config['track_emissions'],
+        # Extra information was adjusted in 0.15.1 to ensure more relevant
+        # fit params are returned here. A script
+        # ~/code/geowatch/geowatch/cli/experimental/fixup_predict_kwcoco_metadata.py
+        # exist to help update old results to use this new format.
+        extra={
+            'fit_config': fit_config
+        }
+    )
+    # assert not list(util_json.find_json_unserializable(proc_context.obj))
+    info.append(proc_context.obj)
+    proc_context.start()
+    test_coco_dataset = datamodule.coco_datasets['test']
+    proc_context.add_disk_info(test_coco_dataset.fpath)
+
+memory_monitor_timer = ub.Timer().tic()
+memory_monitor_interval_seconds = 60 * 60
+with_memory_units = bool(ub.modname_to_modpath('pint'))
+
+#--------------------------------------------------------------------------
+#---------------------------------------------------------------------------
+torch.set_grad_enabled(False)
+
+EMERGENCY_INPUT_AGREEMENT_HACK = 1 and hasattr(model, 'input_norms')
+
+# prog.set_extra(' <will populate stats after first video>')
+# pman.start()
+
+prog = pman.progiter(batch_iter, desc='fusion predict')
+_batch_iter = iter(prog)
+
+if 0:
+    item = test_dataloader.dataset[0]
+
+    orig_batch = next(_batch_iter)
+    item = orig_batch[0]
+    item['target']
+    frame = item['frames'][0]
+    ub.peek(frame['modes'].values()).shape
+
+batch_idx = 0
+
+# can ignore pman
+pman.stopall()
+
+item = test_dataloader.dataset[0]
+orig_batch = next(_batch_iter)
+
+# check out orig_batch
+
+# Iterates through every item in dataset
+# We just step through loop once
+batch_idx += 1
+batch_trs = []
+
+# Move data onto the prediction device, grab spacetime region info
+fixed_batch = []
+for item in orig_batch:
+    if item is None:
+        continue
+    item = item.copy()
+    batch_gids = [frame['gid'] for frame in item['frames']]
+    frame_infos = [ub.udict(f) & {
+        'gid',
+        'output_space_slice',
+        'output_image_dsize',
+        'output_weights',
+        'scale_outspace_from_vid',
+    } for f in item['frames']]
+    batch_trs.append({
+        'space_slice': tuple(item['target']['space_slice']),
+        # 'scale': item['target']['scale'],
+        'scale': item['target'].get('scale', None),
+        'gids': batch_gids,
+        'frame_infos': frame_infos,
+        'fliprot_params': item['target'].get('fliprot_params', None)
+    })
+    position_tensors = item.get('positional_tensors', None)
+    if position_tensors is not None:
+        for k, v in position_tensors.items():
+            position_tensors[k] = v.to(device)
+
+    filtered_frames = []
+    for frame in item['frames']:
+        frame = frame.copy()
+        sensor = frame['sensor']
+        if EMERGENCY_INPUT_AGREEMENT_HACK:
+            try:
+                known_sensor_modes = model.input_norms[sensor]
+            except KeyError:
+                known_sensor_modes = None
+                continue
+        filtered_modes = {}
+        modes = frame['modes']
+        for key, mode in modes.items():
+            if EMERGENCY_INPUT_AGREEMENT_HACK:
+                if key not in known_sensor_modes:
+                    continue
+            filtered_modes[key] = mode.to(device)
+        frame['modes'] = filtered_modes
+        filtered_frames.append(frame)
+    item['frames'] = filtered_frames
+    fixed_batch.append(item)
+
+# fixes batch components
+batch = fixed_batch
+
+# can view again `batch`
+
+
+from geowatch.utils.util_netharn import _debug_inbatch_shapes
+print(_debug_inbatch_shapes(batch))
+
+
+if memory_monitor_timer.toc() > memory_monitor_interval_seconds:
+    # TODO: monitor memory usage and report if it looks like we
+    # are about to run out of memory, and maybe do something to
+    # handle it.
+    from geowatch.utils import util_hardware
+    mem_info = util_hardware.get_mem_info(with_units=with_memory_units)
+    print(f'\n\nmem_info = {ub.urepr(mem_info, nl=1)}\n\n')
+    memory_monitor_timer.tic()
+
+
+# Entire purpose of the function. Where model connects to data and runs
+# neural network.  Arbritrary code we don't control
+# Need to inject something to find hidden state
+# Prepare input --> model.forward_step --> prepare output
+outputs = model.forward_step(batch, with_loss=False)
+
+# view `outputs`
+# view `outputs.keys()`
+
+# Compatibility step
+outputs = {head_key_mapping.get(k, k): v for k, v in outputs.items()}
+outputs.keys()
+
+# Checks/hack, runs once
+got_outputs = list(outputs.keys())
+prog.ensure_newline()
+writable_outputs = set(got_outputs) & expected_outputs
+print('got_outputs = {!r}'.format(got_outputs))
+print('writable_outputs = {!r}'.format(writable_outputs))
+
+# For each item in the batch, process the results
+for head_key in writable_outputs:
+    head_probs = outputs[head_key]
+    head_stitcher = stitch_managers[head_key]
+    chan_keep_idxs = head_stitcher.head_keep_idxs
+
+# hack
+predicted_frame_slice = slice(None)
+
+num_batches = len(batch_trs)
+
+# -----------------------------------------------------------------------
+for bx in range(num_batches):
+    target: dict = batch_trs[bx]
+    item_head_probs: list[torch.Tensor] | torch.Tensor = head_probs[bx]
+
+    # check the shape, was ([2,64,64,2]) for initial run
+    item_head_probs.shape
+
+    # Keep only the channels we want to write to disk
+    # convert to numpy
+    item_head_relevant_probs = [p[..., chan_keep_idxs] for p in item_head_probs]
+    bin_probs = [p.detach().cpu().numpy() for p in item_head_relevant_probs]
+
+
+
+    # check probs
+    DEBUG_PRED_SPATIAL_COVERAGE=0
+
+    frame_infos: list[dict] = target['frame_infos'][predicted_frame_slice]
+
+    fliprot_params: dict = target['fliprot_params']
+    # Update the stitcher with this windowed prediction
+    for probs, frame_info in zip(bin_probs, frame_infos):
+        if fliprot_params is not None:
+            # Undo fliprot TTA
+            probs = data_utils.inv_fliprot(probs, **fliprot_params)
+
+        gid = frame_info['gid']
+        output_image_dsize = frame_info['output_image_dsize']
+        output_space_slice = frame_info['output_space_slice']
+        scale_outspace_from_vid = frame_info['scale_outspace_from_vid']
+
+        if DEBUG_PRED_SPATIAL_COVERAGE:
+            image_id_to_video_space_slices[gid].append(target['space_slice'])
+            image_id_to_output_space_slices[gid].append(output_space_slice)
+
+        output_weights = frame_info.get('output_weights', None)
+
+        # View some details
+        print(probs.shape)
+        print(frame_info)
+        print(f"model._activation['hidden'] {model._activation_cache['hidden']}")
+
+
+
+        # Checks if an image is done, submits to finalize
+        head_stitcher.accumulate_image(
+            gid, output_space_slice, probs,
+            asset_dsize=output_image_dsize,
+            scale_asset_from_stitchspace=scale_outspace_from_vid,
+            weights=output_weights,
+            downweight_edges=downweight_edges,
+        )
diff --git a/geowatch/tasks/fusion/predict.py b/geowatch/tasks/fusion/predict.py
index 52c8dfe1c..6d0f545b1 100644
--- a/geowatch/tasks/fusion/predict.py
+++ b/geowatch/tasks/fusion/predict.py
@@ -201,6 +201,43 @@ class PredictConfig(DataModuleConfigMixin):
         is used.
         '''))
 
+# --------------Add hidden layer hook to model----------------
+
+def _register_hidden_layer_hook(model):
+    # TODO: generalize to other models
+    # Specific to UNetR model
+    # These are at half of the output image resolution.
+
+    model._activation_cache = {}
+
+    # print("info on model", dir(model))
+
+    print("Enumerate over model.children()\n")
+    for i, layer in enumerate(model.children()):
+        print(f"Layer {i}: {layer}")
+
+    # Not sure this is the correct code
+    encoder_layers = list(model.encoder.children())
+
+    print(f"\nNumber of encoder layers {len(encoder_layers)}\n")
+
+
+    print("Enumerate over model.encoder.children()\n")
+    for i, layer in enumerate(encoder_layers):
+        print(f"Encoder Layer {i}: {layer}")
+
+    def record_hidden_activation(layer, input, output):
+        activation = output.detach()
+        model._activation_cache['hidden'] = activation
+        print(f"Hidden layer activation shape {activation.shape}")
+
+    # Desired layer is the 5th layer of the 2nd encoder layer
+    layer_of_interest = list(model.encoder.children())[1][6]
+    layer_of_interest._forward_hooks.clear()
+    layer_of_interest.register_forward_hook(record_hidden_activation)
+# ----------------------------------------------------------
+
+
 
 # --------------Add hidden layer hook to model----------------
 
-- 
GitLab


From 55e5c7c6dd730226cab961247f6984e71a295bcd Mon Sep 17 00:00:00 2001
From: Paul Beasly <paul.beasly@kitware.com>
Date: Thu, 20 Jun 2024 12:10:27 -0400
Subject: [PATCH 2/3] Recent update of tutorial9

---
 .../manual/tutorial/fusion_predict_notes.sh   |  14 +
 .../tutorial/hidden_descriptors_notes.sh      | 511 ++++++++++++++++++
 .../manual/tutorial/hidden_layer_notes.sh     |  43 ++
 .../manual/tutorial/module_path_test.py       |   6 +
 docs/source/manual/tutorial/record_demo.sh    | 430 +++++++++++++++
 .../manual/tutorial/tutorial9_hidden_layer.py |  49 +-
 geowatch/tasks/fusion/predict.py              |   6 +-
 7 files changed, 1046 insertions(+), 13 deletions(-)
 create mode 100644 docs/source/manual/tutorial/fusion_predict_notes.sh
 create mode 100644 docs/source/manual/tutorial/hidden_descriptors_notes.sh
 create mode 100644 docs/source/manual/tutorial/hidden_layer_notes.sh
 create mode 100644 docs/source/manual/tutorial/module_path_test.py
 create mode 100644 docs/source/manual/tutorial/record_demo.sh

diff --git a/docs/source/manual/tutorial/fusion_predict_notes.sh b/docs/source/manual/tutorial/fusion_predict_notes.sh
new file mode 100644
index 000000000..ac063c8f7
--- /dev/null
+++ b/docs/source/manual/tutorial/fusion_predict_notes.sh
@@ -0,0 +1,14 @@
+# notes on the structure and flow of the fusion predict model
+
+# Fusion predict creates predicted heatmaps
+# Need to modify code to extract hidden layers
+
+# 1. Imports packages
+
+# 2. Define config classes
+
+# 2A. DataConfig `DataModuleConfigMixin()` class
+# - does this need the `with_hidden` kwarg?
+# 2B. PredictConfig, takes the previous config as arguments
+# Noted as "Prediction script for the fusion task"
+# - added `with_hidden_layers = scfg.Value('auto', help=None)`
diff --git a/docs/source/manual/tutorial/hidden_descriptors_notes.sh b/docs/source/manual/tutorial/hidden_descriptors_notes.sh
new file mode 100644
index 000000000..b7b14b4d3
--- /dev/null
+++ b/docs/source/manual/tutorial/hidden_descriptors_notes.sh
@@ -0,0 +1,511 @@
+__doc__="
+
+Notes on exposing hidden layer descipritors in fusion/predict.py
+================================================================
+
+The script takes two inputs; the model and kwcoco data.  Need to add a
+stitching manager that gathers descritors/features from designated layer. Port
+the paradigm from landcover/predict.py to fusion/predict.py. This will add
+an option to output intermediate features from the model.
+"
+
+# Overall steps for running the code:
+# Setup uses the doctest info
+# kwargs get overwritten by predict_kwargs
+# Now the info is loaded into the namespace
+
+# the chan_code is where the features are saved
+
+#-----------------Basically-----------------
+# 1. Get Activations from the model
+# 2. Send to the stitching manager
+#-------------------------------------------------
+
+# Next is to perform the prediction _run() function
+# need to assign self to predictor for the class method
+# script to modify
+geowatch/tasks/fusion/predict.py
+
+# take the hidden layer hook from the script
+geowatch/tasks/landcover/predict.py
+
+# Start with an interactive python sesstion
+# Run code in doctests to setup paths and config. (approx line 1017-1069)
+from geowatch.tasks.fusion.predict import *  # NOQA (line 1017)
+# ...to...
+root_dpath = ub.Path(test_dpath, 'train').ensuredir() # line 1030
+
+# test files go into cache directory specified by variable
+test_dpath
+
+# run the code for config
+fit_config = kwargs = {...}
+
+# the model being used is fit_lightning. To generate the model
+fit_lightning.main(fit_config)
+
+# setup the predict_kwargs that overwrite the generic kwargs
+# and perform the prediction (approx lines 1060-1069)
+
+
+# Now we have the required info in the namespace.
+
+# Now move to the main script to run the predictor (line 1370)
+# need to assign self to predictor for the class method
+
+
+#--------------------------------------------
+# Notes on porting paradigm. The file is
+tasks/landcover/predict.py
+# line 227-240 finds the hidden layer
+# lines 238-240 perform the register function
+#-----------------------------------------------------
+
+# History of commands in interactive python session
+# Need to run the ipython session in the geowatch directory
+
+# -----------------------------------------------------------------
+# List of what is imported in first line of code below
+# {'utils', 'kwarray', 'datamodules', 'main', 'DataModuleConfigMixin',
+# 'PredictConfig', 'kwimage', 'ub', 'Predictor', 'resolve_datamodule',
+# 'profile', 'quantize_float01', 'before_import', 'build_stitching_managers',
+# 'np', 'kwcoco', 'util_parallel', 'monkey_torch', 'monkey_torchmetrics',
+# 'predict', 'CocoStitchingManager', 'scfg', 'torch', 'data_utils'}
+
+
+#---------------------------------------------------------------------
+# From doctest (line 1016), import modules and set up some config args
+from geowatch.tasks.fusion.predict import *  # NOQA
+import os
+from geowatch.utils.lightning_ext.monkeypatches import disable_lightning_hardware_warnings
+disable_lightning_hardware_warnings()
+args = None
+cmdline = False
+devices = None
+test_dpath = ub.Path.appdir('geowatch/tests/fusion/').ensuredir()
+results_path = (test_dpath / 'predict').ensuredir()
+results_path.delete()
+results_path.ensuredir()
+
+# generate kwcoco toydataset
+import kwcoco
+train_dset = kwcoco.CocoDataset.demo('special:vidshapes2-gsize64-frames9-speed0.5-multispectral')
+test_dset = kwcoco.CocoDataset.demo('special:vidshapes1-gsize64-frames9-speed0.5-multispectral')
+
+root_dpath = ub.Path(test_dpath, 'train').ensuredir()
+
+# create config
+fit_config = kwargs = {
+    'subcommand': 'fit',
+    'fit.data.train_dataset': train_dset.fpath,
+    'fit.data.time_steps': 2,
+    'fit.data.time_span': "2m",
+    'fit.data.chip_dims': 64,
+    'fit.data.time_sampling': 'hardish3',
+    'fit.data.num_workers': 0,
+    # 'package_fpath': package_fpath,
+    'fit.model.class_path': 'geowatch.tasks.fusion.methods.MultimodalTransformer',
+    'fit.model.init_args.global_change_weight': 1.0,
+    'fit.model.init_args.global_class_weight': 1.0,
+    'fit.model.init_args.global_saliency_weight': 1.0,
+    'fit.optimizer.class_path': 'torch.optim.SGD',
+    'fit.optimizer.init_args.lr': 1e-5,
+    'fit.trainer.max_steps': 10,
+    'fit.trainer.accelerator': 'cpu',
+    'fit.trainer.devices': 1,
+    'fit.trainer.max_epochs': 3,
+    'fit.trainer.log_every_n_steps': 1,
+    'fit.trainer.default_root_dir': os.fspath(root_dpath),
+}
+
+from geowatch.tasks.fusion import fit_lightning
+package_fpath = root_dpath / 'final_package.pt'
+
+# sets up the model
+fit_lightning.main(fit_config)
+
+# Unfortunately, its not as easy to get the package path of
+# this call..
+assert ub.Path(package_fpath).exists()
+# Predict via that model
+predict_kwargs = kwargs = {
+     'package_fpath': package_fpath,
+     'pred_dataset': ub.Path(results_path) / 'pred.kwcoco.json',
+     'test_dataset': test_dset.fpath,
+     'datamodule': 'KWCocoVideoDataModule',
+     'batch_size': 1,
+     'num_workers': 0,
+     'devices': devices,
+     'draw_batches': 1,
+}
+
+# Does this run the prediction model?
+result_dataset = predict(**kwargs)
+
+#-----------------------------------------------------------------
+# Load config, model and dataset into predictor model
+import rich
+from rich.markup import escape
+config = PredictConfig.cli(cmdline=cmdline, data=kwargs, strict=True)
+rich.print('config = {}'.format(escape(ub.urepr(config, nl=2))))
+predictor = Predictor(config)
+predictor._load_model()
+predictor._load_dataset()
+
+#---------------------------------------------------------------
+# Move to the predictor _run() function (line 1370)
+# Need to assign 'self' due to class method
+self = predictor
+datamodule = self.datamodule
+model = self.model
+config = self.config
+test_coco_dataset = datamodule.coco_datasets['test']
+
+# test_torch_dataset = datamodule.torch_datasets['test']
+# T, H, W = test_torch_dataset.window_dims
+
+# Create the results dataset as a copy of the test CocoDataset
+print('Populate result dataset')
+result_dataset: kwcoco.CocoDataset = test_coco_dataset.copy()
+
+# Remove all annotations in the results copy
+if config['clear_annots']:
+    result_dataset.clear_annotations()
+
+# Change all paths to be absolute paths
+result_dataset.reroot(absolute=True)
+if not config['pred_dataset']:
+    raise ValueError(
+        f'Must specify path to the output (predicted) kwcoco file. '
+        f'Got {config["pred_dataset"]=}')
+result_dataset.fpath = str(ub.Path(config['pred_dataset']).expand())
+
+from geowatch.utils.lightning_ext import util_device
+print('devices = {!r}'.format(config['devices']))
+devices = util_device.coerce_devices(config['devices'])
+print('devices = {!r}'.format(devices))
+if len(devices) > 1:
+    raise NotImplementedError('TODO: handle multiple devices')
+device = devices[0]
+
+fit_config = self.fit_config
+
+#---------------------------------------------------------------
+# The critical work happens here at _predict_critical_loop(), line 1407
+# The data module produces batches
+# We move into the function to perform the operations (line 557)
+
+print('Predict on device = {!r}'.format(device))
+downweight_edges = config.downweight_edges
+
+UNPACKAGE_METHOD_HACK = 0
+if UNPACKAGE_METHOD_HACK:
+    # unpackage model hack
+    from geowatch.tasks.fusion import methods
+    unpackged_method = methods.MultimodalTransformer(**model.hparams)
+    unpackged_method.load_state_dict(model.state_dict())
+    model = unpackged_method
+
+model = model.to(device)
+
+# Introspection of config
+# Resolve what tasks are requested by looking at what heads are available.
+global_head_weights = getattr(model, 'global_head_weights', {})
+if config['with_change'] == 'auto':
+    config['with_change'] = getattr(model, 'global_change_weight', 1.0) or global_head_weights.get('change', 1)
+if config['with_class'] == 'auto':
+    config['with_class'] = getattr(model, 'global_class_weight', 1.0) or global_head_weights.get('class', 1)
+if config['with_saliency'] == 'auto':
+    config['with_saliency'] = getattr(model, 'global_saliency_weight', 0.0) or global_head_weights.get('saliency', 1)
+
+
+test_dataloader = datamodule.test_dataloader()
+batch_iter = iter(test_dataloader)
+
+from kwutil import util_progress
+pman = util_progress.ProgressManager(backend='rich')
+
+# prog = ub.ProgIter(batch_iter, desc='fusion predict', verbose=1, freq=1)
+
+# Make threads after starting background proces.
+if config.write_workers == 'datamodule':
+    config.write_workers = datamodule.num_workers
+writer_queue = util_parallel.BlockingJobQueue(
+    mode='thread',
+    # mode='serial',
+    max_workers=config.write_workers
+)
+result_fpath = ub.Path(result_dataset.fpath)
+result_fpath.parent.ensuredir()
+print('result_fpath = {!r}'.format(result_fpath))
+
+# Definition of the building stitching managers starts line 198
+# Need a change here to include hidden layer
+# 226-247 pass info
+# Stitching manager needs bands and short code
+# chan_code = is important
+# create a stitching manager and add to the dicstionary
+# Model knows what classes it wants to predict
+# creates a 4-channel raster, then takes "saliency"
+# create another stitching manager to gather descriptor features
+# "coco-stitcher"
+
+#--------------------------------------------------------------
+stitch_managers = build_stitching_managers(
+    config, model, result_dataset,
+    writer_queue=writer_queue
+)
+stitch_managers
+
+#--------------------------------------------------------------
+
+expected_outputs = set(stitch_managers.keys())
+got_outputs = None
+writable_outputs = None
+
+print('Expected outputs: ' + str(expected_outputs))
+
+head_key_mapping = {
+    'saliency_probs': 'saliency',
+    'class_probs': 'class',
+    'change_probs': 'change',
+}
+
+from geowatch.tasks.fusion.predict import _jsonify
+
+info = result_dataset.dataset.get('info', [])
+
+pred_dpath = ub.Path(result_dataset.fpath).parent
+rich.print(f'Pred Dpath: [link={pred_dpath}]{pred_dpath}[/link]')
+
+DRAW_BATCHES = config.draw_batches
+if DRAW_BATCHES:
+    viz_batch_dpath = (pred_dpath / '_viz_pred_batches').ensuredir()
+
+config_resolved = _jsonify(config.asdict())
+fit_config = _jsonify(fit_config)
+
+from kwcoco.util import util_json
+unresolvable = list(util_json.find_json_unserializable(config_resolved))
+if unresolvable:
+    import warnings
+    warnings.warn(f'NotReproducibleWarning: Found unresolvable configuration options: {unresolvable!r}')
+    config_walker = ub.IndexableWalker(config_resolved)
+    for unresolvable_item in unresolvable:
+        _value = unresolvable_item['data']
+        config_walker[unresolvable_item['loc']] = f'Unresolvable: {_value}'
+
+    unresolvable = list(util_json.find_json_unserializable(config_resolved))
+    assert not unresolvable, 'should have entered dummy values for unresolvable data'
+
+if config['record_context']:
+    from geowatch.utils import process_context
+    proc_context = process_context.ProcessContext(
+        name='geowatch.tasks.fusion.predict',
+        type='process',
+        config=config_resolved,
+        track_emissions=config['track_emissions'],
+        # Extra information was adjusted in 0.15.1 to ensure more relevant
+        # fit params are returned here. A script
+        # ~/code/geowatch/geowatch/cli/experimental/fixup_predict_kwcoco_metadata.py
+        # exist to help update old results to use this new format.
+        extra={
+            'fit_config': fit_config
+        }
+    )
+    # assert not list(util_json.find_json_unserializable(proc_context.obj))
+    info.append(proc_context.obj)
+    proc_context.start()
+    test_coco_dataset = datamodule.coco_datasets['test']
+    proc_context.add_disk_info(test_coco_dataset.fpath)
+
+memory_monitor_timer = ub.Timer().tic()
+memory_monitor_interval_seconds = 60 * 60
+with_memory_units = bool(ub.modname_to_modpath('pint'))
+
+#--------------------------------------------------------------------------
+#---------------------------------------------------------------------------
+torch.set_grad_enabled(False)
+
+EMERGENCY_INPUT_AGREEMENT_HACK = 1 and hasattr(model, 'input_norms')
+
+# prog.set_extra(' <will populate stats after first video>')
+# pman.start()
+
+prog = pman.progiter(batch_iter, desc='fusion predict')
+_batch_iter = iter(prog)
+
+if 0:
+    item = test_dataloader.dataset[0]
+
+    orig_batch = next(_batch_iter)
+    item = orig_batch[0]
+    item['target']
+    frame = item['frames'][0]
+    ub.peek(frame['modes'].values()).shape
+
+batch_idx = 0
+
+# can ignore pman
+pman.stopall()
+
+item = test_dataloader.dataset[0]
+orig_batch = next(_batch_iter)
+
+# check out orig_batch
+
+# Iterates through every item in dataset
+# We just step through loop once
+batch_idx += 1
+batch_trs = []
+
+# Move data onto the prediction device, grab spacetime region info
+fixed_batch = []
+for item in orig_batch:
+    if item is None:
+        continue
+    item = item.copy()
+    batch_gids = [frame['gid'] for frame in item['frames']]
+    frame_infos = [ub.udict(f) & {
+        'gid',
+        'output_space_slice',
+        'output_image_dsize',
+        'output_weights',
+        'scale_outspace_from_vid',
+    } for f in item['frames']]
+    batch_trs.append({
+        'space_slice': tuple(item['target']['space_slice']),
+        # 'scale': item['target']['scale'],
+        'scale': item['target'].get('scale', None),
+        'gids': batch_gids,
+        'frame_infos': frame_infos,
+        'fliprot_params': item['target'].get('fliprot_params', None)
+    })
+    position_tensors = item.get('positional_tensors', None)
+    if position_tensors is not None:
+        for k, v in position_tensors.items():
+            position_tensors[k] = v.to(device)
+
+    filtered_frames = []
+    for frame in item['frames']:
+        frame = frame.copy()
+        sensor = frame['sensor']
+        if EMERGENCY_INPUT_AGREEMENT_HACK:
+            try:
+                known_sensor_modes = model.input_norms[sensor]
+            except KeyError:
+                known_sensor_modes = None
+                continue
+        filtered_modes = {}
+        modes = frame['modes']
+        for key, mode in modes.items():
+            if EMERGENCY_INPUT_AGREEMENT_HACK:
+                if key not in known_sensor_modes:
+                    continue
+            filtered_modes[key] = mode.to(device)
+        frame['modes'] = filtered_modes
+        filtered_frames.append(frame)
+    item['frames'] = filtered_frames
+    fixed_batch.append(item)
+
+# fixes batch components
+batch = fixed_batch
+
+# can view again `batch`
+
+
+from geowatch.utils.util_netharn import _debug_inbatch_shapes
+print(_debug_inbatch_shapes(batch))
+
+
+if memory_monitor_timer.toc() > memory_monitor_interval_seconds:
+    # TODO: monitor memory usage and report if it looks like we
+    # are about to run out of memory, and maybe do something to
+    # handle it.
+    from geowatch.utils import util_hardware
+    mem_info = util_hardware.get_mem_info(with_units=with_memory_units)
+    print(f'\n\nmem_info = {ub.urepr(mem_info, nl=1)}\n\n')
+    memory_monitor_timer.tic()
+
+
+# Entire purpose of the function. Where model connects to data and runs
+# neural network.  Arbritrary code we don't control
+# Need to inject something to find hidden state
+# Prepare input --> model.forward_step --> prepare output
+outputs = model.forward_step(batch, with_loss=False)
+
+# view `outputs`
+# view `outputs.keys()`
+
+# Compatibility step
+outputs = {head_key_mapping.get(k, k): v for k, v in outputs.items()}
+outputs.keys()
+
+# Checks/hack, runs once
+got_outputs = list(outputs.keys())
+prog.ensure_newline()
+writable_outputs = set(got_outputs) & expected_outputs
+print('got_outputs = {!r}'.format(got_outputs))
+print('writable_outputs = {!r}'.format(writable_outputs))
+
+# For each item in the batch, process the results
+for head_key in writable_outputs:
+    head_probs = outputs[head_key]
+    head_stitcher = stitch_managers[head_key]
+    chan_keep_idxs = head_stitcher.head_keep_idxs
+
+# hack
+predicted_frame_slice = slice(None)
+
+num_batches = len(batch_trs)
+
+# -----------------------------------------------------------------------
+for bx in range(num_batches):
+    target: dict = batch_trs[bx]
+    item_head_probs: list[torch.Tensor] | torch.Tensor = head_probs[bx]
+
+    # check the shape, was ([2,64,64,2]) for initial run
+    item_head_probs.shape
+
+    # Keep only the channels we want to write to disk
+    # convert to numpy
+    item_head_relevant_probs = [p[..., chan_keep_idxs] for p in item_head_probs]
+    bin_probs = [p.detach().cpu().numpy() for p in item_head_relevant_probs]
+
+
+
+    # check probs
+    DEBUG_PRED_SPATIAL_COVERAGE=0
+
+    frame_infos: list[dict] = target['frame_infos'][predicted_frame_slice]
+
+    fliprot_params: dict = target['fliprot_params']
+    # Update the stitcher with this windowed prediction
+    for probs, frame_info in zip(bin_probs, frame_infos):
+        if fliprot_params is not None:
+            # Undo fliprot TTA
+            probs = data_utils.inv_fliprot(probs, **fliprot_params)
+
+        gid = frame_info['gid']
+        output_image_dsize = frame_info['output_image_dsize']
+        output_space_slice = frame_info['output_space_slice']
+        scale_outspace_from_vid = frame_info['scale_outspace_from_vid']
+
+        if DEBUG_PRED_SPATIAL_COVERAGE:
+            image_id_to_video_space_slices[gid].append(target['space_slice'])
+            image_id_to_output_space_slices[gid].append(output_space_slice)
+
+        output_weights = frame_info.get('output_weights', None)
+
+        # View some details
+        probs.shape
+        frame_info
+
+        # Checks if an image is done, submits to finalize
+        head_stitcher.accumulate_image(
+            gid, output_space_slice, probs,
+            asset_dsize=output_image_dsize,
+            scale_asset_from_stitchspace=scale_outspace_from_vid,
+            weights=output_weights,
+            downweight_edges=downweight_edges,
+        )
diff --git a/docs/source/manual/tutorial/hidden_layer_notes.sh b/docs/source/manual/tutorial/hidden_layer_notes.sh
new file mode 100644
index 000000000..0f645181c
--- /dev/null
+++ b/docs/source/manual/tutorial/hidden_layer_notes.sh
@@ -0,0 +1,43 @@
+# Document the flow and to output hidden layers in landcover/predict
+
+# The code uses the model (DZYNE_LANDCOVER_MODEL) and drop6 data.
+# Model is a based on pytorch for landcover classification.
+
+# Code requires some basic kwargs and runs the `predict` function.
+
+# 1. imports packages
+# 2. Loads config from the `LandcoverPredictConfig` class and adds cli kwargs
+#     - this config class has `with_hidden` kwarg
+
+# 3. Unpack variables from the config, dataset and model
+#     - `model` is loaded by runnning
+    model_info = lookup_model_info(weights_filename)
+    ptdataset = model_info.create_dataset(input_dset)
+    model = model_info.load_model(weights_filename, device)
+
+# 4. Run the `CocoStitchingManager` to get the stitchers.
+
+# 5. For the `hidden_stitcher` option:
+# - Runs the `_register_hidden_layer_hook` on with the loaded model
+# - Adds `._activation_cache = {}` dict attribute to the model
+# - Adds  `layer_of_interest` from `model.decoder1[]` attribute
+# - Clears hooks and `register_forward_hook()` using
+# - `record_hidden_activation()` that detaches output and stores in
+model._activation_cache['hidden'] = activation
+
+
+# 6. Starts the main predict loop
+# - Starts the `ProcessContext()` function
+# - starts the `pman` progress manager
+# - sets up the wrapping for the main loop `for img_info in _prog`
+# - predicts current image with `_predict_single`
+# - uses both stitchers
+
+# 7. Steps into `def _predict_single` function
+# - There is a hidden_scale - need to figure what this is
+# - Runs a sliding window over the image
+
+# 8. Extracts the hidden layer using
+hidden_raw = model._activation_cache['hidden'].cpu().numpy()
+# - Then hardcodes transformations of the activation layer
+# - back to the image stitcher
diff --git a/docs/source/manual/tutorial/module_path_test.py b/docs/source/manual/tutorial/module_path_test.py
new file mode 100644
index 000000000..dc2b521e2
--- /dev/null
+++ b/docs/source/manual/tutorial/module_path_test.py
@@ -0,0 +1,6 @@
+# simple test to see if modeules are loaded
+
+
+
+from geowatch.tasks.fusion.predict import *
+print(dir(Predictor))
diff --git a/docs/source/manual/tutorial/record_demo.sh b/docs/source/manual/tutorial/record_demo.sh
new file mode 100644
index 000000000..42648c231
--- /dev/null
+++ b/docs/source/manual/tutorial/record_demo.sh
@@ -0,0 +1,430 @@
+__doc__="
+
+Notes on exposing hidden layer descipritors in fusion/predict.py
+================================================================
+
+The script takes two inputs; the model and kwcoco data.  Need to add a
+stitching manager that gathers descritors/features from designated layer. Port
+the paradigm from landcover/predict.py to fusion/predict.py. This will add
+an option to output intermediate features from the model.
+"
+
+# Overall steps for running the code:
+# Setup uses the doctest info
+# kwargs get overwritten by predict_kwargs
+# Now the info is loaded into the namespace
+
+# Next is to perform the prediction _run() function
+# need to assign self to predictor for the class method
+# script to modify
+geowatch/tasks/fusion/predict.py
+
+# take the hidden layer hook from the script
+geowatch/tasks/landcover/predict.py
+
+# Start with an interactive pyton sesstion
+# Run code in doctests to setup paths and config. (approx line 1017-1069)
+from geowatch.tasks.fusion.predict import *  # NOQA (line 1017)
+# ...to...
+root_dpath = ub.Path(test_dpath, 'train').ensuredir() # line 1030
+
+# test files go into cache directory specified by variable
+test_dpath
+
+# run the code for config
+fit_config = kwargs = {...}
+
+# the model being used is fit_lightning. To generate the model
+fit_lightning.main(fit_config)
+
+# setup the predict_kwargs that overwrite the generic kwargs
+# and perform the prediction (approx lines 1060-1069)
+
+
+# Now we have the required info in the namespace.
+
+# Now move to the main script to run the predictor (line 1370)
+# need to assign self to predictor for the class method
+
+
+# History of commands in interactive python session
+
+# From doctest, import modules and set up some config args
+from geowatch.tasks.fusion.predict import *  # NOQA
+import os
+from geowatch.utils.lightning_ext.monkeypatches import disable_lightning_hardware_warnings
+disable_lightning_hardware_warnings()
+args = None
+cmdline = False
+devices = None
+test_dpath = ub.Path.appdir('geowatch/tests/fusion/').ensuredir()
+
+test_dpath # view path
+
+results_path = (test_dpath / 'predict').ensuredir()
+results_path.delete()
+results_path.ensuredir()
+
+# generate kwcoco toydataset
+import kwcoco
+train_dset = kwcoco.CocoDataset.demo('special:vidshapes2-gsize64-frames9-speed0.5-multispectral')
+test_dset = kwcoco.CocoDataset.demo('special:vidshapes1-gsize64-frames9-speed0.5-multispectral')
+20/4:
+        >>> root_dpath = ub.Path(test_dpath, 'train').ensuredir()
+        >>> fit_config = kwargs = {
+        ...     'subcommand': 'fit',
+        ...     'fit.data.train_dataset': train_dset.fpath,
+        ...     'fit.data.time_steps': 2,
+        ...     'fit.data.time_span': "2m",
+        ...     'fit.data.chip_dims': 64,
+        ...     'fit.data.time_sampling': 'hardish3',
+        ...     'fit.data.num_workers': 0,
+        ...     #'package_fpath': package_fpath,
+        ...     'fit.model.class_path': 'geowatch.tasks.fusion.methods.MultimodalTransformer',
+        ...     'fit.model.init_args.global_change_weight': 1.0,
+        ...     'fit.model.init_args.global_class_weight': 1.0,
+        ...     'fit.model.init_args.global_saliency_weight': 1.0,
+        ...     'fit.optimizer.class_path': 'torch.optim.SGD',
+        ...     'fit.optimizer.init_args.lr': 1e-5,
+        ...     'fit.trainer.max_steps': 10,
+        ...     'fit.trainer.accelerator': 'cpu',
+        ...     'fit.trainer.devices': 1,
+        ...     'fit.trainer.max_epochs': 3,
+        ...     'fit.trainer.log_every_n_steps': 1,
+        ...     'fit.trainer.default_root_dir': os.fspath(root_dpath),
+        ... }
+        >>> from geowatch.tasks.fusion import fit_lightning
+        >>> package_fpath = root_dpath / 'final_package.pt'
+20/5:         >>> fit_lightning.main(fit_config)
+20/6:
+        >>> # Unfortunately, its not as easy to get the package path of
+        >>> # this call..
+        >>> assert ub.Path(package_fpath).exists()
+        >>> # Predict via that model
+        >>> predict_kwargs = kwargs = {
+        >>>     'package_fpath': package_fpath,
+        >>>     'pred_dataset': ub.Path(results_path) / 'pred.kwcoco.json',
+        >>>     'test_dataset': test_dset.fpath,
+        >>>     'datamodule': 'KWCocoVideoDataModule',
+        >>>     'batch_size': 1,
+        >>>     'num_workers': 0,
+        >>>     'devices': devices,
+        >>>     'draw_batches': 1,
+        >>> }
+20/7:        >>> result_dataset = predict(**kwargs)
+20/8:
+    import rich
+    from rich.markup import escape
+    config = PredictConfig.cli(cmdline=cmdline, data=kwargs, strict=True)
+    rich.print('config = {}'.format(escape(ub.urepr(config, nl=2))))
+20/9: predictor = Predictor(config)
+20/10:
+    predictor._load_model()
+    predictor._load_dataset()
+20/11: self = predictor
+20/12:
+        datamodule = self.datamodule
+        model = self.model
+        config = self.config
+
+        test_coco_dataset = datamodule.coco_datasets['test']
+
+        # test_torch_dataset = datamodule.torch_datasets['test']
+        # T, H, W = test_torch_dataset.window_dims
+
+        # Create the results dataset as a copy of the test CocoDataset
+        print('Populate result dataset')
+        result_dataset: kwcoco.CocoDataset = test_coco_dataset.copy()
+
+        # Remove all annotations in the results copy
+        if config['clear_annots']:
+            result_dataset.clear_annotations()
+
+        # Change all paths to be absolute paths
+        result_dataset.reroot(absolute=True)
+        if not config['pred_dataset']:
+            raise ValueError(
+                f'Must specify path to the output (predicted) kwcoco file. '
+                f'Got {config["pred_dataset"]=}')
+        result_dataset.fpath = str(ub.Path(config['pred_dataset']).expand())
+
+        from geowatch.utils.lightning_ext import util_device
+        print('devices = {!r}'.format(config['devices']))
+        devices = util_device.coerce_devices(config['devices'])
+        print('devices = {!r}'.format(devices))
+        if len(devices) > 1:
+            raise NotImplementedError('TODO: handle multiple devices')
+        device = devices[0]
+
+        fit_config = self.fit_config
+20/13:
+    import rich
+
+    print('Predict on device = {!r}'.format(device))
+    downweight_edges = config.downweight_edges
+
+    UNPACKAGE_METHOD_HACK = 0
+    if UNPACKAGE_METHOD_HACK:
+        # unpackage model hack
+        from geowatch.tasks.fusion import methods
+        unpackged_method = methods.MultimodalTransformer(**model.hparams)
+        unpackged_method.load_state_dict(model.state_dict())
+        model = unpackged_method
+
+    model = model.to(device)
+    global_head_weights = getattr(model, 'global_head_weights', {})
+    if config['with_change'] == 'auto':
+        config['with_change'] = getattr(model, 'global_change_weight', 1.0) or global_head_weights.get('change', 1)
+    if config['with_class'] == 'auto':
+        config['with_class'] = getattr(model, 'global_class_weight', 1.0) or global_head_weights.get('class', 1)
+    if config['with_saliency'] == 'auto':
+        config['with_saliency'] = getattr(model, 'global_saliency_weight', 0.0) or global_head_weights.get('saliency', 1)
+20/14:
+    test_dataloader = datamodule.test_dataloader()
+    batch_iter = iter(test_dataloader)
+
+    from kwutil import util_progress
+    pman = util_progress.ProgressManager(backend='rich')
+
+    # prog = ub.ProgIter(batch_iter, desc='fusion predict', verbose=1, freq=1)
+
+    # Make threads after starting background proces.
+    if config.write_workers == 'datamodule':
+        config.write_workers = datamodule.num_workers
+    writer_queue = util_parallel.BlockingJobQueue(
+        mode='thread',
+        # mode='serial',
+        max_workers=config.write_workers
+    )
+20/15:
+    result_fpath = ub.Path(result_dataset.fpath)
+    result_fpath.parent.ensuredir()
+    print('result_fpath = {!r}'.format(result_fpath))
+
+    stitch_managers = build_stitching_managers(
+        config, model, result_dataset,
+        writer_queue=writer_queue
+    )
+20/16: stitch_managers
+20/17: # error with indentation, move to 20/18
+20/18:
+    expected_outputs = set(stitch_managers.keys())
+    got_outputs = None
+    writable_outputs = None
+
+    print('Expected outputs: ' + str(expected_outputs))
+
+    head_key_mapping = {
+        'saliency_probs': 'saliency',
+        'class_probs': 'class',
+        'change_probs': 'change',
+    }
+20/19: # error, need to import _jsonify. Jump to 20/21
+
+20/21: from geowatch.tasks.fusion.predict import _jsonify
+20/22:
+    info = result_dataset.dataset.get('info', [])
+
+    pred_dpath = ub.Path(result_dataset.fpath).parent
+    rich.print(f'Pred Dpath: [link={pred_dpath}]{pred_dpath}[/link]')
+
+    DRAW_BATCHES = config.draw_batches
+    if DRAW_BATCHES:
+        viz_batch_dpath = (pred_dpath / '_viz_pred_batches').ensuredir()
+
+    config_resolved = _jsonify(config.asdict())
+    fit_config = _jsonify(fit_config)
+
+    from kwcoco.util import util_json
+    unresolvable = list(util_json.find_json_unserializable(config_resolved))
+    if unresolvable:
+        import warnings
+        warnings.warn(f'NotReproducibleWarning: Found unresolvable configuration options: {unresolvable!r}')
+        config_walker = ub.IndexableWalker(config_resolved)
+        for unresolvable_item in unresolvable:
+            _value = unresolvable_item['data']
+            config_walker[unresolvable_item['loc']] = f'Unresolvable: {_value}'
+
+        unresolvable = list(util_json.find_json_unserializable(config_resolved))
+        assert not unresolvable, 'should have entered dummy values for unresolvable data'
+20/23:
+    if config['record_context']:
+        from geowatch.utils import process_context
+        proc_context = process_context.ProcessContext(
+            name='geowatch.tasks.fusion.predict',
+            type='process',
+            config=config_resolved,
+            track_emissions=config['track_emissions'],
+            # Extra information was adjusted in 0.15.1 to ensure more relevant
+            # fit params are returned here. A script
+            # ~/code/geowatch/geowatch/cli/experimental/fixup_predict_kwcoco_metadata.py
+            # exist to help update old results to use this new format.
+            extra={
+                'fit_config': fit_config
+            }
+        )
+        # assert not list(util_json.find_json_unserializable(proc_context.obj))
+        info.append(proc_context.obj)
+        proc_context.start()
+        test_coco_dataset = datamodule.coco_datasets['test']
+        proc_context.add_disk_info(test_coco_dataset.fpath)
+
+    memory_monitor_timer = ub.Timer().tic()
+    memory_monitor_interval_seconds = 60 * 60
+    with_memory_units = bool(ub.modname_to_modpath('pint'))
+20/24: torch.set_grad_enabled(False)
+20/25:
+        EMERGENCY_INPUT_AGREEMENT_HACK = 1 and hasattr(model, 'input_norms')
+
+        # prog.set_extra(' <will populate stats after first video>')
+        # pman.start()
+
+        prog = pman.progiter(batch_iter, desc='fusion predict')
+        _batch_iter = iter(prog)
+        if 0:
+            item = test_dataloader.dataset[0]
+
+            orig_batch = next(_batch_iter)
+            item = orig_batch[0]
+            item['target']
+            frame = item['frames'][0]
+            ub.peek(frame['modes'].values()).shape
+
+        batch_idx = 0
+20/26: pman.stopall()
+20/27:             item = test_dataloader.dataset[0]
+20/28:             orig_batch = next(_batch_iter)
+20/29: orig_batch
+20/30:
+            batch_idx += 1
+            batch_trs = []
+            # Move data onto the prediction device, grab spacetime region info
+            fixed_batch = []
+            for item in orig_batch:
+                if item is None:
+                    continue
+                item = item.copy()
+                batch_gids = [frame['gid'] for frame in item['frames']]
+                frame_infos = [ub.udict(f) & {
+                    'gid',
+                    'output_space_slice',
+                    'output_image_dsize',
+                    'output_weights',
+                    'scale_outspace_from_vid',
+                } for f in item['frames']]
+                batch_trs.append({
+                    'space_slice': tuple(item['target']['space_slice']),
+                    # 'scale': item['target']['scale'],
+                    'scale': item['target'].get('scale', None),
+                    'gids': batch_gids,
+                    'frame_infos': frame_infos,
+                    'fliprot_params': item['target'].get('fliprot_params', None)
+                })
+                position_tensors = item.get('positional_tensors', None)
+                if position_tensors is not None:
+                    for k, v in position_tensors.items():
+                        position_tensors[k] = v.to(device)
+
+                filtered_frames = []
+                for frame in item['frames']:
+                    frame = frame.copy()
+                    sensor = frame['sensor']
+                    if EMERGENCY_INPUT_AGREEMENT_HACK:
+                        try:
+                            known_sensor_modes = model.input_norms[sensor]
+                        except KeyError:
+                            known_sensor_modes = None
+                            continue
+                    filtered_modes = {}
+                    modes = frame['modes']
+                    for key, mode in modes.items():
+                        if EMERGENCY_INPUT_AGREEMENT_HACK:
+                            if key not in known_sensor_modes:
+                                continue
+                        filtered_modes[key] = mode.to(device)
+                    frame['modes'] = filtered_modes
+                    filtered_frames.append(frame)
+                item['frames'] = filtered_frames
+                fixed_batch.append(item)
+20/31:             batch = fixed_batch
+20/32: batch
+20/33:
+                from geowatch.utils.util_netharn import _debug_inbatch_shapes
+                print(_debug_inbatch_shapes(batch))
+20/34:
+                if memory_monitor_timer.toc() > memory_monitor_interval_seconds:
+                    # TODO: monitor memory usage and report if it looks like we
+                    # are about to run out of memory, and maybe do something to
+                    # handle it.
+                    from geowatch.utils import util_hardware
+                    mem_info = util_hardware.get_mem_info(with_units=with_memory_units)
+                    print(f'\n\nmem_info = {ub.urepr(mem_info, nl=1)}\n\n')
+                    memory_monitor_timer.tic()
+20/35: mem_info
+20/36:                 outputs = model.forward_step(batch, with_loss=False)
+20/37: outputs
+20/38: # error
+20/39: outputs.keys()
+20/40: outputs = {head_key_mapping.get(k, k): v for k, v in outputs.items()}
+20/41: outputs.keys()
+20/42:
+                got_outputs = list(outputs.keys())
+                prog.ensure_newline()
+                writable_outputs = set(got_outputs) & expected_outputs
+                print('got_outputs = {!r}'.format(got_outputs))
+                print('writable_outputs = {!r}'.format(writable_outputs))
+20/43:
+            for head_key in writable_outputs:
+                head_probs = outputs[head_key]
+20/44: head_key
+20/45: head_probs
+20/46-48: # errors
+20/49:
+                head_stitcher = stitch_managers[head_key]
+                chan_keep_idxs = head_stitcher.head_keep_idxs
+20/50: chan_keep_idxs
+20/51:                     predicted_frame_slice = slice(None)
+20/52:
+                num_batches = len(batch_trs)
+
+                for bx in range(num_batches):
+                    target: dict = batch_trs[bx]
+                    item_head_probs: list[torch.Tensor] | torch.Tensor = head_probs[bx]
+20/53: item_head_probs
+20/54: item_head_probs.shape
+20/55:
+                    item_head_relevant_probs = [p[..., chan_keep_idxs] for p in item_head_probs]
+                    bin_probs = [p.detach().cpu().numpy() for p in item_head_relevant_probs]
+20/56: # error
+20/57: DEBUG_PRED_SPATIAL_COVERAGE=0
+20/58:
+                    frame_infos: list[dict] = target['frame_infos'][predicted_frame_slice]
+
+                    fliprot_params: dict = target['fliprot_params']
+                    # Update the stitcher with this windowed prediction
+                    for probs, frame_info in zip(bin_probs, frame_infos):
+                        if fliprot_params is not None:
+                            # Undo fliprot TTA
+                            probs = data_utils.inv_fliprot(probs, **fliprot_params)
+
+                        gid = frame_info['gid']
+                        output_image_dsize = frame_info['output_image_dsize']
+                        output_space_slice = frame_info['output_space_slice']
+                        scale_outspace_from_vid = frame_info['scale_outspace_from_vid']
+
+                        if DEBUG_PRED_SPATIAL_COVERAGE:
+                            image_id_to_video_space_slices[gid].append(target['space_slice'])
+                            image_id_to_output_space_slices[gid].append(output_space_slice)
+
+                        output_weights = frame_info.get('output_weights', None)
+20/59: probs
+20/60: probs.shape
+20/61: frame_info
+20/62:
+                            head_stitcher.accumulate_image(
+                                gid, output_space_slice, probs,
+                                asset_dsize=output_image_dsize,
+                                scale_asset_from_stitchspace=scale_outspace_from_vid,
+                                weights=output_weights,
+                                downweight_edges=downweight_edges,
+                            )
diff --git a/docs/source/manual/tutorial/tutorial9_hidden_layer.py b/docs/source/manual/tutorial/tutorial9_hidden_layer.py
index 96b8b90e2..8d57c8292 100644
--- a/docs/source/manual/tutorial/tutorial9_hidden_layer.py
+++ b/docs/source/manual/tutorial/tutorial9_hidden_layer.py
@@ -1,12 +1,14 @@
-# Tutorial to extract hidden layer features from a pre-trained model
+# Tutorial code to extract hidden layer features from a pre-trained model
+# Need to then pass it to the stitching manager
 
 import sys
 import os
 import ubelt as ub
 
-
-# import module
+# --------------------------------------------------------------
+# import module and details from the doctest example
 from geowatch.tasks.fusion.predict import *  # NOQA
+from geowatch.tasks.fusion.predict import _register_hidden_layer_hook
 print(Predictor)
 from geowatch.utils.lightning_ext.monkeypatches import disable_lightning_hardware_warnings
 disable_lightning_hardware_warnings()
@@ -52,11 +54,12 @@ fit_config = kwargs = {
 from geowatch.tasks.fusion import fit_lightning
 package_fpath = root_dpath / 'final_package.pt'
 
-# sets up the model
+# sets up the training model and trains it
 fit_lightning.main(fit_config)
 
 # Unfortunately, its not as easy to get the package path of
 # this call..
+# Looks like it setups up kwargs for the prediction model
 assert ub.Path(package_fpath).exists()
 # Predict via that model
 predict_kwargs = kwargs = {
@@ -69,9 +72,12 @@ predict_kwargs = kwargs = {
      'devices': devices,
      'draw_batches': 1,
 }
-
-# Does this run the prediction model?
-result_dataset = predict(**kwargs)
+#---------------------------------
+# Performs a prediction on the dataset
+# this runs parts of the code shown below.  Script runs the various portions
+#-------------------------------------
+## result_dataset = predict(**kwargs)
+#---------------------------------------
 
 #-----------------------------------------------------------------
 # Load config, model and dataset into predictor model
@@ -174,8 +180,8 @@ print('result_fpath = {!r}'.format(result_fpath))
 # Need a change here to include hidden layer
 # 226-247 pass info
 # Stitching manager needs bands and short code
-# chan_code = is important
-# create a stitching manager and add to the dicstionary
+# chan_code = is important, provides the path to the features/activations
+# create a stitching manager and add to the dictionary
 # Model knows what classes it wants to predict
 # creates a 4-channel raster, then takes "saliency"
 # create another stitching manager to gather descriptor features
@@ -366,6 +372,23 @@ if memory_monitor_timer.toc() > memory_monitor_interval_seconds:
 # Prepare input --> model.forward_step --> prepare output
 outputs = model.forward_step(batch, with_loss=False)
 
+
+# Hidden_layers have been extracted at this point
+print(model._activation_cache.keys())
+print(model._activation_cache['hidden'])
+print(model._activation_cache['hidden'].shape)
+
+#----this needs to get added to the source code
+
+hidden_stitcher = CocoStitchingManager(
+    result_dataset
+    'hidden_layers',
+    chan_code='hidden_layers',
+    stiching_space='image',
+    writer_queue=writer_queue,
+    assets_dname=config.assets_dname,
+)
+
 # view `outputs`
 # view `outputs.keys()`
 
@@ -376,7 +399,8 @@ outputs.keys()
 # Checks/hack, runs once
 got_outputs = list(outputs.keys())
 prog.ensure_newline()
-writable_outputs = set(got_outputs) & expected_outputs
+# This code eliminates the "hidden_layers" key
+writable_outputs = expected_outputs & set(got_outputs)
 print('got_outputs = {!r}'.format(got_outputs))
 print('writable_outputs = {!r}'.format(writable_outputs))
 
@@ -392,6 +416,8 @@ predicted_frame_slice = slice(None)
 num_batches = len(batch_trs)
 
 # -----------------------------------------------------------------------
+# This code looks like it does some cleanup, filtering and then
+# writes the results.
 for bx in range(num_batches):
     target: dict = batch_trs[bx]
     item_head_probs: list[torch.Tensor] | torch.Tensor = head_probs[bx]
@@ -433,8 +459,9 @@ for bx in range(num_batches):
         print(probs.shape)
         print(frame_info)
         print(f"model._activation['hidden'] {model._activation_cache['hidden']}")
+        print(f"\n hidden layer shape {model._activation_cache['hidden'].shape}")
 
-
+        head_stitcher = stitch_managers["hidden_layers"]
 
         # Checks if an image is done, submits to finalize
         head_stitcher.accumulate_image(
diff --git a/geowatch/tasks/fusion/predict.py b/geowatch/tasks/fusion/predict.py
index 6d0f545b1..06aeb1172 100644
--- a/geowatch/tasks/fusion/predict.py
+++ b/geowatch/tasks/fusion/predict.py
@@ -229,14 +229,16 @@ def _register_hidden_layer_hook(model):
     def record_hidden_activation(layer, input, output):
         activation = output.detach()
         model._activation_cache['hidden'] = activation
-        print(f"Hidden layer activation shape {activation.shape}")
+        print(f"Hidden Layer Extracted! Shape is {activation.shape}")
 
     # Desired layer is the 5th layer of the 2nd encoder layer
+    # See `/docs/source/manual/tutorial/fusion_model_layer_info.sh`
+    # for an example structure of the model
     layer_of_interest = list(model.encoder.children())[1][6]
     layer_of_interest._forward_hooks.clear()
     layer_of_interest.register_forward_hook(record_hidden_activation)
-# ----------------------------------------------------------
 
+# ----------------------------------------------------------
 
 
 # --------------Add hidden layer hook to model----------------
-- 
GitLab


From 7e469cba48922ce2a986e576cf08fcd4448d00a4 Mon Sep 17 00:00:00 2001
From: joncrall <jon.crall@kitware.com>
Date: Thu, 20 Jun 2024 13:11:24 -0400
Subject: [PATCH 3/3] Hack outputs to contain hidden features

---
 .../manual/tutorial/tutorial9_hidden_layer.py | 20 ++++++-------
 geowatch/tasks/fusion/predict.py              | 29 +++++++++----------
 geowatch_tpl/submodules/loss-of-plasticity    |  2 +-
 3 files changed, 25 insertions(+), 26 deletions(-)

diff --git a/docs/source/manual/tutorial/tutorial9_hidden_layer.py b/docs/source/manual/tutorial/tutorial9_hidden_layer.py
index 8d57c8292..e1db42bf2 100644
--- a/docs/source/manual/tutorial/tutorial9_hidden_layer.py
+++ b/docs/source/manual/tutorial/tutorial9_hidden_layer.py
@@ -378,16 +378,16 @@ print(model._activation_cache.keys())
 print(model._activation_cache['hidden'])
 print(model._activation_cache['hidden'].shape)
 
-#----this needs to get added to the source code
-
-hidden_stitcher = CocoStitchingManager(
-    result_dataset
-    'hidden_layers',
-    chan_code='hidden_layers',
-    stiching_space='image',
-    writer_queue=writer_queue,
-    assets_dname=config.assets_dname,
-)
+# #----this needs to get added to the source code
+
+# hidden_stitcher = CocoStitchingManager(
+#     result_dataset
+#     'hidden_layers',
+#     chan_code='hidden_layers',
+#     stiching_space='image',
+#     writer_queue=writer_queue,
+#     assets_dname=config.assets_dname,
+# )
 
 # view `outputs`
 # view `outputs.keys()`
diff --git a/geowatch/tasks/fusion/predict.py b/geowatch/tasks/fusion/predict.py
index 06aeb1172..87d3db976 100644
--- a/geowatch/tasks/fusion/predict.py
+++ b/geowatch/tasks/fusion/predict.py
@@ -212,29 +212,28 @@ def _register_hidden_layer_hook(model):
 
     # print("info on model", dir(model))
 
-    print("Enumerate over model.children()\n")
-    for i, layer in enumerate(model.children()):
-        print(f"Layer {i}: {layer}")
+    # print("Enumerate over model.children()\n")
+    # for i, layer in enumerate(model.children()):
+    #     print(f"Layer {i}: {layer}")
 
     # Not sure this is the correct code
-    encoder_layers = list(model.encoder.children())
-
-    print(f"\nNumber of encoder layers {len(encoder_layers)}\n")
 
+    # Hack to grab the inputs to one of the heads
+    # This will let us grab pre-formated spacetime features
+    # out of the multimodal model.
+    available_decoders = (ub.oset(['saliency', 'class', 'change']) & model.heads.keys())
+    chosen_head_key = available_decoders[0]
+    layer_of_interest = model.heads[chosen_head_key].hidden.hidden0.conv
 
-    print("Enumerate over model.encoder.children()\n")
-    for i, layer in enumerate(encoder_layers):
-        print(f"Encoder Layer {i}: {layer}")
-
-    def record_hidden_activation(layer, input, output):
-        activation = output.detach()
+    def record_hidden_activation(layer, inputs, output):
+        assert len(inputs) == 1
+        input_features = inputs[0]
+        activation = input_features.detach()
         model._activation_cache['hidden'] = activation
         print(f"Hidden Layer Extracted! Shape is {activation.shape}")
 
-    # Desired layer is the 5th layer of the 2nd encoder layer
     # See `/docs/source/manual/tutorial/fusion_model_layer_info.sh`
     # for an example structure of the model
-    layer_of_interest = list(model.encoder.children())[1][6]
     layer_of_interest._forward_hooks.clear()
     layer_of_interest.register_forward_hook(record_hidden_activation)
 
@@ -549,7 +548,7 @@ def resolve_datamodule(config, model, datamodule_defaults, fit_config):
     DZYNE_MODEL_HACK = 1
     if DZYNE_MODEL_HACK and isinstance(config['package_fpath'], str):
         package_fpath = ub.Path(config['package_fpath'])
-        if package_fpath.stem == 'lc_rgb_fusion_model_package':
+        if package_fpath.stest_dataloadertem == 'lc_rgb_fusion_model_package':
             # This model has an issue with the L8 features it was trained on
             datamodule_vars['exclude_sensors'] = ['L8']
 
diff --git a/geowatch_tpl/submodules/loss-of-plasticity b/geowatch_tpl/submodules/loss-of-plasticity
index eb0868220..50f7a5614 160000
--- a/geowatch_tpl/submodules/loss-of-plasticity
+++ b/geowatch_tpl/submodules/loss-of-plasticity
@@ -1 +1 @@
-Subproject commit eb0868220cebb4930476ff51110cffddfb983b0a
+Subproject commit 50f7a561445e734c7444dc4f354d9b22b2644cb1
-- 
GitLab