Contributions

本文首次将知识蒸馏/Knowledge Distillation 系统性地引入深度强化学习/Deep Reinforcement Learning/DRL 领域,提出策略蒸馏/Policy Distillation 方法,将训练好的 DQN 教师网络的策略迁移到更小的学生网络中。核心发现是:蒸馏成败的关键不在网络结构,而在损失函数如何刻画动作间差距的重要性——基于 softmax 温度参数的 KL 散度损失显著优于 MSE 回归和 NLL 分类损失。方法在四个场景下得到验证:单任务蒸馏、模型压缩(参数量缩小至 4% 仍保留 84% 性能)、多任务蒸馏(将多个单任务专家合并为一个多任务策略,且超越各单任务教师)、以及在线蒸馏(在 DQN 训练过程中持续跟踪教师,获得更稳定的策略)。

主要局限:实验仅在 Atari 游戏的离散动作空间上进行验证,未涉及连续动作空间;教师网络固定为 DQN,未探索其他 RL 算法作为教师的效果;多任务蒸馏依赖为每个任务设置独立的输出层(controller),并未实现完全统一的动作空间建模。

1. Introduction

深度强化学习通过 DQN 等方法已能在 Atari 游戏等复杂视觉任务上达到甚至超越人类水平,但 DQN 训练出的网络通常较大且训练代价高昂。这带来两个实际问题:其一,单个训练好的网络只能执行一个任务,无法共享跨任务的知识;其二,网络的最终容量远大于表达最终策略所需的容量——DQN 训练过程中网络需要先后表示一系列不同的策略(从随机探索到最终收敛),因此需要较大的网络来适应这一动态过程,但最终策略本身可能并不需要这么多参数。

知识蒸馏最初由 Model Compression 提出,后经 Distilling the Knowledge in a Neural Network 发展为通过升高 softmax 温度来”软化”教师网络的输出分布,使学生网络能学到更丰富的类间关系信息。然而,蒸馏此前仅应用于分类任务,其输出是归一化概率分布。将蒸馏应用到 RL 策略上面临一个关键困难:DQN 的输出是 Q 值——无界的实数值,其尺度取决于未来折扣回报,在许多状态下不同动作的 Q 值差异极小(例如游戏中的”无关帧”),而在关键决策点差异又很大。这使得直接回归 Q 值或简单分类都不理想,需要一种能根据动作间相对重要性来加权的损失函数。

本文的核心洞察是:策略蒸馏的关键在于损失函数的选择。通过将教师的 Q 值经 softmax(带温度参数 )转换为概率分布,再用 KL 散度来衡量学生与教师分布的差异,可以在”保留全部动作信息”和”聚焦最优动作”之间取得良好平衡。

2. Problem Setup

DQN 背景

给定环境 ,在时间步 ,智能体观察到 ,选择动作 ,获得奖励 。定义序列 及折扣回报:

最优动作价值函数 给出在状态 下执行动作 后遵循最优策略 所能获得的最大期望回报。DQN 用卷积神经网络逼近 ,通过最小化以下损失进行训练:

其中 为经验回放缓冲区, 为目标网络的参数(定期从 复制,用于稳定训练)。

策略蒸馏的设定

策略蒸馏的目标是将教师网络 (训练好的 DQN)的策略迁移到学生网络 。具体地,教师生成数据集:

其中每个样本包含一段短观察序列 及教师在该状态下输出的所有动作的 Q 值向量 。学生网络通过监督学习在此数据集上训练,不与环境交互。

3. Methods

方法总览

策略蒸馏的核心问题是:给定教师的 Q 值输出,如何设计损失函数来训练学生网络?论文系统比较了三种损失函数,并在此基础上将蒸馏扩展到单任务压缩、多任务合并和在线跟踪三个场景。

三种蒸馏损失函数

(1)负对数似然损失/Negative Log-Likelihood/NLL

仅使用教师 Q 值中的最优动作 作为硬标签,训练学生进行分类:

NLL 仅保留了”哪个动作最好”的信息,丢弃了动作间 Q 值差异的全部结构。其优势是简单直接,但缺陷在于:当教师本身存在噪声(Q 值估计不精确)时,NLL 会将教师的估计误差放大——它对所有被标为”最优”的动作赋予相同的置信度,无法区分”以微弱优势胜出的最优动作”和”以压倒性优势胜出的最优动作”。

