Unpacking K-P Flow:
The Geometry of GD Learning in General Recurrent Models


Relevant Links
Arxiv Pre-Print Pytorch Package

TLDR
Q: Can we track the evolution of a dynamical system trained by Gradient Descent (GD)?
A: Yes, but the NTK involves tensor calculus. We need new geometric intuition and dynamical tools.



Synopsis
This blog post provides a more intuitive exploration of some parts of our main paper, linked above, which broadly explores the gradient flow of general recurrent models. An efficient and in-development package with examples is linked above. See accompanying blog posts on my main page exploring specific aspects of the code.





Broader Motivation and Prior Work


Recurrent Models in Neuroscience and ML Recurrent models in machine learning are a powerful tool that can mimic sequential behavior when fit to data. Practically, this can be used to solve a variety of tasks in deep learning and control. Furthermore, such models can be used as a proxy for understanding how circuity in the brain forms so-called neural manifolds, consisting of low-dimensional dynamical motifs such as fixed points and attractors. In multi-task learning contexts, recurrent models give insights solving complex continual learning or compositional problems.


Many, Many Recurrent Architectures Classically, in control, the most common recurrent model is the Linear Time Invariant (LTI) controller, which is a linear dynamical system that can be trained on a variety of problems and is well analyzed theoretically. In deep learning, the most classical model is the Recurrent Neural Network (RNN), with non-linear time-stepping dynamics. Building on this model are GRUs and LSTMs, which train better, avoiding pitfalls of vanishing and exploding gradients. More recently, there are many diverse recurrent architectures in deep learning, including those with dynamical synapses (MPNs) or State-Space Models (SSMs). Outside of these contexts, we can look to physics and neuroscience to see a very wide range of complex network-based recurrent dynamical systems, such as spin-glasses, Hodgkin-Huxley biophysical neural networks, Hopfield neural networks, energy-based models, and many, many more.


Training Recurrent Models In control and deep learning, the primary means of "training" such recurrent models is dynamically adapting their parameters (typically weights of the recurrent or output connections) with (potentially stochastic or accelerated) Gradient Descent (GD). More specifically, on discrete models like RNNs or GRUs, we use Backpropagation Through Time (BPTT) to efficiently compute gradients, which are then used to incrementally adjust the parameters. In optimal control or continuous Neural ODEs, this may be instead be labeled as the adjoint method, but it was proven long ago that BPTT and the adjoint method are exactly the same thing.


Tracking Trained Dynamics In this study, we investigated how the dynamics of a recurrent model trained with GD evolve. More specifically, we define operators that, when composed together, define the hidden-state gradient flow of the a general parameterized recurrent dynamical system.

Introduction

Gradient Descent (GD) Setup

A General Recurrent Model As described, we consider a recurrent dynamical system, trained with GD to approximately mirror target trajectories given a variety of time-varying inputs. We let \(h(t)\) denote the model's hidden state, and \(x(t)\) denote a sample task input. The variable \(t\) denotes the forward pass time and exists in a given time range \([0, t_{end}]\). During the forward-pass inference, an individual task input is chosen, \(x(t)\), from a given distribution of all task inputs; then, the model state \(h(t)\) is simulated from an initial state \(h_0\) to time \(t_{end}\), driven by the input \(x(t)\). The dynamics can be continuously specified or discretely with little change in the mathematics below, so we use assume time is continuous. In this case, the model dynamics take the form: $$\frac{d}{dt} h(t) = f(h(t), x(t), \theta), \text{ for } t \in (0, t_{end}]; \text { with } h(0) = h_0.$$ Here, \(\theta\) denote the model parameters (e.g., weights and biases) and \(f\) models the tangential dynamics of the hidden state.

Defining the Task For each input \(x(t)\) we associate and desired output target \(y^*(t)\). Furthermore, we define an output function of the hidden state, e.g., the simple affine map: $$y(t) = W_{out} h(t) + b_{out}.$$ Then, the goal is to minimize a loss function \(\ell(y(t), y^*(t))\) on all possible inputs and at all times. In particular, we define the loss as an average (denoted by \(\langle \cdot \rangle_{x,t}\)): $$L := \langle \ell(y(t), y^*(t)) \rangle_{x,t}.$$

