Skip to content

6.2 Training the Critic

In the previous section, we defined the advantage function A(s,a)δ=r+γV(s)V(s)A(s,a) \approx \delta = r + \gamma V(s') - V(s) and introduced the Critic network as an estimator of V(s)V(s). This section expands on the three classic value-estimation methods from Chapter 3 -- DP, MC, and TD -- and shows how each one trains the Critic in practice.

Prerequisites

We continue with the three-cell corridor from Chapter 3, using a fixed policy π\pi: at both SS and MM, move right with probability 0.8 and left with probability 0.2. The transitions and rewards are:

Current stateActionPolicy probNext stateReward
SSleft0.2SS2-2
SSright0.8MM1-1
MMleft0.2SS2-2
MMright0.8GG1-1
GGend1.0--00

We set γ=1\gamma=1. All three methods estimate the same value table; they differ only in where the update targets come from.

DP: A Theoretical Baseline

If we knew the full transition probabilities PP and reward function RR (recall the MDP 5-tuple), we could iterate the Critic directly using the Bellman expectation equation:

Vϕ(s)aπ(as)[R(s,a)+γsP(ss,a)Vϕ(s)]V_\phi(s) \leftarrow \sum_a \pi(a|s) \left[ R(s,a) + \gamma \sum_{s'} P(s'|s,a) V_\phi(s') \right]

Each symbol in this equation:

SymbolMeaning
Vϕ(s)V_\phi(s)Critic's current value estimate for state ss, with parameters ϕ\phi
aaAn action available at state ss (e.g., left, right)
π(as)\pi(a \mid s)Probability that the current policy selects action aa at state ss
R(s,a)R(s,a)Immediate reward received after taking action aa in state ss
ss'A possible next state after taking action aa
P(ss,a)P(s' \mid s,a)Probability of transitioning to ss' from state ss and action aa
Vϕ(s)V_\phi(s')Critic's current value estimate for next state ss'
γ\gammaDiscount factor

Expanding for state SS in the corridor. The outer sum weights over actions according to the policy; the inner sum weights over next states according to the transition probabilities. Since transitions are deterministic (moving right always goes right; moving left either hits the wall or goes back), the inner s\sum_{s'} has probability 1 only for the actual destination:

Vϕ(S)π(rightS)[R(S,right)+Vϕ(M)]+π(leftS)[R(S,left)+Vϕ(S)]=0.8[1+Vϕ(M)]+0.2[2+Vϕ(S)]\begin{aligned} V_\phi(S) &\leftarrow \pi(\text{right} \mid S)\left[R(S,\text{right}) + V_\phi(M)\right] + \pi(\text{left} \mid S)\left[R(S,\text{left}) + V_\phi(S)\right] \\ &= 0.8\left[-1 + V_\phi(M)\right] + 0.2\left[-2 + V_\phi(S)\right] \end{aligned}

Similarly for MM, where moving right reaches the terminal state GG and moving left returns to SS:

Vϕ(M)π(rightM)[R(M,right)+Vϕ(G)]+π(leftM)[R(M,left)+Vϕ(S)]=0.8[1+Vϕ(G)]+0.2[2+Vϕ(S)]\begin{aligned} V_\phi(M) &\leftarrow \pi(\text{right} \mid M)\left[R(M,\text{right}) + V_\phi(G)\right] + \pi(\text{left} \mid M)\left[R(M,\text{left}) + V_\phi(S)\right] \\ &= 0.8\left[-1 + V_\phi(G)\right] + 0.2\left[-2 + V_\phi(S)\right] \end{aligned}

By repeatedly applying this update to all states, VϕV_\phi converges to the exact VπV^\pi. Starting from an all-zero table, we substitute numbers round by round.

Round 1 -- the old table is all zeros, so the target reduces to the average immediate cost of each action:

V1(S)=0.8[1+V0(M)]+0.2[2+V0(S)]=0.8(1+0)+0.2(2+0)=0.80.4=1.2\begin{aligned} V_1(S) &= 0.8[-1 + V_0(M)] + 0.2[-2 + V_0(S)] \\ &= 0.8(-1 + 0) + 0.2(-2 + 0) = -0.8 - 0.4 = -1.2 \end{aligned}

V1(M)=0.8[1+V0(G)]+0.2[2+V0(S)]=0.8(1+0)+0.2(2+0)=0.80.4=1.2\begin{aligned} V_1(M) &= 0.8[-1 + V_0(G)] + 0.2[-2 + V_0(S)] \\ &= 0.8(-1 + 0) + 0.2(-2 + 0) = -0.8 - 0.4 = -1.2 \end{aligned}

Round 2 -- using the round-1 results as the old table:

V2(S)=0.8[1+V1(M)]+0.2[2+V1(S)]=0.8[1+(1.2)]+0.2[2+(1.2)]=0.8×(2.2)+0.2×(3.2)=1.760.64=2.4\begin{aligned} V_2(S) &= 0.8[-1 + V_1(M)] + 0.2[-2 + V_1(S)] \\ &= 0.8[-1 + (-1.2)] + 0.2[-2 + (-1.2)] \\ &= 0.8 \times (-2.2) + 0.2 \times (-3.2) = -1.76 - 0.64 = -2.4 \end{aligned}

V2(M)=0.8[1+V1(G)]+0.2[2+V1(S)]=0.8[1+0]+0.2[2+(1.2)]=0.8×(1)+0.2×(3.2)=0.80.64=1.44\begin{aligned} V_2(M) &= 0.8[-1 + V_1(G)] + 0.2[-2 + V_1(S)] \\ &= 0.8[-1 + 0] + 0.2[-2 + (-1.2)] \\ &= 0.8 \times (-1) + 0.2 \times (-3.2) = -0.8 - 0.64 = -1.44 \end{aligned}

Round 3 -- using the round-2 results as the old table:

V3(S)=0.8[1+V2(M)]+0.2[2+V2(S)]=0.8[1+(1.44)]+0.2[2+(2.4)]=0.8×(2.44)+0.2×(4.4)=1.9520.88=2.832\begin{aligned} V_3(S) &= 0.8[-1 + V_2(M)] + 0.2[-2 + V_2(S)] \\ &= 0.8[-1 + (-1.44)] + 0.2[-2 + (-2.4)] \\ &= 0.8 \times (-2.44) + 0.2 \times (-4.4) = -1.952 - 0.88 = -2.832 \end{aligned}

V3(M)=0.8[1+V2(G)]+0.2[2+V2(S)]=0.8(1+0)+0.2[2+(2.4)]=0.8+0.2×(4.4)=0.80.88=1.68\begin{aligned} V_3(M) &= 0.8[-1 + V_2(G)] + 0.2[-2 + V_2(S)] \\ &= 0.8(-1 + 0) + 0.2[-2 + (-2.4)] \\ &= -0.8 + 0.2 \times (-4.4) = -0.8 - 0.88 = -1.68 \end{aligned}

Summary of each round:

RoundV(S)V(S)V(M)V(M)V(G)V(G)
0000
1-1.2-1.20
2-2.4-1.440
3-2.832-1.680
converged-3.375-1.8750

In each round, the values for SS and MM encode the "average consequence of acting under the current policy" -- moving right is generally better, but the policy occasionally moves left, and the costs of detours and wall bumps must also enter the value table.

On this basis, we can also perform policy improvement -- at state ss, choose the action that maximizes Q(s,a)Q(s,a) (recall the greedy optimal policy). The loop "evaluate the policy \rightarrow improve the policy \rightarrow evaluate again" is exactly Policy Iteration, which is guaranteed to converge to an optimal policy.

In real-world problems, however, it is almost never feasible to know the complete PP and RR. DP's role in Actor-Critic is primarily a theoretical baseline -- it tells you the Critic's optimal answer when everything is known.

MC: Update the Critic Using Complete Trajectories

Monte Carlo (MC) updates wait until a complete episode finishes, then use the actual return GtG_t to train the Critic. The Critic loss is a mean squared error:

LCritic=(GtVϕ(s))2(6.3)L_{\text{Critic}} = \left( G_t - V_\phi(s) \right)^2 \tag{6.3}

Each symbol in this equation:

SymbolMeaning
LCriticL_{\text{Critic}}Critic loss, measuring the prediction error
GtG_tActual discounted return from time step tt to the end of the episode (the MC target)
Vϕ(s)V_\phi(s)Critic's current value prediction for state ss

GtVϕ(s)G_t - V_\phi(s) is the Critic's prediction error -- the episode actually returned GtG_t, but the Critic previously predicted Vϕ(s)V_\phi(s). The loss is the square of this error.

Numerical Example

Suppose we sample the following trajectory:

S2S1M2S1M1GS \xrightarrow{-2} S \xrightarrow{-1} M \xrightarrow{-2} S \xrightarrow{-1} M \xrightarrow{-1} G

With γ=1\gamma=1, we compute GtG_t by summing the rewards from each visit position to the end:

VisitStateRemaining rewardsGtG_t computationMC target GtG_t
1SS2,1,2,1,1-2,-1,-2,-1,-12+(1)+(2)+(1)+(1)-2 + (-1) + (-2) + (-1) + (-1)7-7
2SS1,2,1,1-1,-2,-1,-11+(2)+(1)+(1)-1 + (-2) + (-1) + (-1)5-5
3MM2,1,1-2,-1,-12+(1)+(1)-2 + (-1) + (-1)4-4
4SS1,1-1,-11+(1)-1 + (-1)2-2
5MM1-11-11-1

Loss Computation and Gradient Update

Assume the Critic is a simple value table with V(S)=0V(S) = 0 and V(M)=0V(M) = 0. Using the first visit to SS as an example, the MC target is Gt=7G_t = -7:

L=(GtV(S))2=(70)2=49L = (G_t - V(S))^2 = (-7 - 0)^2 = 49

The gradient-descent update (learning rate α=0.5\alpha = 0.5):

V(S)V(S)αLV(S)=V(S)α2(V(S)Gt)V(S) \leftarrow V(S) - \alpha \cdot \frac{\partial L}{\partial V(S)} = V(S) - \alpha \cdot 2(V(S) - G_t)

Here LV(S)=2(V(S)Gt)=2(0(7))=14\frac{\partial L}{\partial V(S)} = 2(V(S) - G_t) = 2(0 - (-7)) = 14, but it is more common to absorb the 12\frac{1}{2} into the learning rate and write directly:

V(S)V(S)+α(GtV(S))=0+0.5×(70)=3.5V(S) \leftarrow V(S) + \alpha (G_t - V(S)) = 0 + 0.5 \times (-7 - 0) = -3.5

The complete update process across all visits:

Updated stateMC target GtG_tOld valueUpdate computationNew value
1st SS7-700+0.5×(70)=3.50 + 0.5 \times (-7 - 0) = -3.53.5-3.5
2nd SS5-53.5-3.53.5+0.5×[5(3.5)]=3.5+0.5×(1.5)=4.25-3.5 + 0.5 \times [-5 - (-3.5)] = -3.5 + 0.5 \times (-1.5) = -4.254.25-4.25
1st MM4-400+0.5×(40)=20 + 0.5 \times (-4 - 0) = -22-2
3rd SS2-24.25-4.254.25+0.5×[2(4.25)]=4.25+0.5×2.25=3.125-4.25 + 0.5 \times [-2 - (-4.25)] = -4.25 + 0.5 \times 2.25 = -3.1253.125-3.125
2nd MM1-12-22+0.5×[1(2)]=2+0.5×1=1.5-2 + 0.5 \times [-1 - (-2)] = -2 + 0.5 \times 1 = -1.51.5-1.5

MC methods (recall the MC value update: V(s)V(s)+α[GtV(s)]V(s) \leftarrow V(s) + \alpha[G_t - V(s)]) provide an unbiased estimate because they use the true return, but they have two limitations:

  1. You must wait until the episode ends to compute GtG_t; you cannot learn online step by step.
  2. High variance -- GtG_t can fluctuate drastically across different episodes.

In a neural-network implementation, the MC method is equivalent to: run one full episode, collect all (st,Gt)(s_t, G_t) pairs, then perform a gradient-descent update on the Critic parameters ϕ\phi using this batch.

TD: One-Step Updates

Temporal Difference (TD) learning updates the Critic using the TD Error. The Critic loss is:

LCritic=(r+γVϕ(s)Vϕ(s))2=δ2(6.4)L_{\text{Critic}} = \left( r + \gamma V_\phi(s') - V_\phi(s) \right)^2 = \delta^2 \tag{6.4}

Each symbol in this equation:

SymbolMeaning
LCriticL_{\text{Critic}}Critic loss, measuring the magnitude of the TD Error
rrImmediate reward received at the current step
γ\gammaDiscount factor
Vϕ(s)V_\phi(s')Critic's current value prediction for next state ss'
Vϕ(s)V_\phi(s)Critic's current value prediction for current state ss
δ\deltaTD Error, i.e., r+γVϕ(s)Vϕ(s)r + \gamma V_\phi(s') - V_\phi(s)

Minimizing δ2\delta^2 makes the Critic's predictions progressively more accurate. The meaning of δ\delta: after taking one step, the difference between "the reward actually received plus the next-step prediction" and "the current prediction." δ>0\delta > 0 means this step was better than expected; δ<0\delta < 0 means it was worse.

Numerical Example

Using the same trajectory as MC:

S2S1M2S1M1GS \xrightarrow{-2} S \xrightarrow{-1} M \xrightarrow{-2} S \xrightarrow{-1} M \xrightarrow{-1} G

The initial value table is all zeros, with learning rate α=0.5\alpha = 0.5. TD updates after every step, reading from the current latest table.

Step 1: S2SS \xrightarrow{-2} S. Current V(S)=0V(S) = 0, V(S)=V(S)=0V(S') = V(S) = 0.

δ=r+γV(s)V(s)=2+1×00=2\delta = r + \gamma V(s') - V(s) = -2 + 1 \times 0 - 0 = -2

V(S)V(S)+αδ=0+0.5×(2)=1V(S) \leftarrow V(S) + \alpha \cdot \delta = 0 + 0.5 \times (-2) = -1

Step 2: S1MS \xrightarrow{-1} M. Current V(S)=1V(S) = -1 (updated in the previous step), V(M)=0V(M) = 0.

δ=1+1×0(1)=1+0+1=0\delta = -1 + 1 \times 0 - (-1) = -1 + 0 + 1 = 0

V(S)1+0.5×0=1V(S) \leftarrow -1 + 0.5 \times 0 = -1

δ=0\delta = 0 means "received 1-1 and arrived at V(M)=0V(M) = 0" exactly equals the previous estimate of V(S)=1V(S) = -1 -- the prediction had no error.

Step 3: M2SM \xrightarrow{-2} S. Current V(M)=0V(M) = 0, V(S)=1V(S) = -1.

δ=2+1×(1)0=3\delta = -2 + 1 \times (-1) - 0 = -3

V(M)0+0.5×(3)=1.5V(M) \leftarrow 0 + 0.5 \times (-3) = -1.5

Note that V(S)=1V(S) = -1 here is the value just updated in step 1 -- TD immediately uses freshly learned information.

Step 4: S1MS \xrightarrow{-1} M. Current V(S)=1V(S) = -1, V(M)=1.5V(M) = -1.5.

δ=1+1×(1.5)(1)=1.5\delta = -1 + 1 \times (-1.5) - (-1) = -1.5

V(S)1+0.5×(1.5)=1+(0.75)=1.75V(S) \leftarrow -1 + 0.5 \times (-1.5) = -1 + (-0.75) = -1.75

Step 5: M1GM \xrightarrow{-1} G. Current V(M)=1.5V(M) = -1.5, V(G)=0V(G) = 0.

δ=1+1×0(1.5)=0.5\delta = -1 + 1 \times 0 - (-1.5) = 0.5

V(M)1.5+0.5×0.5=1.5+0.25=1.25V(M) \leftarrow -1.5 + 0.5 \times 0.5 = -1.5 + 0.25 = -1.25

δ=0.5>0\delta = 0.5 > 0 indicates that moving right from MM to the terminal was better than the current estimate of V(M)V(M), so V(M)V(M) is adjusted upward.

Step-by-step Summary

StepTransitionUpdated stateOld V(s)V(s)rrV(s)V(s')TD target r+γV(s)r + \gamma V(s')δ\deltaNew V(s)V(s)
1S2SS \xrightarrow{-2} SSS02-202+0=2-2 + 0 = -22-21-1
2S1MS \xrightarrow{-1} MSS1-11-101+0=1-1 + 0 = -1001-1
3M2SM \xrightarrow{-2} SMM02-21-12+(1)=3-2 + (-1) = -33-31.5-1.5
4S1MS \xrightarrow{-1} MSS1-11-11.5-1.51+(1.5)=2.5-1 + (-1.5) = -2.51.5-1.51.75-1.75
5M1GM \xrightarrow{-1} GMM1.5-1.51-101+0=1-1 + 0 = -10.50.51.25-1.25

TD Loss Computation

Using step 3 as an example, δ=3\delta = -3:

L=δ2=(3)2=9L = \delta^2 = (-3)^2 = 9

The gradient-descent direction:

LV(M)=2δ=2×(3)=6\frac{\partial L}{\partial V(M)} = -2\delta = -2 \times (-3) = 6

The parameter moves in the direction of LV(M)-\frac{\partial L}{\partial V(M)}, i.e., V(M)V(M) decreases. In practice, the update is equivalently V(M)V(M)+αδV(M) \leftarrow V(M) + \alpha \cdot \delta, consistent with the table above.

The advantages of TD methods (recall the TD(0) update: V(s)V(s)+α[r+γV(s)V(s)]V(s) \leftarrow V(s) + \alpha[r + \gamma V(s') - V(s)]):

  1. No need to wait for the episode to end -- you can update at every step.
  2. Lower variance -- Vϕ(s)V_\phi(s') acts as an "anchor" that stabilizes the estimate.
  3. Matches the Actor's update cadence -- both update once per environment step.

The price is introducing bias: Vϕ(s)V_\phi(s') is itself an estimate, not the true value. This is called bootstrapping -- using your own estimates to update your own estimates. In practice, however, this bias is far smaller than the benefit gained from reducing variance.

Comparing the Three Methods

DPMCTD
Used to train Critic?Theoretical baselineUsablePractical default
Need episode to end?NoYesNo
Unbiased?YesYesNo (biased but lower variance)
VarianceLowHighMedium
BootstrappingYesNoYes

MC vs. TD: A Numerical Comparison

Same trajectory S2S1M2S1M1GS \xrightarrow{-2} S \xrightarrow{-1} M \xrightarrow{-2} S \xrightarrow{-1} M \xrightarrow{-1} G, initial table all zeros, α=0.5\alpha = 0.5, γ=1\gamma = 1.

MC -- updates only after the entire episode ends. At the first visit to SS, the target is the complete return over the whole trajectory:

G0=(2)+(1)+(2)+(1)+(1)=7G_0 = (-2) + (-1) + (-2) + (-1) + (-1) = -7

V(S)0+0.5×(70)=3.5V(S) \leftarrow 0 + 0.5 \times (-7 - 0) = -3.5

MC uses all information from start to finish in a single update.

TD -- updates immediately after the first step. Step 1 uses only one-step information:

δ=2+V(S)V(S)=2+00=2\delta = -2 + V(S) - V(S) = -2 + 0 - 0 = -2

V(S)0+0.5×(2)=1V(S) \leftarrow 0 + 0.5 \times (-2) = -1

The TD target (2-2) is much smaller in magnitude than the MC target (7-7), but TD does not need to wait for the episode to end. As more trajectories accumulate, TD's V(S)V(S) also gradually approaches the true value of 3.375-3.375.

Both methods eventually converge to the same VπV^\pi, but their update paths differ: MC makes large single updates (3.5-3.5) with high variance; TD makes small updates (1-1) but more frequently, with lower variance.

In practice, Actor-Critic methods almost always use TD to train the Critic. In more advanced implementations (e.g., GAE in Chapter 7), MC and TD are combined -- a parameter λ\lambda interpolates between them to achieve an optimal bias-variance tradeoff.

The Full Critic-Training Workflow

Putting the pieces together, a one-step Actor-Critic training loop looks like this:

  1. Interact: At state ss, the Actor selects action aa; the environment returns rr and ss'.
  2. Forward pass: The Critic computes the current prediction Vϕ(s)V_\phi(s) and the next-step prediction Vϕ(s)V_\phi(s').
  3. Compute TD Error: δ=r+γVϕ(s)Vϕ(s)\delta = r + \gamma V_\phi(s') - V_\phi(s).
  4. Update Critic: Update the Critic parameters ϕ\phi using δ2\delta^2 as the loss.
  5. Update Actor: Update the Actor parameters θ\theta using δ\delta as the advantage estimate.

Numerical Walkthrough

Assume the current Critic value table is V(S)=1V(S) = -1, V(M)=0.5V(M) = -0.5, V(G)=0V(G) = 0, with γ=0.9\gamma = 0.9, Critic learning rate αϕ=0.1\alpha_\phi = 0.1, and Actor learning rate αθ=0.01\alpha_\theta = 0.01.

Step 1: Interact

At state SS, the Actor chooses right with probability 0.8 and left with probability 0.2. Suppose this sample picks right; the environment returns r=1r = -1, s=Ms' = M.

Step 2: Forward Pass

Vϕ(S)=1,Vϕ(M)=0.5V_\phi(S) = -1, \quad V_\phi(M) = -0.5

Step 3: Compute TD Error

δ=r+γVϕ(s)Vϕ(s)=1+0.9×(0.5)(1)=1+(0.45)+1=0.45\delta = r + \gamma V_\phi(s') - V_\phi(s) = -1 + 0.9 \times (-0.5) - (-1) = -1 + (-0.45) + 1 = -0.45

δ=0.45<0\delta = -0.45 < 0 indicates that moving right from SS to MM was worse than the current prediction -- actually receiving 1-1 plus MM's estimate of 0.45-0.45 totals 1.45-1.45, which is lower than SS's estimate of 1-1.

Step 4: Update Critic

LCritic=δ2=(0.45)2=0.2025L_{\text{Critic}} = \delta^2 = (-0.45)^2 = 0.2025

Parameter update (using a value table as an example):

V(S)V(S)+αϕδ=1+0.1×(0.45)=1+(0.045)=1.045V(S) \leftarrow V(S) + \alpha_\phi \cdot \delta = -1 + 0.1 \times (-0.45) = -1 + (-0.045) = -1.045

The Critic lowered V(S)V(S) -- this experience suggests SS's value is lower than previously estimated.

Step 5: Update Actor

δ=0.45\delta = -0.45 indicates that this action (moving right) performed worse than expected. The Actor's update direction is to decrease the probability of this action. Using the policy gradient as an example:

θθ+αθδθlogπ(rightS)\theta \leftarrow \theta + \alpha_\theta \cdot \delta \cdot \nabla_\theta \log \pi(\text{right} \mid S)

Since δ<0\delta < 0, the parameters move opposite to θlogπ(rightS)\nabla_\theta \log \pi(\text{right} \mid S), reducing the probability π(rightS)\pi(\text{right} \mid S).

If δ>0\delta > 0, the action was better than expected, and the Actor increases its probability.

The Critic parameters ϕ\phi update in the direction that makes δ2\delta^2 smaller -- predictions become more accurate. The Actor parameters θ\theta update in the direction that assigns higher probability to actions with positive δ\delta -- choices become better. This creates a virtuous cycle: the more accurate the Critic's evaluation, the faster the Actor improves; the more diverse actions the Actor tries, the richer the data the Critic sees, and the more accurate its evaluation becomes.

References

现代强化学习实战课程