(2)均方误差损失/Mean Squared Error/MSE

让学生直接回归教师的完整 Q 值向量:

其中 分别为教师和学生的 Q 值向量。MSE 保留了所有动作的完整 Q 值信息,但问题是:Q 值是无界的实数值,其绝对尺度可能很不稳定;更重要的是,MSE 对所有 Q 值差异一视同仁,而贪心策略的决策仅取决于动作间的相对排序。在 Atari 游戏中,大量状态下各动作的 Q 值极为接近(差异在小数点后几位),MSE 会将大量优化预算花费在这些对策略毫无影响的微小差异上。

(3)KL 散度损失/Kullback-Leibler Divergence/KL

先将教师和学生的 Q 值分别通过带温度参数 的 softmax 转换为概率分布,再最小化两者之间的 KL 散度:

KL 损失的直觉

KL 损失介于 NLL 和 MSE 之间,通过温度参数 控制信息保留的程度。在传统分类蒸馏中,教师的 softmax 输出已经很尖锐,需要升高温度来软化分布、暴露更多类间关系。但在策略蒸馏中情况相反:Q 值本身不是概率分布,不同动作的 Q 值差异往往很小,需要降低温度来锐化分布、突出动作间的相对重要性。实验中最优温度为 ,远小于 1,这印证了策略蒸馏需要锐化而非软化的直觉。

单任务策略蒸馏

单任务蒸馏的流程如下:教师 DQN 在环境中持续生成数据(观察帧及对应的 Q 值输出),存入经验回放缓冲区;学生网络从缓冲区中采样,以上述某种损失函数进行监督训练。教师在数据收集过程中以 -greedy 策略行动(),以保证数据的多样性。

多任务策略蒸馏

多任务蒸馏的目标是将 个独立训练的单任务 DQN 教师合并为一个统一的学生网络。每个教师分别生成数据存入各自的缓冲区,学生网络在训练时每个 episode 切换到不同任务的缓冲区中采样。由于不同任务的动作空间可能不同,学生网络采用共享卷积特征提取层加独立输出层(称为 controller)的架构——每个任务对应一个独立的 MLP 输出头,训练和评估时根据任务标签切换。

蒸馏 vs. 多任务 DQN 的优势

多任务 DQN(用一个网络同时在多个任务上做 Q-learning)在 Atari 上表现很差,原因包括不同任务间的策略干扰、奖励尺度差异、以及价值函数学习的内在不稳定性。策略蒸馏绕开了这些问题:每个教师独立训练到收敛后再蒸馏,蒸馏过程只涉及监督学习,不受 Q-learning 不稳定性影响;且策略(softmax 后的动作分布)天然比价值函数具有更低的方差。

在线策略蒸馏

在线蒸馏将策略蒸馏嵌入 DQN 的训练过程中:DQN 教师正常训练,每当达到新的最高分时保存当前网络作为”最佳教师”;学生网络持续蒸馏当前最佳教师的策略。这一设计的动机是 DQN 训练过程中策略剧烈波动(得分方差很大),而蒸馏可以起到平滑和稳定的作用——学生追踪的始终是教师历史上的最佳策略,不会受到教师训练过程中的性能回退影响。

4. Experiments

实验在 10 个 Atari 游戏上进行,涵盖 DQN 表现从超人水平(Breakout、Space Invaders)到低于人类水平(Q*bert、Ms.Pacman)的不同难度等级。评估采用人类专家起始状态的泛化测试协议(丢弃人类操作阶段累积的分数),以避免智能体仅记忆固定起始位置的轨迹。

损失函数对比

在网络结构与教师完全相同的条件下,对比三种损失函数在 4 个游戏上的蒸馏效果:

游戏DQN (score)Dist-MSE (%DQN)Dist-NLL (%DQN)Dist-KL (%DQN)
Breakout303.933.977.694.7
Freeway25.899.4101.4103.5
Pong16.294.494.9100.9
Q*bert4589.8122.2147.6155.0