Training the Model with GD The goal is to minimize the loss \(L\) by tuning the parameters of the model, \(\theta\). To do so, we want to find the best parameters: $$\theta^* := \arg \min_\theta L.$$ This is typically done by GD, where the parameters are iteratively updated by travelling down the steepest direction in Euclidean parameter space. Specifically, letting \(\delta \theta\) be the perturbation to the parameters at a particular instant in GD, it is given by: $$\delta \theta := -\eta \nabla_\theta L.$$ Here, \(\eta\) is a learning rate, which I'll just assume is \(1\) throughout the rest of this blog for simplicity.

The Hidden State GD Flow and NTK

Tracking the GD Flow We would like to track certain quantities as they evolve under GD. The simplest quantity to track are the parameters, defining the so-called parameter GD flow. However, this quantity is ultimately a proxy for the actual dynamics of the hidden state, which also evolves as GD trains the model. We would like to track how this quantity itself evolves in a meaningful way. The hidden state dynamics can be envisioned as a large vector field, conditioned on the particular input you give the model.

Now, before getting into the details of the objects involved, the general idea (which turns out to work correctly) is simply as follows. Firstly, define the Instantenous Error Signal, Err, by $$\text{Err} := \nabla_{h} \ell(y, y^*).$$ For example, when we use a squared error loss, \(\ell(y, y^*) := \| y - y^*\|\), the error signal is \(\text{Err} = W_{out}^T (y - y^*)\), a simple residual projected onto the row space of \(W_{out}\). Then, by the chain rule, $$\nabla_\theta L = (\frac{d h}{d \theta})^T \text{Err}.$$ Given this parameter change, the hidden state itself will approximately change by $$\delta h = -\frac{d h}{d \theta} \cdot \nabla_\theta L = -\frac{d h}{d \theta} \cdot \frac{d h}{d \theta}^T \text{Err}.$$ The Jacobian outer product \(\frac{d h}{d \theta} \cdot \frac{d h}{d \theta}^T\) is referred to as the Neural Tangent Kernel (NTK).

The Classical NTK In the classical NTK literature, the model considered is a multi-layer perceptron neural network with scalar output, so there is no notion of time, \(t\), and no spatial (hidden unit) dimension. Hence, the NTK is just a B by B matrix over all batch input trials. The NTK entries pinpoint how much a GD corrections correlate over batch trials: do corrections to the parameters proposed by two distinct task inputs agree, not overlap at all, or disagree, leading to loss when we do the actual parameter update.

The NTK is a tensor operator In our case, however, the NTK is actually a tensor operator. In a Pytorch implementation, for example, the hidden state will have the form [B, T, H] over batch inputs, \(x\), times, \(t\), and hidden units. The quantity \(\frac{d h}{d \theta}\) thus has the form [B, T, H, M], where \(M\) is the size of the flattened model parameters. Finally, \(\Theta\) is an linear operator on the space of tensors of the form [B, T, H]. If we discretize it, it will be a massive matrix of shape B*T*H by B*T*H. If indexed in Pytorch, the of \(\Phi\) are given by $$\Phi: \mathbb{R}^{B\times T \times H} \rightarrow \mathbb{R}^{B \times T \times H}$$ $$\Phi[b, t, h, b_0, t_0, h_0] = \sum_{m=1}^M \frac{d h}{d \theta}[b, t, h, m] \frac{d h}{d \theta}[b_0, t_0, h_0, m].$$ Formally, this is a tensor product on \(\frac{d h}{d \theta}\) where the parameter dimension is contracted (see my blog post going into more depth on tensor calculus). The fact that the NTK is a tensor operator for such general recurrent models has pros and cons. Pros include that it encapsulates a massive amount of information, including insights into GD learning at very granular levels given by eigenfunctions. Indeed, in the next section I'll discuss some of the ways we can reduce the operator to generate different perspectives on the GD learning. However, this is also a since it makes the whole object harder to understand, requiring some tensor math. Another major con is the massive size of this object. Even for very small problems it can be huge (e.g. 100 neurons, batch size 100 and 100 forward-pass times results in \(\Phi\) being 1m by 1m when discretized). Thus, advanced matrix-free methods that do not actually compute the full discretized operator are needed to work with \(\Phi\) (see my blog using trace estimation for the NTK, for example).

Working with Tensor Operators

