Gradient Episodic Memory for Continual Learning
David Lopez-Paz and Marc’Aurelio Ranzato
Facebook Artificial Intelligence Research
论文主要探讨了持续学习(continual learning)的问题,即模型在不断接触到新任务时,如何快速解决新问题而不遗忘之前学到的知识。作者提出了一种新的评估模型学习连续数据集的指标,并提出了一个名为Gradient Episodic Memory (GEM)的模型,该模型通过缓解遗忘并允许有益的知识传递到先前任务来解决这个问题。
论文回顾了持续学习的相关研究,包括多任务学习、转移学习、零次学习、课程学习等。特别提到了对抗灾难性遗忘的方法,如冻结早期网络层、使用模块化网络结构、调整学习率以及基于“记忆”的方法。
GEM模型通过使用一个“情节记忆”(episodic memory)来存储每个任务的观察示例的子集。这个记忆帮助模型在新任务学习时最小化对旧任务性能的负面影响(即灾难性遗忘),同时允许有益的知识传递。GEM通过解决一个带有不等式约束的优化问题来更新模型参数,这些约束确保了对旧任务的损失不会增加。
情节记忆的构建:
GEM模型为每个任务分配一定数量的记忆位置(memory locations),用于存储该任务的最后几个示例。这些示例是从当前任务中观察到的,并且用于后续的参数更新过程中。
参数更新过程中的约束:
在观察到新的数据样本
(
x
,
t
,
y
)
(x, t, y)
(x,t,y)时,GEM模型不仅需要最小化当前任务的损失,还需要确保对之前任务的损失不增加。这通过将之前任务的损失作为不等式约束来实现。
不等式约束的数学表示:
对于每个先前的任务
k
(
k
<
t
)
k(k < t)
k(k<t),GEM模型需要满足以下不等式约束:
?
g
,
g
k
?
≥
0
for?all
k
<
t
\langle g, g_k \rangle \geq 0 \quad \text{for all} \quad k < t
?g,gk??≥0for?allk<t
其中,
g
g
g是当前任务的参数更新,
g
k
g_k
gk?是之前任务k的损失梯度。
投影梯度:
如果当前的参数更新
g
g
g违反了不等式约束,GEM会将
g
g
g投影到满足所有约束的最近的梯度上。这个投影梯度记为
g
~
\tilde{g}
g~?,并且可以通过解决一个二次规划问题(Quadratic Program, QP)来找到:
minimize
g
~
2
1
∣
∣
g
?
g
~
∥
2
2
subject?to
?
g
~
,
g
k
?
≥
0
for?all
k
<
t
\text{minimize} \tilde{g} \frac 2 1 ||g-\tilde{g}\|_2^2\quad\text{subject to}\quad\langle\tilde{g},g_k\rangle\geq0\quad\text{for all}\quad k<t
minimizeg~?12?∣∣g?g~?∥22?subject?to?g~?,gk??≥0for?allk<t
这里的
g
~
\tilde{g}
g~?是投影后的梯度,它在满足所有约束的同时,尽可能接近原始梯度
g
g
g。
二次规划问题的求解:
为了高效地解决上述二次规划问题,GEM利用了二次规划的对偶形式。对偶问题的形式如下:
minimize
v
1
2
v
T
G
G
T
v
+
g
T
G
T
v
subject?to
v
≥
0
\text{minimize}_{v} \frac{1}{2} v^T G G^T v + g^T G^T v \quad \text{subject to} \quad v \geq 0
minimizev?21?vTGGTv+gTGTvsubject?tov≥0
其中,
G
=
?
(
g
1
,
.
.
.
,
g
t
?
1
)
G = -(g_1, ..., g_{t-1})
G=?(g1?,...,gt?1?),并且我们忽略了常数项
g
T
g
g^T g
gTg。这个对偶问题涉及到的变量数量远少于原始问题,因为它只依赖于到目前为止观察到的任务数量
(
t
?
1
)
(t-1)
(t?1),而不是模型的参数数量
(
p
)
(p)
(p)。
更新参数:
一旦求解出对偶问题的解
v
?
v*
v?,就可以恢复投影后的梯度更新
g
~
=
G
T
v
?
+
g
\tilde{g} = G^T v* + g
g~?=GTv?+g。然后,模型参数通过这个更新进行更新。
通过这种方法,GEM能够在学习新任务的同时,尽量保持对旧任务的预测性能,从而缓解灾难性遗忘。
论文在MNIST和CIFAR-100数据集的变体上进行了实验,这些数据集模拟了模型在连续任务中观察到的示例序列。实验结果表明,GEM在与最先进的方法相比时,表现出强大的性能。
论文提出了三个可以进一步研究的方向:1) 利用结构化的任务描述符以实现零次学习;2) 探索高级的记忆管理策略,如构建任务的核心集;3) 减少每个GEM迭代所需的计算时间,因为当前的方法需要对每个任务进行一次反向传递。
论文提出了一个名为GEM的模型来解决持续学习中的灾难性遗忘问题,并在实验中展示了其有效性。GEM通过使用情节记忆来最小化对旧任务性能的负面影响,并允许有益的知识传递。尽管GEM在实验中表现出色,但仍有改进的空间,特别是在利用任务描述符、记忆管理和计算效率方面。