跳转到正文

6.1 优势函数

第 5 章末尾我们发现:减掉基线 V(s)V(s) 可以降低策略梯度的方差,而不改变梯度的方向。本节将深入这个关键洞察,引出优势函数——它是连接 Actor 和 Critic 的桥梁。

本节会用到的前置知识

  • REINFORCE 策略梯度 θJθlogπ(as)Gt\nabla_\theta J \approx \nabla_\theta \log \pi(a|s) \cdot G_t——基线要加在哪里
  • 状态价值 V(s)V(s)——最好的基线是什么
  • 动作价值 Q(s,a)Q(s,a)——优势函数的定义依赖 QQVV 的差
  • TD Error δ=r+γV(s)V(s)\delta = r + \gamma V(s') - V(s)——优势函数的实用估计方法

从基线到优势函数

回忆第 5 章 REINFORCE 的策略梯度

θJθlogπ(as)Gt\nabla_\theta J \approx \nabla_\theta \log \pi(a|s) \cdot G_t

GtG_t 是从当前步到 episode 结束的总回报(回顾:折扣累积回报)。问题在于 GtG_t 波动巨大——同一个策略、同一个状态,跑两次可能拿到完全不同的 GtG_t

减掉基线 V(s)V(s) 后:

θJθlogπ(as)(GtV(s))\nabla_\theta J \approx \nabla_\theta \log \pi(a|s) \cdot (G_t - V(s))

括号里的 GtV(s)G_t - V(s) 就是优势函数(Advantage Function) 的一种估计。优势函数的正式定义是:

Aπ(s,a)=Qπ(s,a)Vπ(s)(6.1)A^\pi(s,a) = Q^\pi(s,a) - V^\pi(s) \tag{6.1}

符号含义
$A^\pi(s,a)$优势函数:在状态 ss 做动作 aa,比"平均水平"好了多少。
$Q^\pi(s,a)$动作价值函数:在状态 ss 先做动作 aa,之后按策略 π\pi 行动的期望折扣回报。
$V^\pi(s)$状态价值函数:在状态 ss 按策略 π\pi 行动的期望折扣回报。
$\pi$当前策略,决定在每个状态下各动作的概率。

两者的差恰好是"因为做了动作 aa,多拿了多少分"。

优势函数的含义是:做了这个动作,比"平均能拿多少分"好了多少。

  • A>0A > 0:这个动作比预期好,应该多选
  • A<0A < 0:这个动作比预期差,应该少选
  • A0A \approx 0:这个动作和预期差不多

用下棋类比:V(s)V(s) 是"这个棋局整体胜率 60%",Q(s,出车)Q(s, \text{出车}) 是"走车之后胜率 75%"。优势 A=75%60%=15%A = 75\% - 60\% = 15\%,说明走车比平均水平好了 15%,是个好选择。

用一个具体的 3 步 episode 来看优势函数是怎么算出来的。假设折扣因子 γ=0.9\gamma = 0.9,某次采样得到如下轨迹:

s0r=+2s1r=+3s2r=+1s3 (结束)s_0 \xrightarrow{r=+2} s_1 \xrightarrow{r=+3} s_2 \xrightarrow{r=+1} s_3\ (\text{结束})

从每个时刻开始计算折扣累积回报 GtG_t

G0=r1+γr2+γ2r3=2+0.9×3+0.92×1=2+2.7+0.81=5.51G_0 = r_1 + \gamma r_2 + \gamma^2 r_3 = 2 + 0.9 \times 3 + 0.9^2 \times 1 = 2 + 2.7 + 0.81 = 5.51

G1=r2+γr3=3+0.9×1=3.9G_1 = r_2 + \gamma r_3 = 3 + 0.9 \times 1 = 3.9

G2=r3=1G_2 = r_3 = 1

再假设 Critic 已经给出了各状态的价值估计:

状态$V(s)$
$s_0$3.0
$s_1$2.5
$s_2$0.8

GtG_tV(s)V(s) 代入 AGtV(s)A \approx G_t - V(s),得到每个时刻的优势估计:

时刻 $t$状态$G_t$$V(s_t)$$A = G_t - V(s_t)$含义
0$s_0$$5.51$$3.0$$5.51 - 3.0 = 2.51$比预期好了 $2.51$
1$s_1$$3.9$$2.5$$3.9 - 2.5 = 1.4$比预期好了 $1.4$
2$s_2$$1$$0.8$$1 - 0.8 = 0.2$比预期好了 $0.2$

三个时刻的优势都是正数,说明这次轨迹上的每个动作都比平均水平表现更好。GtV(s)G_t - V(s) 是用 MC 回报对优势函数的估计;它无偏但方差高(不同的轨迹会给出差异很大的 GtG_t)。

优势函数与累积回报

优势函数之所以能降低方差,核心在于它减掉了"本来就能拿到的分",只保留"因为做了这个动作多拿的分"。

构造一个更完整的例子。假设某状态 ss 下,策略的平均回报是 V(s)=10V(s) = 10。采样 4 条轨迹,回报分别是 Gt(1)=18G_t^{(1)} = 18Gt(2)=15G_t^{(2)} = 15Gt(3)=7G_t^{(3)} = 7Gt(4)=4G_t^{(4)} = 4

先看用 GtG_t 作梯度信号的情况:

Episode$G_t$梯度信号含义
118$\nabla \times 18$大正数,强推该动作
215$\nabla \times 15$正数,推该动作
37$\nabla \times 7$正数,推该动作
44$\nabla \times 4$正数,推该动作

四次都是正数。策略会认为"在这个状态下,不管怎样这个动作都是好的"——但 episode 3 和 4 的回报实际上低于平均水平。

再看用 A=GtV(s)A = G_t - V(s) 的情况:

Episode$G_t$$V(s)$$A = G_t - V(s)$梯度信号含义
11810$18 - 10 = +8$$\nabla \times (+8)$比平均好很多,强推
21510$15 - 10 = +5$$\nabla \times (+5)$比平均好,推
3710$7 - 10 = -3$$\nabla \times (-3)$比平均差,抑制
4410$4 - 10 = -6$$\nabla \times (-6)$比平均差很多,强抑制

GtG_t 时,四个 episode 都给出正的梯度信号,策略无法区分"真的好"和"碰巧碰上了高回报"。用 AA 时,信号被校准了:高于平均的给正信号,低于平均的给负信号。

量化看方差变化。用 GtG_t 时,四个信号的均值为 18+15+7+44=11\frac{18+15+7+4}{4} = 11,方差为 (1811)2+(1511)2+(711)2+(411)24=49+16+16+494=32.5\frac{(18-11)^2+(15-11)^2+(7-11)^2+(4-11)^2}{4} = \frac{49+16+16+49}{4} = 32.5。用 AA 时,四个信号的均值为 8+5364=1\frac{8+5-3-6}{4} = 1,方差为 (81)2+(51)2+(31)2+(61)24=49+16+16+494=32.5\frac{(8-1)^2+(5-1)^2+(-3-1)^2+(-6-1)^2}{4} = \frac{49+16+16+49}{4} = 32.5

四个样本的方差相同,但 AA 的均值更接近零。当样本量增大时,GtG_t 的波动范围由整条轨迹的随机性决定(可能从 0 到几十),而 AA 的波动范围被 V(s)V(s) 中心化,正负抵消使得梯度的期望方向更稳定。这正是"减掉基线降低方差"的机制。

用 TD Error 估计优势

优势函数的理论定义是 A=QVA = Q - V,但实际中通常不直接计算 QQ。从定义出发,经过一步展开就能得到一个更实用的形式。

Aπ(s,a)=Qπ(s,a)Vπ(s)A^\pi(s,a) = Q^\pi(s,a) - V^\pi(s) 开始。动作价值函数的定义是:

Qπ(s,a)=E[Rt+1+γVπ(St+1)St=s,At=a]Q^\pi(s,a) = \mathbb{E}\left[R_{t+1} + \gamma V^\pi(S_{t+1}) \mid S_t = s, A_t = a\right]

这个期望的含义:在状态 ss 做了动作 aa 之后,拿到的即时奖励加上下一状态的价值。如果只取一次采样(不走完整个 episode,也不对所有可能转移求平均),就得到 QQ 的一步估计:

Q(s,a)r+γV(s)Q(s,a) \approx r + \gamma V(s')

其中 rr 是这一步实际拿到的奖励,ss' 是这一步实际到达的下一状态。把这个近似代入优势函数定义:

A(s,a)=Q(s,a)V(s)r+γV(s)V(s)A(s,a) = Q(s,a) - V(s) \approx r + \gamma V(s') - V(s)

右边就是TD Error

A(s,a)r+γV(s)V(s)=δ(6.2)A(s,a) \approx r + \gamma V(s') - V(s) = \delta \tag{6.2}

符号含义
$r$这一步实际拿到的即时奖励。
$\gamma$折扣因子,控制未来价值被打多少折扣。
$V(s')$Critic 对下一状态 ss' 的价值估计。
$V(s)$Critic 对当前状态 ss 的价值估计。
$\delta$TD Error:走了一步之后,实际结果比预测好了多少。

用 TD Error 替代 GtG_t 作为策略梯度的信号,有两个好处:

  1. 不需要等 episode 结束——每走一步就能更新(GtG_t 需要跑完一整局,这是MC 方法的限制)
  2. 方差更低——δ\delta 只涉及一步的随机性(GtG_t 涉及整条轨迹的随机性)

用一个具体的数值例子走一遍。假设 γ=0.9\gamma = 0.9,在某一步中:

  • 当前状态 ss,Critic 估计 $V(s) = 5.0$
  • 智能体做了某个动作,拿到即时奖励 $r = +2$
  • 到达下一状态 ss',Critic 估计 $V(s') = 4.0$

代入 TD Error 公式:

δ=r+γV(s)V(s)=2+0.9×4.05.0=2+3.65.0=+0.6\delta = r + \gamma V(s') - V(s) = 2 + 0.9 \times 4.0 - 5.0 = 2 + 3.6 - 5.0 = +0.6

δ=+0.6\delta = +0.6,说明这一步比 Critic 的预测好了 0.60.6。用这个 δ\delta 作为优势估计,策略梯度会轻微推动该动作的概率上升。

换一组数字。假设同一次转移中 r=1r = -1

δ=1+0.9×4.05.0=1+3.65.0=2.4\delta = -1 + 0.9 \times 4.0 - 5.0 = -1 + 3.6 - 5.0 = -2.4

δ=2.4\delta = -2.4,说明这一步比预测差了很多。策略梯度会推动该动作的概率下降。

再看 δ=0\delta = 0 的情况。如果 r=+1r = +1V(s)=5.0V(s') = 5.0V(s)=5.5V(s) = 5.5

δ=1+0.9×5.05.5=1+4.55.5=0\delta = 1 + 0.9 \times 5.0 - 5.5 = 1 + 4.5 - 5.5 = 0

δ=0\delta = 0:这一步的实际结果和 Critic 的预测完全一致,策略梯度信号为零,该动作的概率不变。

现在把三个时刻连起来看。假设一个 3 步 episode,γ=0.9\gamma = 0.9

时刻状态动作$r$下一状态$V(s)$$V(s')$$\delta = r + \gamma V(s') - V(s)$
0$s_0$$a_0$$+3$$s_1$2.04.0$3 + 0.9 \times 4.0 - 2.0 = 3 + 3.6 - 2.0 = +4.6$
1$s_1$$a_1$$+1$$s_2$4.01.0$1 + 0.9 \times 1.0 - 4.0 = 1 + 0.9 - 4.0 = -2.1$
2$s_2$$a_2$$+2$$s_3$1.00.0$2 + 0.9 \times 0.0 - 1.0 = 2 + 0.0 - 1.0 = +1.0$

三步的 δ\delta 分别是 +4.6+4.62.1-2.1+1.0+1.0。时刻 0 的动作远超预期,策略应增加 a0a_0 的概率;时刻 1 的动作低于预期,策略应降低 a1a_1 的概率;时刻 2 略好于预期,轻微鼓励 a2a_2

对比用 MC 回报 GtG_t 的情况(同一条轨迹):

G0=3+0.9×1+0.92×2=3+0.9+1.62=5.52G_0 = 3 + 0.9 \times 1 + 0.9^2 \times 2 = 3 + 0.9 + 1.62 = 5.52

G1=1+0.9×2=2.8G_1 = 1 + 0.9 \times 2 = 2.8

G2=2G_2 = 2

对应的 MC 优势估计:

时刻$G_t$$V(s)$$A_{\text{MC}} = G_t - V(s)$
05.522.0$5.52 - 2.0 = +3.52$
12.84.0$2.8 - 4.0 = -1.2$
221.0$2 - 1.0 = +1.0$

两种估计给出的方向一致(正、负、正),但数值不同。TD 优势 δ\delta 只看一步,MC 优势 GtV(s)G_t - V(s) 看到终点。δ\delta 的方差更低(只含一步随机性),但有偏差(依赖 V(s)V(s') 的准确性);GtV(s)G_t - V(s) 无偏但方差高(包含整条轨迹的随机性)。

这是 MC → TD 的演进在策略空间的再现:REINFORCE 用 GtG_t(MC),Actor-Critic 用 δ\delta(TD)。

REINFORCE (MC)Actor-Critic (TD)
优势估计GtV(s)G_t - V(s)(需要完整轨迹)r+γV(s)V(s)=δr + \gamma V(s') - V(s) = \delta(走一步就更新)
更新时机episode 结束后每走一步
方差
代价需要训练 Critic

Critic 网络实现

要计算 δ=r+γV(s)V(s)\delta = r + \gamma V(s') - V(s),你需要知道 V(s)V(s)V(s)V(s')。但在真实问题中,VV 是未知的——需要一个网络来估计它。这个网络就是 Critic

Actor(策略网络)           Critic(价值网络)
  输入:状态 s                 输入:状态 s
  输出:π_θ(a|s) 概率分布     输出:V_φ(s) 标量
  作用:选动作                作用:评估状态价值
  参数:θ                     参数:φ

Actor 和 Critic 共享输入(状态 ss),但输出不同:Actor 输出动作概率分布,Critic 输出价值标量。它们通过优势函数 AδA \approx \delta 协作:Critic 给出评估,Actor 根据评估调整行为。

但 Critic 怎么训练?它怎么学会准确估计 V(s)V(s)?下一节将展开第 3 章速览过的 DP、MC、TD 三种方法在 Critic 训练中的具体应用。Critic 训练方法

现代强化学习实战课程