As mentioned in the prior section, the full empirical NTK for recurrent models is a linear operator on a space of 3-tensors (discretely of shape [B, T, H]). In this section, I'll discuss building some more intuition and practical tools for working with such operators. Indeed, it may be tempting to think that working with such a massive, complex objective is overcomplicating things. However, this object exactly matches Pytorch without any simplication and can be simplified after-the-fact, instead of simplifying at the outset. As we will see later, keeping the full operator in all its complexity allows us to further decompose it into individual components: \(\mathcal{P}\) and \(\mathcal{K}\) which are operators that have their own distinct structure that can be exploited.


Fig 1 Schematic of heirarchy of views attainable by different reductions of the full hidden-state NTK operator (or related operators).

Reduced Views A simple way to work with this complex operator is to reduce it into simpler "views". The operator itself informs how gradient updates to the hidden state are structured over time, batch inputs and hidden units, which is a ton of information. Sometimes, we would like to know in what subspaces the updates will reside, on average, without consider time or batches. Or, as another example, we want to consider where most of the updates will be concentrated as a signal over time and batches, without thinking about the individual hidden units.

To generate such views, we define a method of reducing the operator over particular axes. Given an axis (time, batches or hidden units), the view of the tensor essentially performs streaming over that axis. For example, let's define the time-averaged NTK \(\langle \Phi \rangle_t\). This is a linear operator now on 2-tensors of shape [B, H]. Specifically, $$\langle \Phi \rangle_t[b, h, b_0, h_0] = \langle \Phi[b, h, t, b_0, h_0, t_0] \rangle_{t, t_0},$$ i.e. we average over the time axes. This operator takes in tensors of shape [B, H], produces a tensor of shape [B, T, H] by making T copies, applies \(\Phi\) to this tensor, then averages the final result over the time axis. In my code, you can find the implementation in op_common.py, called AveragedOperator. Similarly, we can define averaging operations over any of the axes: time, hidden units, or batches, or multiple simultaneously.


Some Example Use Cases

As a Matrix Over Hidden Units Using these views, we can measure interpretable properties of the NTK and gradient flow. For example, suppose want to quantify where the gradient updates will be concentrated in the hidden space, irregardless of time or batch trials. Then, we can measure the averaged operator \(\langle \Phi \rangle_{x,t}\), which is now an H by H matrix. We can then measure the singular vectors \(\{v_i\}_{i=1...n}\) of this matrix, and its singular values \(\{\sigma_i\}_{i=1...n}\). The principle vector \(v_1\) explains exactly which types of hidden space inputs will maximally stimulate the operator \(\Phi\), on average over input batches and times. Note that two trials with different inputs, \(x_1, x_2\) to the hidden state, may produce trajectories in completely disparate regions of the hidden space. However, by design, both of these directions will factor in distinctly to the averaged SVD (see my paper Appendix for more explanation of this). If we want to see where gradient updates will be correlated, we can use the left singular vectors, \(\{u_i\}_{i=1...n}\): the first vector \(u_1\) explains where \(\Phi\)'s updates will be contentrated if we given it tons of random input signals, i.e. where GD is likely to concentrate. Finally, the effective rank of the operator \(\langle \Phi \rangle_{x,t}\) explains how hidden space GD updates are constrained: if it is low, it means the operator will be constrained to producing updates in a small, low-dimensional hidden subspace.

Eigenfunction Signals Averaging Hidden Space As another example, we can consider the operator \(\langle \Phi \rangle_{h}\), averaged over spatial hidden units. If we take the SVD of this operator, we get input-dependent eigenfunctions \(\{\phi_{b,t}\}_{b=1...B,t=1...T}\). Each eigenfunction explains which points in time are most crucial to stimulating \(\Phi\), for each individual one of the batch trials. On trial 1, it might say that a particular time window of the task is very crucial, i.e. most of the learning will be concentrated there, while on trial 2, with a different batch input for the model, there could be a completely distinct time window. Consequently, these eigenfunction signals explain the temporal structure imposed by the task on learning, accross each input-driven trial of the model inference. Since the operator is averaged over hidden units, we don't take into account which hidden unit to stimulate, instead we care about which times or indendent input trials are most relevant to the GD update.

