Optimization: Faster TimeWindowSampler
Currently the TimeWindowSampler
does a fancy sampling based on a "time kernel" and affinity matrix sampling (see code for details).
One of the doctests illustrates the process:
xdoctest -m geowatch.tasks.fusion.datamodules.temporal_sampling __doc__:0 --show
The left image shows the affinity matrix, which represents how much each frame wants to be sample with every other frame as well as a non-determenistic and determenistic set of samples drawn via the sampling process which uses this affinity matrix.
A query involves providing a time index of interest and a number of desired frames, and then asking for the sampler to "fill in" the rest of the frames.
The right side shows how this sampling process works for a single query. It takes the row from the affinity matrix corresponding to time index of interest, which is the initial probability distribution for sampling all other frames. This distribution is modulated by a few factors depending on parameters (e.g. time-kernel). Then a sample is drawn and that becomes the "next" sample. It's row from the affinity matrix is "folded into" the distribution, and other modulations are applied (time kernel, downweighting previously chosen indexes). This process repeates until the number of desired frames are sampled or a sampling error occurs (e.g. no more frames to sample from).
The problem is that this query is slow. It bottlenecks startup time for training new problems or on new machines in the initial grid building stage. This is somewhat mitigated by the fact that the initial grid building cached, but it is annoying for users and it doesn't have to be this way.
Recall the time kernel is a way to specify distribution as a set of "ideal sample times" (illustrated via the green dashed lines). This kernel determines a distribution (as the multicolored curves) that will provide a randomized (useful at train time) or determenistic sampling over a dense (or sparse) set of time samples.
When we are given a set of "discrete observations" (the thick dark-blue lines), and for each curve in the time kernel we can assign a probablity to it, and sample proportionally (for randomized) or take the max probability (discrete). We also have to take care of the cases where there are not enough samples, or a time is chosen twice.
I'm thinking that a simpler method based only on the time kernel would do about as good of a job, but be significantly faster.
Looking for ideas / implementations.