Track more metrics at train time
I would like to extend our tensorboard logging to include more statistics about training dynamics. This will help us get better intuition when tuning parameters like LR and batch size.
Specifically, I want to be able to track everything recorded in the "Loss of Plasticity in Deep Continual Learning" paper, and described in Section 4.
That includes:
Weight Magnitude:
We measure the average magnitude of the weights by adding up their absolute value and dividing by the total number of weights in the network.
This can likely be computed on each step, although we should have a flag to enable/disable it, or compute it with less frequency as it will likely be expensive. The alternative is do this at the start or end of each train epoch. Does lightning already have something for this?
Dead units
To measure the number of dead units in a network with ReLU activation, we count the number of units with a value of zero for all examples in a random sample of two thousand images at the beginning of each new task. An analogous measure in the case of sigmoid or tanh activations is the number of units that are ϵ away from either of the extreme values of the function for some small positive ϵ [42]. We only focus on ReLU networks in this section.
But it might make sense in an online case to compute this number on the validation set. This measure is with respect to a set of activations. I'm not sure if it makes computational sense to run this at train time, unless we compute it per-batch.
Effective Rank:
The effective rank takes into consideration how each dimension influences the transformation induced by a matrix (Effective Rank Paper). A high effective rank signals that most of the dimensions of the matrix contribute equally to the transformation induced by the matrix. On the other hand, a low effective rank corresponds to few dimensions having any significant effect on the transformation, implying that the information in most of the dimensions is close to being redundant.
Formally (using zero-based indexes, which differs from the paper):
Let Φ ∈ ℝ^(n×m) be a matrix
Let σ be the vector of singular values of Φ, of length q, where q = max(n, m)
# Normalized singular values
Let p[k] = σ[k] / ℓ1_norm(σ)
# Entropy of normalized singular values
Let H(p) = -sum(p[k] * log(p[k]) for k in range(q))
Define: erank(Φ) = exp(H(p))
The effective rank of a hidden layer measures the number of units that can produce the output of the layer. We approximate the effective rank on a random sample of two thousand examples before training on each task.
I believe they compute the effective rank for the weights of each layer in the network separately and then average them. This seems like something we could do either each step, or each epoch. It probably has a similar computational cost to looking at weight magnitude.
Reference implementations can likely be found in the LOP repo: https://github.com/shibhansh/loss-of-plasticity
Ideally these would be implemented as pytorch-lightning callbacks to make them generally useful. We keep our current set of callbacks in geowatch/utils/lightning_ext/callbacks
Would also be good to include gradient magnitude.