Below is a figure more clearly explained in my paper, illustrating the eigenfunctions of the operator \(\langle \mathcal{K} \rangle_{h}\), averaged over hidden units, as in the previous paragraph. A full description of \(\mathcal{K}\) itself is detailed below. The color of each individual trajectory indicate distinct batches. Note that each eigenfunction is a tensor of shape [B, T], as above, informing, on each batch trial, which times are most significant for stimulating the operator. For task setup and more comprehensive details, see the paper itself.

Fig 2 Eigenfunctions and effective rank of the parameter operator \(\langle \mathcal{K} \rangle_{h}\), reduced over the hidden dimension.

The Full Operator SVD

Simular to the previous use case, we can take the SVD of the full operator, \(\Phi\), without averaging any axes. Then, the eigenfunctions and singular values you get inform how to stimulate \(\Phi\) individually at a very granular level: at each timestep of forward evaluation of the model, hidden unit, and on each distinct input-driven trial.

Decomposing the NTK into \(\mathcal{P}\) and \(\mathcal{K}\)

Fig 3 Schematic of the K-P flow decomposition of the full NTK associated with a recurrent dynamical system, as detailed in this blog post.

To summarize the work so far, I have discussed the NTK, expained why for recurrent dynamical models it is a linear operator on 3-tensors and finally I've discussed how to work with such a weird object. In this section, I will introduce the decomposition in my main paper, breaking the recurrent NTK for these general models into a product: $$\Phi = \mathcal{P} \mathcal{K} \mathcal{P}^*$$ where \(\mathcal{P, K}\) are linear operators themselves, which I'll now define, and \(\mathcal{P}^*\) denotes the Hermitian adjoint of the operator \(\mathcal{P}\) (typically just the transpose of the discretized matrix).

Derivation Let's take a step back and consider how the backpropgation forms gradients in recurrent models. This process is known as Backpropagation-Through-Time (BPTT) in discrete contexts and is provably identical to the Adjoint Method in continuous contexts.

Computing the Adjoint (the \(\mathcal{P}^*\) operator): With the error signal, \(\text{Err} := \nabla_h \ell(y, y^*)\), as above, we backpropagate to accumulate the error signal over time in reverse time, computing the so-called adjoint defined by $$a := \nabla_h L = \nabla_h \langle \ell \rangle_t.$$ In particular, the backpropagation stage computes the ajdoints backwards through time as: $$a(t) = J_h(t)^T a(t+1) + \text{Err}(t),$$ where \(J_h(t) = d h(t+1) / d h(t) = d f / d h(t)\) is the state-Jacobian of the model dynamics. In my formulation, this is exactly what the operator \(\mathcal{P}^*\) does, describing the transformation from the error signal, \(\text{Err}\), to the adjoint, \(a\), given exactly by backpropagating, as above. Essentially, the operator \(\mathcal{P}\) accumulates perturbations to the hidden state forwards through time using \(J_h\)(, while its adjoint, \(\mathcal{P}^*\), backpropagates using \(J_h^T\), as above. Thus, we can think of \(\mathcal{P}\) as a Forward Propagation Operator while its adjoint \(\mathcal{P}^*\) is a Backward Propagation Operator. This is described in more detail below.

Computing \(\delta \theta\) (the \(\mathcal{K}\) operator): Once the adjoint is computed, the parameter perturbation GD will propose is given by $$\delta \theta = -\langle J_\theta^T \cdot a \rangle_{x, t},$$ where \(J_\theta^T\) is the one-step Jacobian \(d f / d \theta\). This defines the \(\mathcal{K}\) operator. In particular, \(\mathcal{K}\) maps a particular tensor \(q\) of shape [B, T, H] to a new tensor of the same shape by $$(\mathcal{K} q)[b, t, i] = J_\theta[b, t, i] \cdot \langle J_\theta^T \cdot a \rangle_{x,t}.$$

In total, \(\mathcal{K} a = \mathcal{K} \mathcal{P}^* \text{Err}\) defines the one-step perturbations, \(\delta f\) to the model proposed by GD, before integrating these up over time. Integrating them over time transforms them into the actual changes \(\delta h\) to the hidden state and is exactly explained by the forward operator \(\mathcal{P}\). In particular, the full GD flow is $$\delta h = - \mathcal{P K P^*} (\text{Err}),$$ decomposing the NTK \(\Phi = \mathcal{P K P^*}\).

WHY THE ORGANIZATION MATTERS. ALSO, GIVE P VS K RANDOM INPUTS.

goto: main
goto: top