I'm excited about this work; what seemed like a simple problem -- effective forward-mode differentiation -- turned out to be a rich source of deeply interesting connections and possibilities.
Differentiation through time
We'll consider a general class of differentiable recurrent models with state ht governed by:
ht=F(ht−1,xt;θt).
Here xt is an observation made at time t, and θt are the RNN's parameters (e.g. weight matrix). As usual, the parameters are shared over time, i.e. θt=θ; having the subscript t on the parameters θ conveniently lets us refer to the partial derivatives ∂ht∂θ by the total derivative dhtdθt. We will generally use the notation Jyx for the (total) Jacobian matrix of y with respect to x.
![]() |
State propagation through an RNN |
At each step, the RNN incurs a loss Lt which is a differentiable function of the hidden state ht. In order to optimize θ to minimize the total loss L=∑Tt=1Lt over a sequence of length T, we require an estimate of the gradient JLθ.
JLθ=T∑t=1T∑s=1JLtθs=T∑s=1(T∑t=sJLths)Jhsθs⏟reverse accumulation=T∑t=1JLtht(t∑s=1Jhtθs)⏟forward accumulation
Each of the terms JLtθs indicates how the use of the parameter θ at time s contributed to the loss at time t. The triangular/causal structure JLtθs=1s⩽tJLtθs allows two useful factorizations of the double sum.
The first, labeled reverse accumulation, is used by the popular Backpropagation Through Time algorithm (BPTT). In BPTT we run the model forward to compute the activations and losses, and subsequently run backward to propagate gradient JLht back through time:
JLht=JLht+1Jht+1ht+JLtht.
By following this recursion, we can aggregate the terms JLθt=JLhtJhtθt to compute the gradient. The backpropagation JLht+1Jht+1ht is a vector-matrix product, which has the same cost as the forward propagation of state ht in a typical RNN.
![]() |
Back-propagation of JLht by BPTT |
The second factorization (forward accumulation) is used by the Real-Time Recurrent Learning algorithm (RTRL). The recursion
Jhtθ=Jhtht−1Jht−1θ+Jhtθt
is chronological and can be computed alongside the RNN state. Given Jhtθ, we can compute the term JLtθ=JLthtJhtθ and immediately update the parameter θ (with some technical caveats). The drawback is that the forward propagation Jhtht−1Jht−1θ is an expensive matrix-matrix product. Whereas BPTT cheaply propagated a vector JLht of the same size as the RNN state, RTRL propagates a matrix Jhtθ that consists of one parameter-sized vector for each hidden state. Since typically the parameter is quadratic in the size of the hidden state, this is cubic, and the forward-propagation is quartic (i.e. for a meager 100 hidden units, RTRL is 10,000 times more expensive than BPTT!).
![]() |
Forward-propagation of Jhtθ by RTRL |
It is important to note that the jacobians involved in these recursions depend on the activations ht and other intermediate quantities. In RTRL, these quantities naturally become available in the order in which they are needed, after which they may be forgotten. BPTT revisits these quantities in reverse order, which requires storing them in a stack.
This is the main drawback of BPTT: its storage grows with the sequence length T, which limits the temporal span of dependencies it can capture (as in truncated BPTT) and the rate at which parameter updates can occur.
RTRL has some things to recommend it, if only we had a way of dealing with that giant matrix Jhtθ.
Unbiased Online Recurrent Optimization
UORO (Tallec & Ollivier, 2017) approximates RTRL by maintaining a rank-one estimate of Jhtθ using random projections. A straightforward derivation starts from the expression
Jhtθ=∑s⩽tJhthsJhsθs.
Into each term, we insert a random rank-one matrix νsν⊤s with expectation E[νsν⊤s]=I (e.g. random signs):
Jhtθ≈∑s⩽tJhthsνsν⊤sJhsθs.
The random projections onto νs serve the compress the matrices Jhths and Jhsθs into vector-sized quantities. But accumulating this sum online is still expensive: we must either accumulate the matrix-sized quantities Jhthsνsν⊤sJhsθs or the sequence of pairs of vectors Jhthsνs, ν⊤sJhsθs.
We can pull the same trick again and rely on noise to entangle corresponding pairs. Let τ be a random vector with expectation E[ττ⊤]=I. Then
Jhtθ≈∑s⩽tJhthsνsν⊤sJhsθs=∑s⩽tJhthsνsτsτsν⊤sJhsθs≈(∑s⩽tJhthsνsτs)(∑s⩽tτsν⊤sJhsθs).
For simplicity, we will replace the independent noise variables τt,νt by a single random vector ut=τtνt.
Now we have two vector-valued sums, which we can efficiently maintain online:
˜ht=Jhtht−1˜ht−1+ut˜w⊤t=˜w⊤t−1+u⊤tJhtθt.
This joint recursion is similar to that for Jhtθ in RTRL; the approximation ˜ht˜w⊤t is used as a rank-one stand-in for Jhtθ. Notice how the unwieldy matrix-matrix product in RTRL has been replaced by a cheap matrix-vector product Jhtht−1˜ht−1: UORO is as cheap as BPTT.
![]() |
Forward-propagation of noise in UORO |
![]() |
Accumulation of back-propagated noise in UORO |
So that's the basic workings of UORO: randomly project in state space and then randomly project once again in time. Both of these projections introduce errors in the approximation Jhtθ≈˜ht˜w⊤t, due to connecting the wrong elements together in the matrix-matrix product JhthsJhsθs (*spatial* cross-terms) and due to connecting the wrong time steps q≠r together in Jhthrνrν⊤qJhqθq (*temporal* cross-terms). In expectation, these errors cancel out and the approximation is unbiased.
Variance reduction through iterative rescaling
UORO came with a variance reduction technique that iteratively rescales all quantities involved:
˜ht=γtJhtht−1˜ht−1+βtut˜w⊤t=γ−1t˜w⊤t−1+β−1tu⊤tJhtθt.
The coefficients γt,βt serve to reduce the norms of undesired temporal cross-terms (e.g. γtβ−1tJhtht−1˜ht−1u⊤tJhtθt) while keeping corresponding terms (e.g. γtγ−1tJhtht−1˜ht−1˜w⊤t−1) unaffected. In practice it seems like the brunt of the work is done by γt, which distributes, across ˜ht and ˜wt, the contraction of ˜ht−1 due to forward propagation through Jhtht−1 (aka gradient vanishing).
In our paper we argue that this variance reduction scheme, although cheap and very effective, has some room for improvement. UORO's coefficients are chosen to minimize
Eτt[‖˜ht˜w⊤t−(Jhtht−1˜ht−1˜w⊤t−1+utu⊤tJhtθt)‖2F],
i.e. the expected norm of the error in ˜ht˜w⊤t as a rank-one approximation of the rank-two matrix Jhtht−1˜ht−1˜w⊤t−1+utu⊤tJhtθt. This is a natural quantity to target, but it ignores the bigger picture: downstream, the approximate Jacobians ˜ht˜w⊤t≈Jhtθ are used to produce a sequence of gradient estimates JLtht˜ht˜w⊤t≈JLtθ, which are aggregated by some optimization process into a *total gradient estimate*
∑t⩽TJLtht˜ht˜w⊤t≈∑t⩽TJLtθ=JLθ.
Notice in particular that each of the terms JLtht˜ht˜w⊤t is based largely on the same random quantities, which produces interactions between consecutive gradient estimates. In our paper we instead seek to minimize the *total variance*
Eu[‖∑t⩽TJLtht˜ht˜w⊤t−JLθ‖2].
Since consecutive gradient estimates are not independent, the variance of the sum is not simply the sum of the variances.
Theoretical framework
We analyze a generalization of UORO's recursions:
˜ht=Jhtht−1˜ht−1+JhtztQtut˜w⊤t=˜w⊤t−1+u⊤tQ−1tJztθt.
The symbolic variable zt may refer to any cut vertex along the path from θt to ht. In vanilla UORO, zt≡ht, so projection occurs in state space. Other choices include projection in parameter space (zt≡θt) and projection in preactivation space (which has convenient structure).
We also replaced the scalar coefficients γt,βt by matrices Qt (see the paper for the details). These matrices transform the noise vectors ut; the section on REINFORCE below reveals an interpretation of these matrices as modifying the covariance of exploration noise.
We define the following shorthands:
b(t)⊤s=JLtzsandJs=Jzsθs.
Now the total gradient estimate,
∑t⩽TJLtht˜ht˜w⊤t=∑t⩽TJLtht(∑s⩽tJhtzsQsus)(∑s⩽tu⊤sQ−1sJzsθs)=(∑s⩽tb(t)sQsus)(∑s⩽tu⊤sQ−1sJs),
can be expressed as
∑t⩽T(b(t)1⋮b(t)T)⊤(Q1⋱QT)(u1⋮uT)(u1⋮uT)⊤(Q1⋱QT)−1(S(t)1⋱S(t)T)(J1⋮JT),
where S(t)s=1s⩽tI enforces causality: the estimate JLtht˜ht˜w⊤t at time t does not involve contributions Jzsθs from future (s>t) parameter applications. This property already holds on the b side, as b(t)s=1s⩽tb(t)s.
Giving names to these concatenated quantities, we may write
∑t⩽TJLtht˜ht˜w⊤t=∑t⩽Tb(t)⊤Quu⊤Q−1S(t)J.
This is a fairly simple expression, which makes it easy to analyze the behavior of the estimator. We see immediately that the estimator is unbiased, as Eu[Quu⊤Q−1]=QEu[uu⊤]Q−1=QQ−1=I (Q is assumed to be independent of the noise u). In the paper we also derive the variance; for this blog post it will be enough to note that it is dominated by a product of traces,
V(Q)=∑s⩽T∑t⩽Ttr(∑r⩽Tb(s)rb(t)⊤rQrQ⊤r)tr(∑q⩽TS(t)qJqJ⊤qS(s)q(QqQ⊤q)−1).
Toward improved variance reduction
Joint optimization of these quantities turned out to be analytically intractable, and even alternately optimizing the αt and Q0 is difficult. Still, we made some headway on these problems; of particular interest is the quantity
B=∑s⩽T∑t⩽T(∑q=1min
which gives the optimal Q_0 as Q_0 = B^{- 1 / 4} when projection occurs in preactivation space (see the paper). The vector a_t is the input to the RNN (i.e. the previous state, the current observation and a bias), which shows up here due to the convenient structure of the backward Jacobian J_t = \mathcal{J}^{z_t}_{\theta_t} = I \otimes a_t^{\top} in the case of
preactivation-space projection.
In the optimal case Q_0 = B^{- 1 / 4}, the variance contribution V (Q) above reduces from \operatorname{tr} (B) \operatorname{tr} (I) to \operatorname{tr} (B^{1 / 2})^2, which by Cauchy-Schwarz is an improvement to the extent that the eigenspectrum of B is lopsided rather than flat. We show empirically, in a small-scale controlled setting in which B is known (by backprop), that the optimal (in our framework) \alpha_t and Q_0 result in significant variance reduction.
Of course, it is not obvious how to implement these ideas effectively. The theoretically optimal choices for these quantities depend on information that is unknown. For instance, the matrix B is sort of a weighted covariance of sums of future gradients \sum_{s = q}^T b^{(s)}_r; if we knew these gradients we wouldn't need to maintain the RTRL Jacobian \mathcal{J}^{h_t}_{\theta} in the first place! One cool idea is to estimate it online using the same rank-one tricks we use to estimate the gradient, which results in a self-improving algorithm. We show one such estimator, but it (and others like it) fared poorly in practice, as one might have guessed.
However, the theory is meant to guide practice, not to dictate it. Unbiased estimation of B is not necessary (as the overall algorithm is unbiased for any Q_0), and nor is it particularly desirable if it comes at the cost of injecting more noise into the system. It is well-known from optimization that second-order information is hard to estimate reliably in the stochastic setting. Most likely there exist heuristic choices of Q_0 (guided by the theory) that enable variance reduction without introducing additional noise variables and which may be more amenable to noisy estimation in the first place.
Projecting in preactivation space
p_t = W_t a_t \\ h_t = f (p_t)
such that \theta_t = \operatorname{vec} (W_t) and with a_t = \left(\begin{array}{ccc} h_{t - 1}^{\top} & x_t^{\top} & 1 \end{array}\right)^{\top} being the concatenated input to the RNN at time t, then the gradient with respect to W_t is just the outer product of
the gradient with respect to the preactivations p_t and the inputs a_t:
\nabla_{W_t} L = (\mathcal{J}^L_{p_t})^\top a_t^\top
The notation becomes a bit of a trainwreck when we switch to working with the vectorization \theta_t = \operatorname{vec}(W_t) in order to speak of the Jacobians \mathcal{J}^{h_t}_{\theta_t}. Let's switch to Kronecker products for \otimes' sake:
\mathcal{J}^{h_t}_{\theta_t} = \mathcal{J}^{h_t}_{p_t} (I \otimes a^{\top})
This is formally a matrix, but it's more naturally thought of as a third-order H \times P \times A tensor with elements \frac{\partial h_{t i}}{\partial p_{t j}} a_{t k}. (H, P, A are the dimensions of h_t, p_t and a_t respectively.)
The key observation is that this third-order tensor can be broken up into an outer product of an H \times P matrix \mathcal{J}^L_{p_t} and an A-dimensional vector a_t without any random projection. Plain UORO on the other hand would stochastically break it up into an outer product of an H-dimensional vector \mathcal{J}^{h_t}_{p_t} u_t and a vectorized P \times A matrix u_t^\top \mathcal{J}^{p_t}_{\theta_t} = u_t^\top (I \otimes a_t^\top) = \operatorname{vec}(u_t a_t^\top). Although the factorization without projection does not introduce extra variance, it does introduce extra computation, as we are now propagating a matrix again as in RTRL (albeit a smaller one).
Now the approximate Jacobian takes the form \tilde{H}_t (I \otimes \tilde{a}_t) \approx \mathcal{J}^{h_t}_{\theta}, with \tilde{H}_t and \tilde{a}_t maintained according to
\begin{align} \tilde{H}_t & = \gamma_t \mathcal{J}^{h_t}_{h_{t - 1}} \tilde{H}_{t - 1} + \beta_t \tau_t \mathcal{J}^{h_t}_{p_t} \\ \tilde{a}_t & = \gamma_t^{- 1} \tilde{a}_{t - 1} + \beta_t^{- 1} \tau_t a_t . \end{align}
These are similar to UORO's recursions, except that the Jacobian factors \mathcal{J}^{h_t}_{p_t} and I \otimes a_t^\top are not projected onto noise vectors u_t but rather multiplied by random signs \tau_t before being accumulated into their respective sums.
Each gradient \mathcal{J}^{L_t}_{\theta} is estimated by \mathcal{J}^{L_t}_{h_t} \tilde{H}_t (I \otimes \tilde{a}_t) = \operatorname{vec} ( (\mathcal{J}^{L_t}_{h_t} \tilde{H}_t)^\top \tilde{a}_t^\top), which can still be computed without explicitly forming \tilde{H}_t (I \otimes \tilde{a}_t^\top).
![]() |
Forward-propagation of noise in preactivation space |
![]() |
Accumulation of activations |
The group that discovered this algorithm around the same time (the aforementioned Approximating Real-Time Recurrent Learning with Random Kronecker Factors (Mujika et al 2018)) released a new paper a few days ago, Optimal Kronecker-Sum Approximation of Real Time Recurrent Learning (Benzing et al 2019). I haven't had time to read it in full, but it looks like they found a way to optimize the low-rank approximation without biasing the approximation.
A link to REINFORCE
h_t = F (\bar{h}_{t - 1}, x_t ; \theta_t) \\ \bar{h}_t = h_t + \sigma Q_t u_t
Here u_t \sim \mathcal{N} (0, I) is additive Gaussian noise, and \sigma determines the level of noise. The invertible matrix Q_t transforms the standard normal noise u_t and corresponds to a covariance matrix, but the reason it is included here is because it will end up playing the same role as the Q_t matrix discussed in the variance reduction section above. Effectively, the stochastic hidden state \bar{h}_t \sim \mathcal{N} (h_t, \sigma^2 Q_t Q_t^{\top}) is sampled from a Gaussian distribution centered on the deterministic hidden state h_t. We assume the loss L_t to be a differentiable function of \bar{h}_t.
![]() |
RNN with state perturbation |
To be clear, the sole purpose of injecting this noise is so that we may apply REINFORCE to estimate gradients through the system and compare these estimates to those of UORO. The stochastic transition distribution will be our policy from the REINFORCE perspective, which suggests actions \bar{h}_t given states h_t. We compute the REINFORCE estimator by running the stochastic RNN forward, thus sampling a trajectory of states \bar{h}_t, and at each step computing
L_t \nabla_{\theta} \log p (\bar{h}_t | \bar{h}_0, \bar{h}_1 \ldots \bar{h}_{t - 1} ; \theta) \approx \mathcal{J}^{L_t}_{\theta} .
The estimate consists of the loss L_t times the score function of the trajectory. Intuitively, higher rewards (equivalently, lower losses) "reinforce" directions in parameter space that bring them about.
The score function \bar{w}_t^{\top} = \nabla_{\theta} \log p (\bar{h}_t | \bar{h}_0, \bar{h}_1 \ldots \bar{h}_{t - 1} ; \theta) is maintained online according to
\bar{w}_t^{\top} = \bar{w}_{t - 1}^{\top} + \nabla_{\theta} \log p (\bar{h}_t | \bar{h}_{t - 1} ; \theta_t) = \bar{w}_{t - 1} + \frac{1}{\sigma} u_t^{\top} Q_t^{- 1} \mathcal{J}^{h_t}_{\theta_t},
which is analogous to \tilde{w} in UORO. An important difference is that the backward Jacobians \mathcal{J}^{h_t}_{\theta_t} are evaluated in the noisy system. In the paper we eliminate this difference by passing to the limit \sigma \rightarrow 0, which simulates the common practice of annealing the noise.
Besides \tilde{w}_t, UORO's estimate of \mathcal{J}^{L_t}_{h_t} \tilde{h}_t \tilde{w}_t^{\top} \approx \mathcal{J}^{L_t}_{\theta} involves \mathcal{J}^{L_t}_{h_t} \tilde{h}_t. In REINFORCE, the inner product of \bar{w}_t with this quantity is implicit in the multiplication by the loss. We can reveal it by taking the Taylor series of the loss around the point u = 0 where the noise is zero:
L_t = L_t |_{u = 0} + \left( \sum_{s \leqslant t} \mathcal{J}^{L_t}_{u_s} |_{u = 0} u_s \right) + \frac{1}{2} \left( \sum_{r \leqslant t} \sum_{s \leqslant t} u_r^{\top} \mathcal{H}^{L_t}_{u_r, u_s} |_{u = 0} u_s \right) + \cdots
Using the fact that derivatives with respect to u_s are directly related to derivatives with respect to h_s, namely
\mathcal{J}^{L_t}_{u_s} |_{u = 0} = \sigma \sum_{s \leqslant t} \mathcal{J}^{L_t}_{h_s} |_{u = 0} Q_s,
we may write
L_t = L_t |_{u = 0} + \sigma \mathcal{J}^{L_t}_{h_t} \left( \sum_{s \leqslant t} \mathcal{J}^{h_t}_{h_s} |_{u = 0} Q_s u_s \right) +\mathcal{O} (\sigma^2) .
Plugging these into the REINFORCE estimate L_t \nabla_{\theta} \log p (\bar{h}_t | \bar{h}_0, \bar{h}_1 \ldots \bar{h}_{t - 1} ; \theta), we get
L_t \bar{w}_t^{\top} = \frac{1}{\sigma} L_t |_{u = 0} \left( \sum_{s \leqslant t} u_s^{\top} Q_s^{- 1} \mathcal{J}^{h_s}_{\theta_s} \right) +\mathcal{J}^{L_t}_{h_t} \left( \sum_{s \leqslant t} \mathcal{J}^{h_t}_{h_s} |_{u = 0} Q_s u_s \right) \left( \sum_{s \leqslant t} u_s^{\top} Q_s^{- 1} \mathcal{J}^{h_s}_{\theta_s} \right) +\mathcal{O} (\sigma^2) .
When we pass to the limit \sigma \rightarrow 0, the second term becomes identical to the UORO estimate \mathcal{J}^{L_t}_{h_t} \tilde{h}_t \tilde{w}_t^{\top}. Note how the Q_t matrices that determine the covariance of the exploration noise in REINFORCE play the exact same role as our variance reduction matrices in UORO.
The first term, which is zero in expectation, contributes infinite variance in the limit. In effect, annealing the noise deteriorates the quality of REINFORCE's estimates. This first term is usually addressed by subtracting a "baseline" -- an estimate of L_t |_{u = 0} -- from the loss L_t before multiplying with the score.
Conclusions
We've delved deeply into UORO, contributing a straightforward derivation, a general theoretical framework and a thorough analysis. Our proposed variance reduction using these Q_t matrices is promising, although much work remains to be done. We've shown a variant of UORO that avoids the spatial level of stochastic approximation, thereby greatly reducing the variance at the cost of equally greatly increasing the time complexity. Finally, we have established a deep link between UORO and REINFORCE, which allows the interpretation of REINFORCE as an approximation to RTRL.
No comments:
Post a Comment