KL 损失在所有游戏上均为最优,且在多数游戏上学生超越教师。MSE 在 Breakout 上表现极差(仅 33.9%),印证了 MSE 将优化预算浪费在对策略无关紧要的 Q 值微小差异上的问题。

模型压缩

使用 KL 损失将 10 个 DQN 教师分别蒸馏到三种不同大小的学生网络(参数量分别为教师的 25%、7%、4%):

  • Dist-KL-net1(25% 参数,428K):平均得分为教师的 108%,压缩 4 倍反而超越教师
  • Dist-KL-net2(7% 参数,113K):平均得分为教师的 102%,压缩 15 倍仍与教师持平
  • Dist-KL-net3(4% 参数,62K):平均得分为教师的 84%,压缩 25 倍仍保留大部分性能

这一结果强烈表明 DQN 的最终策略远不需要训练时那么大的网络容量。直接训练小网络做 Q-learning 效果很差,因为小网络在 Q-learning 的策略迭代过程中无法快速适应不断变化的值函数估计,而蒸馏绕过了这一问题。

多任务蒸馏

在 3 个游戏的实验中,三种方法使用相同大小的网络(约 90% 参数共享,仅输出 controller 不同):

  • Multi-DQN(多任务 Q-learning):几何平均得分为单任务教师的 83.5%
  • Multi-Dist-NLL(NLL 蒸馏):105.1%
  • Multi-Dist-KL(KL 蒸馏):116.9%

多任务蒸馏不仅远超多任务 DQN,甚至超越了各单任务教师。在 10 个游戏上的扩展实验中,单个蒸馏网络(参数约为单个教师的 4 倍)达到了 10 个教师几何平均得分的 89.3%,其中有 3 个游戏超越教师。

在线蒸馏

以 Q*bert 为例,DQN 训练过程中得分波动剧烈(方差极大),而在线蒸馏学生的学习曲线平滑且稳定,最终性能与教师历史最佳相当或更好。这表明蒸馏可以作为一种实用的正则化手段,在 RL 训练过程中过滤掉策略的剧烈抖动。

实验局限性

  • 所有实验限于 Atari 离散动作空间,未验证连续控制场景的适用性
  • 损失函数的比较仅在 4 个游戏上进行参数选择,然后固定应用于其余 6 个游戏,可能存在选择偏差
  • 多任务蒸馏采用独立 controller 输出层,在任务数量增多时扩展性存疑
  • 10 个游戏的多任务蒸馏未与多任务 DQN 对比,因为多任务 DQN 在 10 个游戏上完全失败,缺少可比较的基线
  • 未探索 KL 与 NLL 或 MSE 组合损失的可能性(论文自身也指出了这一点)

模型压缩与蒸馏Model Compression 最早提出用教师网络的输出训练学生网络来实现模型压缩。Distilling the Knowledge in a Neural Network 引入温度参数来软化教师的 softmax 输出,使学生能学到更丰富的类间关系。本文将这一思路扩展到 RL 领域,但关键的不同在于:RL 中教师输出的是 Q 值而非概率分布,需要反向使用温度(锐化而非软化)。

模仿学习DAgger 等交互式模仿学习方法通过让学生策略控制数据收集来解决分布偏移问题。策略蒸馏与此不同:数据完全由教师生成,学生不与环境交互。CAPI 框架的一次迭代可以视为使用特定损失函数(按动作间隙加权分类)的策略蒸馏,与本文的 KL 损失在精神上一致。

多任务学习:传统多任务学习要求任务共享输入分布,而 Atari 游戏的图像在不同游戏间差异极大,不共享共同的统计结构。策略蒸馏通过将各任务的策略先独立训练到收敛、再合并蒸馏,绕开了多任务 RL 训练中的干扰和不稳定性问题。

Future Work

论文自身提及的方向:

  • 探索 KL 与其他损失函数的组合
  • 将策略蒸馏应用于更广泛的 RL 算法(不限于 DQN)
  • 进一步研究蒸馏作为正则化手段的机制

自然延伸的方向包括:将策略蒸馏推广到连续动作空间(此时 softmax 温度的类比需要重新设计);探索学生网络与环境交互式训练的方案(结合模仿学习中的在线纠正思路);以及在多任务场景下研究共享动作空间的统一输出架构。