跳转到正文

6.2 Critic 训练

上一节定义了优势函数 A(s,a)δ=r+γV(s)V(s)A(s,a) \approx \delta = r + \gamma V(s') - V(s),并引出了 Critic 网络作为 V(s)V(s) 的估计器。本节展开第 3 章速览过的 DP、MC、TD 三种方法在 Critic 训练中的具体实现。

本节会用到的前置知识

沿用第 3 章的三格走廊环境,固定策略 π\pi:在 SSMM 均以 0.8 概率右走、0.2 概率左走。环境转移和奖励如下:

当前状态动作策略概率下一状态奖励
$S$左走0.2$S$$-2$
$S$右走0.8$M$$-1$
$M$左走0.2$S$$-2$
$M$右走0.8$G$$-1$
$G$结束1.0$0$

γ=1\gamma=1。三种方法都估计同一张价值表,区别只在更新目标的来源。

DP:理论基准

如果完全知道环境的转移概率 PP 和奖励函数 RR(回顾:MDP 五元组),可以直接用贝尔曼期望方程迭代 Critic:

Vϕ(s)aπ(as)[R(s,a)+γsP(ss,a)Vϕ(s)]V_\phi(s) \leftarrow \sum_a \pi(a|s) \left[ R(s,a) + \gamma \sum_{s'} P(s'|s,a) V_\phi(s') \right]

这个式子的每个符号含义如下:

符号含义
$V_\phi(s)$Critic 对状态 ss 的当前价值估计,参数为 $\phi$
$a$在状态 ss 可以选择的动作(如左走、右走)
$\pi(a \mid s)$当前策略在状态 ss 选择动作 aa 的概率
$R(s,a)$在状态 ss 执行动作 aa 后获得的即时奖励
$s'$执行动作 aa 后可能到达的下一状态
$P(s' \mid s,a)$在状态 ss 做动作 aa 后转移到 ss' 的概率
$V_\phi(s')$Critic 对下一状态 ss' 的当前价值估计
$\gamma$折扣因子,决定下一状态价值打多少折扣

对走廊的 SS 展开。外层按策略对动作加权,内层按转移概率对下一状态加权。由于转移确定(右走必到右边、左走必撞墙或退回),内层 s\sum_{s'} 只有真正到达的下一状态概率为 1:

Vϕ(S)π(S)[R(S,)+Vϕ(M)]+π(S)[R(S,)+Vϕ(S)]=0.8[1+Vϕ(M)]+0.2[2+Vϕ(S)]\begin{aligned} V_\phi(S) &\leftarrow \pi(\text{右} \mid S)\left[R(S,\text{右}) + V_\phi(M)\right] + \pi(\text{左} \mid S)\left[R(S,\text{左}) + V_\phi(S)\right] \\ &= 0.8\left[-1 + V_\phi(M)\right] + 0.2\left[-2 + V_\phi(S)\right] \end{aligned}

MM 同理,右走到终点 GG、左走退回 SS

Vϕ(M)π(M)[R(M,)+Vϕ(G)]+π(M)[R(M,)+Vϕ(S)]=0.8[1+Vϕ(G)]+0.2[2+Vϕ(S)]\begin{aligned} V_\phi(M) &\leftarrow \pi(\text{右} \mid M)\left[R(M,\text{右}) + V_\phi(G)\right] + \pi(\text{左} \mid M)\left[R(M,\text{左}) + V_\phi(S)\right] \\ &= 0.8\left[-1 + V_\phi(G)\right] + 0.2\left[-2 + V_\phi(S)\right] \end{aligned}

反复对所有状态执行这个更新,VϕV_\phi 会收敛到 VπV^\pi 的精确值。下面从全 0 的初始表开始,逐轮代入数字。

第 1 轮——旧表全为 0,目标只剩下眼前动作代价的平均:

V1(S)=0.8[1+V0(M)]+0.2[2+V0(S)]=0.8(1+0)+0.2(2+0)=0.80.4=1.2\begin{aligned} V_1(S) &= 0.8[-1 + V_0(M)] + 0.2[-2 + V_0(S)] \\ &= 0.8(-1 + 0) + 0.2(-2 + 0) = -0.8 - 0.4 = -1.2 \end{aligned}

V1(M)=0.8[1+V0(G)]+0.2[2+V0(S)]=0.8(1+0)+0.2(2+0)=0.80.4=1.2\begin{aligned} V_1(M) &= 0.8[-1 + V_0(G)] + 0.2[-2 + V_0(S)] \\ &= 0.8(-1 + 0) + 0.2(-2 + 0) = -0.8 - 0.4 = -1.2 \end{aligned}

第 2 轮——把第 1 轮结果作为旧表:

V2(S)=0.8[1+V1(M)]+0.2[2+V1(S)]=0.8[1+(1.2)]+0.2[2+(1.2)]=0.8×(2.2)+0.2×(3.2)=1.760.64=2.4\begin{aligned} V_2(S) &= 0.8[-1 + V_1(M)] + 0.2[-2 + V_1(S)] \\ &= 0.8[-1 + (-1.2)] + 0.2[-2 + (-1.2)] \\ &= 0.8 \times (-2.2) + 0.2 \times (-3.2) = -1.76 - 0.64 = -2.4 \end{aligned}

V2(M)=0.8[1+V1(G)]+0.2[2+V1(S)]=0.8[1+0]+0.2[2+(1.2)]=0.8×(1)+0.2×(3.2)=0.80.64=1.44\begin{aligned} V_2(M) &= 0.8[-1 + V_1(G)] + 0.2[-2 + V_1(S)] \\ &= 0.8[-1 + 0] + 0.2[-2 + (-1.2)] \\ &= 0.8 \times (-1) + 0.2 \times (-3.2) = -0.8 - 0.64 = -1.44 \end{aligned}

第 3 轮——把第 2 轮结果作为旧表:

V3(S)=0.8[1+V2(M)]+0.2[2+V2(S)]=0.8[1+(1.44)]+0.2[2+(2.4)]=0.8×(2.44)+0.2×(4.4)=1.9520.88=2.832\begin{aligned} V_3(S) &= 0.8[-1 + V_2(M)] + 0.2[-2 + V_2(S)] \\ &= 0.8[-1 + (-1.44)] + 0.2[-2 + (-2.4)] \\ &= 0.8 \times (-2.44) + 0.2 \times (-4.4) = -1.952 - 0.88 = -2.832 \end{aligned}

V3(M)=0.8[1+V2(G)]+0.2[2+V2(S)]=0.8(1+0)+0.2[2+(2.4)]=0.8+0.2×(4.4)=0.80.88=1.68\begin{aligned} V_3(M) &= 0.8[-1 + V_2(G)] + 0.2[-2 + V_2(S)] \\ &= 0.8(-1 + 0) + 0.2[-2 + (-2.4)] \\ &= -0.8 + 0.2 \times (-4.4) = -0.8 - 0.88 = -1.68 \end{aligned}

汇总每轮结果:

轮次$V(S)$$V(M)$$V(G)$
0000
1-1.2-1.20
2-2.4-1.440
3-2.832-1.680
收敛-3.375-1.8750

每轮更新中,SSMM 的值都包含了"按当前策略行动的平均后果"——右走通常更好,但策略偶尔会左走,绕路和撞墙的代价也必须进入价值表。

在这个基础上,还可以进行策略改进——在状态 ss 选择让 Q(s,a)Q(s,a) 最大的动作(回顾:贪心最优策略)。"评估策略 → 改进策略 → 再评估"的循环就是策略迭代(Policy Iteration),理论上保证收敛到最优策略。

但在真实问题中,几乎不可能知道完整的 PPRR。DP 在 Actor-Critic 中的角色更多是理论基准——它告诉你"知道一切时 Critic 的最优答案"。

MC:用完整轨迹更新 Critic

跑完一个完整的 episode,用实际回报 GtG_t 来更新 Critic。Critic 的损失函数是均方误差:

LCritic=(GtVϕ(s))2(6.3)L_{\text{Critic}} = \left( G_t - V_\phi(s) \right)^2 \tag{6.3}

这个式子中每个符号的含义:

符号含义
$L_{\text{Critic}}$Critic 的损失函数,衡量预测偏差的大小
$G_t$从时刻 tt 开始到 episode 结束的实际折扣回报(MC 目标)
$V_\phi(s)$Critic 对状态 ss 的当前价值预测

GtVϕ(s)G_t - V_\phi(s) 是 Critic 的预测误差——实际拿了 GtG_t 分,但之前预测是 Vϕ(s)V_\phi(s) 分。损失是这个误差的平方。

具体数值例子

假设采样到一条轨迹:

S2S1M2S1M1GS \xrightarrow{-2} S \xrightarrow{-1} M \xrightarrow{-2} S \xrightarrow{-1} M \xrightarrow{-1} G

γ=1\gamma=1,从每次访问位置到终点,倒着累加得到 GtG_t

访问位置状态后续奖励GtG_t 的计算MC 目标 $G_t$
第 1 步$S$$-2,-1,-2,-1,-1$$-2 + (-1) + (-2) + (-1) + (-1)$$-7$
第 2 步$S$$-1,-2,-1,-1$$-1 + (-2) + (-1) + (-1)$$-5$
第 3 步$M$$-2,-1,-1$$-2 + (-1) + (-1)$$-4$
第 4 步$S$$-1,-1$$-1 + (-1)$$-2$
第 5 步$M$$-1$$-1$$-1$

损失计算与梯度更新

假设 Critic 是一张简单的价值表,当前 V(S)=0V(S) = 0V(M)=0V(M) = 0。以第 1 次访问 SS 为例,MC 目标 Gt=7G_t = -7

L=(GtV(S))2=(70)2=49L = (G_t - V(S))^2 = (-7 - 0)^2 = 49

梯度下降更新(学习率 α=0.5\alpha = 0.5):

V(S)V(S)αLV(S)=V(S)α2(V(S)Gt)V(S) \leftarrow V(S) - \alpha \cdot \frac{\partial L}{\partial V(S)} = V(S) - \alpha \cdot 2(V(S) - G_t)

这里 LV(S)=2(V(S)Gt)=2(0(7))=14\frac{\partial L}{\partial V(S)} = 2(V(S) - G_t) = 2(0 - (-7)) = 14,但更常见的是将 12\frac{1}{2} 吸收进学习率,直接写为

V(S)V(S)+α(GtV(S))=0+0.5×(70)=3.5V(S) \leftarrow V(S) + \alpha (G_t - V(S)) = 0 + 0.5 \times (-7 - 0) = -3.5

逐次访问的完整更新过程如下:

被更新的状态MC 目标 $G_t$旧值更新计算新值
第 1 次 $S$$-7$0$0 + 0.5 \times (-7 - 0) = -3.5$$-3.5$
第 2 次 $S$$-5$$-3.5$$-3.5 + 0.5 \times [-5 - (-3.5)] = -3.5 + 0.5 \times (-1.5) = -4.25$$-4.25$
第 1 次 $M$$-4$0$0 + 0.5 \times (-4 - 0) = -2$$-2$
第 3 次 $S$$-2$$-4.25$$-4.25 + 0.5 \times [-2 - (-4.25)] = -4.25 + 0.5 \times 2.25 = -3.125$$-3.125$
第 2 次 $M$$-1$$-2$$-2 + 0.5 \times [-1 - (-2)] = -2 + 0.5 \times 1 = -1.5$$-1.5$

MC 方法(回顾:MC 价值更新 V(s)V(s)+α[GtV(s)]V(s) \leftarrow V(s) + \alpha[G_t - V(s)])给出无偏估计(用的是真实回报),但有两个限制:

  1. 必须等 episode 结束才能计算 GtG_t,不能边走边学
  2. 方差大——不同 episode 的 GtG_t 波动剧烈

在神经网络实现中,MC 方法等价于:跑完一个 episode,收集所有 (st,Gt)(s_t, G_t) 对,然后用这些数据做一次梯度下降更新 Critic 的参数 ϕ\phi

TD:单步更新

TD Error 来更新 Critic。Critic 的损失函数是:

LCritic=(r+γVϕ(s)Vϕ(s))2=δ2(6.4)L_{\text{Critic}} = \left( r + \gamma V_\phi(s') - V_\phi(s) \right)^2 = \delta^2 \tag{6.4}

这个式子中每个符号的含义:

符号含义
$L_{\text{Critic}}$Critic 的损失函数,衡量 TD Error 的大小
$r$当前步获得的即时奖励
$\gamma$折扣因子
$V_\phi(s')$Critic 对下一状态 ss' 的当前价值预测
$V_\phi(s)$Critic 对当前状态 ss 的当前价值预测
$\delta$TD Error,即 $r + \gamma V_\phi(s') - V_\phi(s)$

最小化 δ2\delta^2 就是让 Critic 的预测越来越准确。δ\delta 的含义是:走了一步之后,"实际拿到的奖励 + 下一步预测"与"当前预测"之间的差。δ>0\delta > 0 表示这一步比预期好,δ<0\delta < 0 表示比预期差。

具体数值例子

使用与 MC 相同的轨迹:

S2S1M2S1M1GS \xrightarrow{-2} S \xrightarrow{-1} M \xrightarrow{-2} S \xrightarrow{-1} M \xrightarrow{-1} G

初始价值表全为 0,学习率 α=0.5\alpha = 0.5。TD 每走一步就更新一次,读取的是当前最新的表

第 1 步S2SS \xrightarrow{-2} S。当前 V(S)=0V(S) = 0V(S)=V(S)=0V(S') = V(S) = 0

δ=r+γV(s)V(s)=2+1×00=2\delta = r + \gamma V(s') - V(s) = -2 + 1 \times 0 - 0 = -2

V(S)V(S)+αδ=0+0.5×(2)=1V(S) \leftarrow V(S) + \alpha \cdot \delta = 0 + 0.5 \times (-2) = -1

第 2 步S1MS \xrightarrow{-1} M。当前 V(S)=1V(S) = -1(已被上一步更新),V(M)=0V(M) = 0

δ=1+1×0(1)=1+0+1=0\delta = -1 + 1 \times 0 - (-1) = -1 + 0 + 1 = 0

V(S)1+0.5×0=1V(S) \leftarrow -1 + 0.5 \times 0 = -1

δ=0\delta = 0 说明"拿了 1-1 然后到了 V(M)=0V(M) = 0 的状态"恰好等于之前对 V(S)V(S) 的估计 1-1,预测没有偏差。

第 3 步M2SM \xrightarrow{-2} S。当前 V(M)=0V(M) = 0V(S)=1V(S) = -1

δ=2+1×(1)0=3\delta = -2 + 1 \times (-1) - 0 = -3

V(M)0+0.5×(3)=1.5V(M) \leftarrow 0 + 0.5 \times (-3) = -1.5

注意这里 V(S)=1V(S) = -1 是第 1 步刚更新过的值——TD 立刻把刚学到的信息拿来用了。

第 4 步S1MS \xrightarrow{-1} M。当前 V(S)=1V(S) = -1V(M)=1.5V(M) = -1.5

δ=1+1×(1.5)(1)=1.5\delta = -1 + 1 \times (-1.5) - (-1) = -1.5

V(S)1+0.5×(1.5)=1+(0.75)=1.75V(S) \leftarrow -1 + 0.5 \times (-1.5) = -1 + (-0.75) = -1.75

第 5 步M1GM \xrightarrow{-1} G。当前 V(M)=1.5V(M) = -1.5V(G)=0V(G) = 0

δ=1+1×0(1.5)=0.5\delta = -1 + 1 \times 0 - (-1.5) = 0.5

V(M)1.5+0.5×0.5=1.5+0.25=1.25V(M) \leftarrow -1.5 + 0.5 \times 0.5 = -1.5 + 0.25 = -1.25

δ=0.5>0\delta = 0.5 > 0,说明从 MM 右走到终点的体验比 V(M)V(M) 当前估计要好,V(M)V(M) 因此上调。

逐步汇总表

步骤实际发生的一步被更新的状态旧 $V(s)$$r$$V(s')$TD 目标 $r + \gamma V(s')$$\delta$新 $V(s)$
1$S \xrightarrow{-2} S$$S$0$-2$0$-2 + 0 = -2$$-2$$-1$
2$S \xrightarrow{-1} M$$S$$-1$$-1$0$-1 + 0 = -1$$0$$-1$
3$M \xrightarrow{-2} S$$M$0$-2$$-1$$-2 + (-1) = -3$$-3$$-1.5$
4$S \xrightarrow{-1} M$$S$$-1$$-1$$-1.5$$-1 + (-1.5) = -2.5$$-1.5$$-1.75$
5$M \xrightarrow{-1} G$$M$$-1.5$$-1$0$-1 + 0 = -1$$0.5$$-1.25$

TD 损失计算

以第 3 步为例,δ=3\delta = -3

L=δ2=(3)2=9L = \delta^2 = (-3)^2 = 9

梯度下降更新方向:

LV(M)=2δ=2×(3)=6\frac{\partial L}{\partial V(M)} = -2\delta = -2 \times (-3) = 6

参数沿 LV(M)-\frac{\partial L}{\partial V(M)} 方向移动,即 V(M)V(M) 下降。实际更新中等效为 V(M)V(M)+αδV(M) \leftarrow V(M) + \alpha \cdot \delta,与上表一致。

TD 方法(回顾:TD(0) 更新 V(s)V(s)+α[r+γV(s)V(s)]V(s) \leftarrow V(s) + \alpha[r + \gamma V(s') - V(s)])的优势:

  1. 不需要等 episode 结束——每走一步就能更新
  2. 方差低——Vϕ(s)V_\phi(s') 作为"锚点"稳定了估计
  3. 与 Actor 的更新节奏一致——两者都是走一步更新一次

代价是引入了偏差Vϕ(s)V_\phi(s') 本身也是一个估计值,不是真实的价值。这叫做自举(Bootstrapping)——用自己的估计来更新自己的估计。但实际中,这个偏差远小于方差降低带来的好处。

三种方法的对比

DPMCTD
用于 Critic 训练?理论基准可以用实际首选
需要 episode 结束?不需要需要不需要
无偏?否(有偏但方差低)
方差
自举

MC 与 TD 的数值对比

同一条轨迹 S2S1M2S1M1GS \xrightarrow{-2} S \xrightarrow{-1} M \xrightarrow{-2} S \xrightarrow{-1} M \xrightarrow{-1} G,初始表全 0,α=0.5\alpha = 0.5γ=1\gamma = 1

MC——等整局结束后才更新。第 1 次访问 SS 时目标为整条轨迹的完整回报:

G0=(2)+(1)+(2)+(1)+(1)=7G_0 = (-2) + (-1) + (-2) + (-1) + (-1) = -7

V(S)0+0.5×(70)=3.5V(S) \leftarrow 0 + 0.5 \times (-7 - 0) = -3.5

MC 一次性用从起点到终点的全部信息来更新。

TD——走第 1 步后立刻更新。第 1 步只用到一步信息:

δ=2+V(S)V(S)=2+00=2\delta = -2 + V(S) - V(S) = -2 + 0 - 0 = -2

V(S)0+0.5×(2)=1V(S) \leftarrow 0 + 0.5 \times (-2) = -1

TD 目标 (2)(-2) 远小于 MC 目标 (7)(-7),但 TD 不需要等整局结束。随着更多轨迹的积累,TD 的 V(S)V(S) 也会逐步逼近真实值 3.375-3.375

两种方法最终收敛到同一个 VπV^\pi,但更新路径不同:MC 单次更新幅度大(3.5-3.5),方差高;TD 单次更新幅度小(1-1),但更频繁,方差低。

实际中,Actor-Critic 几乎都用 TD 方法来训练 Critic。在更高级的实现中(如第 7 章的 GAE),MC 和 TD 会被组合使用——通过参数 λ\lambda 在两者之间插值,获得偏差和方差的最佳平衡。

Critic 训练的完整流程

将以上内容整合,Actor-Critic 的单步训练流程如下:

  1. 交互:在状态 ss 下,Actor 选择动作 aa,环境返回 rr 和 $s'$
  2. 前向传播:Critic 计算当前预测 Vϕ(s)V_\phi(s) 和下一步预测 $V_\phi(s')$
  3. 计算 TD Error:$\delta = r + \gamma V_\phi(s') - V_\phi(s)$
  4. 更新 Critic:用 δ2\delta^2 作为损失更新 Critic 的参数 $\phi$
  5. 更新 Actor:用 δ\delta 作为优势估计更新 Actor 的参数 $\theta$

具体数值 walkthrough

假设当前 Critic 的价值表为 V(S)=1V(S) = -1V(M)=0.5V(M) = -0.5V(G)=0V(G) = 0γ=0.9\gamma = 0.9,Critic 学习率 αϕ=0.1\alpha_\phi = 0.1,Actor 学习率 αθ=0.01\alpha_\theta = 0.01

第 1 步:交互

在状态 SS,Actor 以概率 0.8 选择右走、0.2 选择左走。假设这次采样到右走,环境返回 r=1r = -1s=Ms' = M

第 2 步:前向传播

Vϕ(S)=1,Vϕ(M)=0.5V_\phi(S) = -1, \quad V_\phi(M) = -0.5

第 3 步:计算 TD Error

δ=r+γVϕ(s)Vϕ(s)=1+0.9×(0.5)(1)=1+(0.45)+1=0.45\delta = r + \gamma V_\phi(s') - V_\phi(s) = -1 + 0.9 \times (-0.5) - (-1) = -1 + (-0.45) + 1 = -0.45

δ=0.45<0\delta = -0.45 < 0,说明从 SS 右走到 MM 的体验比当前预测要差——实际拿到 1-1 加上 MM 的估计 0.45-0.45,总共 1.45-1.45,低于对 SS 的估计 1-1

第 4 步:更新 Critic

LCritic=δ2=(0.45)2=0.2025L_{\text{Critic}} = \delta^2 = (-0.45)^2 = 0.2025

参数更新(以价值表为例):

V(S)V(S)+αϕδ=1+0.1×(0.45)=1+(0.045)=1.045V(S) \leftarrow V(S) + \alpha_\phi \cdot \delta = -1 + 0.1 \times (-0.45) = -1 + (-0.045) = -1.045

Critic 降低了 V(S)V(S)——因为这次体验表明 SS 的价值比之前估计的还要低。

第 5 步:更新 Actor

δ=0.45\delta = -0.45 表示这次动作(右走)的表现不如预期。Actor 的更新方向是:降低这个动作的概率。以策略梯度为例:

θθ+αθδθlogπ(S)\theta \leftarrow \theta + \alpha_\theta \cdot \delta \cdot \nabla_\theta \log \pi(\text{右} \mid S)

δ<0\delta < 0 使得参数沿 θlogπ(S)\nabla_\theta \log \pi(\text{右} \mid S) 的反方向移动,即降低 π(S)\pi(\text{右} \mid S) 的概率。

如果 δ>0\delta > 0,则表示这个动作比预期好,Actor 会增加该动作的概率。

Critic 的参数 ϕ\phi 沿着"让 δ2\delta^2 更小"的方向更新——预测越来越准。Actor 的参数 θ\theta 沿着"让正 δ\delta 的动作概率更高"的方向更新——选择越来越好。两者形成良性循环:Critic 的评分越准,Actor 的进步就越快;Actor 尝试的新动作越多,Critic 看到的数据就越丰富,评分也越准。

参考文献

现代强化学习实战课程