2025-02-03
To understand BPTT works, let’s derive its mathematical formulation.
Consider an RNN that processes a sequence of inputs \[( x_1, x_2, \ldots, x_T )\] At each time step \(t\), the RNN maintains a hidden state \(h_t\) which is updated based on the current input \(x_t\) and the previous hidden state \(h_{t-1}\):
\[h_t = f(W_h h_{t-1} + W_x x_t + b)\]
where \(W_h\) and \(W_x\) are weight matrices, \(b\) is a bias vector, and \(f\) is an activation function (typically tanh or ReLU).
Assume we have a loss function \(L\) that depends on the outputs of the RNN at each time step. The total loss over the sequence is:
\[L = \sum_{t=1}^T L_t(y_t, \hat{y}_t)\]
where \(y_t\) is the true output and \(\hat{y}_t\) is the predicted output at time step \(t\).
To train the RNN, we need to compute the gradients of the loss with respect to the weights \(W_h\) and \(W_x\). BPTT involves unfolding the RNN through time and applying the chain rule of calculus to compute these gradients.
Forward Pass: Compute the hidden states \(h_t\) and the outputs \(\hat{y}_t\) for \(t = 1, 2, \ldots, T\).
Backward Pass: Compute the gradients of the loss with respect to the hidden states and weights by propagating the error backwards through time.
The gradient of the loss with respect to the hidden state at time step \(t\) is:
\[\frac{\partial L}{\partial h_t} = \sum_{k=t}^T \frac{\partial L_k}{\partial h_t}\]
Using the chain rule, we can express this as:
\[\frac{\partial L_k}{\partial h_t} = \frac{\partial L_k}{\partial \hat{y}_k} \frac{\partial \hat{y}_k}{\partial h_k} \frac{\partial h_k}{\partial h_t}\]
The gradient of the hidden state \(h_k\) with respect to \(h_t\) involves the recurrent connection:
\[\frac{\partial h_k}{\partial h_t} = \prod_{j=t+1}^k \frac{\partial h_j}{\partial h_{j-1}}\]
Finally, the gradients of the loss with respect to the weights are computed by summing the contributions from each time step:
\[\frac{\partial L}{\partial W_h} = \sum_{t=1}^T \frac{\partial L}{\partial h_t} \frac{\partial h_t}{\partial W_h}\]
\[\frac{\partial L}{\partial W_x} = \sum_{t=1}^T \frac{\partial L}{\partial h_t} \frac{\partial h_t}{\partial W_x}\]