PER的由来
在PER之前,像DQN(Nature 2015)以及Double DQN等Deep Q-learning方法都是通过经验回放的手段进行目标Q值的计算的。在采样的时候对于经验回放池中的所有transition,被采样到的概率都是相等的。
但是上述方法可能在某些情况下会非常影响训练效率。在PER的论文中提出了BLIND CLIFFWALK测试环境,其奖励十分稀疏,假设有\(n\)个状态,在随机选择的动作序列中只有\(2^{-n}\)左右的概率才能得到第一个非零奖赏。在这种情况下,对于模型优化最相关的transition只能隐藏在众多的失败尝试中。
优先经验回放可以让对学习过程更有用的样本以更高的频率回放,其中对学习过程的作用使用时序差分误差(TD-error)的大小来衡量。 通过这种方式可以加快收敛速度。
PER的方法选择
TD-error表示了当前价值函数输出的动作价值与对该动作价值的估计值之间的差异。越大的TD-error表示当前价值函数的输出越不准确,也就暗示可以从该样本中学到更多。由于新生成的样本不知道其TD-error,所以把它放在第一位,以保证至少回放一次,之后每次都回放TD-error最大的样本。
但是单独的使用TD-error来做贪婪优先经验回放(greedy TD-error prioritization)存在许多问题:
第一,为了避免在整个经验池上更新TD-error, 只有被重放的经验的TD-error被更新,这样就会导致TD-error较小的样本在第一次被回放之后很久都不会被重新回放(在参数更新之后此样本的TD-error可能会变大,然而由于不更新,它会一直被认为很小); 第二,对噪声敏感(例如奖励是随机的时候或者在噪声影响下,某些经验的TD-error可能始终不会减小,进而导致其不断被重放),bootstrapping会进一步加剧这个问题,因为在bootstrapping中近似函数的误差也是噪声的一种。最后,贪婪优先回放的样本会集中在一个小范围内,因为在非线性近似中,TD-error的缩小是很慢的,开始TD-error大的样本可能会被回放很多次,这种多样性的缺失容易导致过拟合(over-fitting)的产生。
为了解决上述问题,作者提出了Rank-based prioritization以及Proportional prioritization两种方法。
Rank-based方法顾名思义,就是根据经验回放池中各个transition的TD-error的值进行排序,根据排名来确定被选定概率的方法。具体如下:
\[ p_{i}=\frac{1}{\operatorname{rank}(\mathrm{i})} \]
其中\(rank(i)\)是经验回放池中根据TD-error排行的第\(i\)个transition的位置。通过这种方式也可以实现batch-size sampling,可以将排名段分成几个等概率区间,再在各个等概率区间里面均匀采样。比如说,假设experience replay大小为100,batch size为3,那么就事先通过计算,将100分为3个等概率区间(e.g., A:1-20,B:21-50, C:51-100),之后就在A、B和C区间内分别做均匀采样,最后取得3个transition。
在Proportional prioritization方法中,transition被选择的概率正比于其TD-error的值。概率\(p_{i}\)的表达式如下:
\[ p_{i}=\left|\delta_{i}\right|+\epsilon \]
其中\(\delta_{i}\)为TD-error,加入\(\epsilon\)防止出现概率为0的情况。在实际实现过程中会使用SumTree结构降低搜索复杂度。详情请参考此处。
Proportional prioritization方法较为准确,但可能会对outlier较为敏感,Rank-based prioritization方法相对没有定量的考虑TD-error的值,但是具有更好的鲁棒性。
重要性采样
如之前的博文所说,Q-learning算法通常以拟合Bellman最优函数为目标:
\[ Q_{\pi}(s, a)=R(s, a)+\gamma \mathbb{E}_{s^{\prime} \sim p(s, a)}\left[\max _{a^{\prime}} Q_{\pi}\left(s^{\prime}, a^{\prime}\right)\right] \]
通过大量的采样,我们可以使得到的状态\(s^{\prime}\)逼近其分布\(p(s, a)\)。由于在DQN算法中经验回放池的采样是等概率随机的,因此不会为上述式子的计算引入偏差。但是当使用PER进行采样时,因为对于transition的采样概率不再是相等的,所以计算出来的结果也有了偏差。
为了消去偏差,我们使用重要性采样。其过程如下:
\[ E_{x \sim p}[f(x)]=\int f(x) p(x) d x=\int f(x) \frac{p(x)}{q(x)} q(x) d x=E_{x \sim q}\left[f(x) \frac{p(x)}{q(x)}\right] \]
将\(p(x)=1 / N\)以及\(q(x)=P(j)\)代入后,就能得到各个transition的重要性采样权重:
\[ \omega_{j}=\left(\frac{1}{N} \cdot \frac{1}{P(j)}\right)^{\beta} \]
其中\(\beta\)用来调节bias修正程度(学习的初始阶段有bias也没所谓,但在后期就要消除bias)。
将权重归一化后得到:
\[ \omega_{j}=\frac{(N \cdot P(j))^{-\beta}}{\max _{i} \omega_{i}} \]
算法流程
算法流程参考Pinard的博客。
Python实现参考我的Github。