《DLOW:Domain Flow for Adaptation and Generalization》论文解析

今天说的这篇文章,也是用来解决迁移学习问题的。迁移学习要解决一个什么问题呢?就是要把模型在source域(源域)学习到的知识,用到target域(目标域)里。

DLOW这篇文章主要提出了两点:1、可以把source域的数据迁移成中间域,中间域也就是介于source和target之间的域。  2、训练的时候如果有多个target域的话,DLOW可以生成网络没有见过的数据风格。

那么接下来介绍一下算法原理:

1、CycleGAN

作为本算法的基础,cycleGAN至关重要。具体的介绍请看我另一篇博客:

https://blog.csdn.net/wenqiwenqi123/article/details/105123491

在这里复习一下,cycleGAN主要由两个loss组成:

其中xs代表源域数据,xt代表目标域数据。Gst为source到target的生成器,Dt为判别目标域真实数据和生成数据的判别器。反之相同。

 

2、定义中间域

设中间域为M(z),z为[0,1]的一个变量,跟与source和target的联系有关。换句话说,z=0的时候,M(z)就是source,当z=1的时候,M(z)就是target。

如下图所示,其实从S到T的路径有许多,但是我们希望我们能找到一条最近的路(贴着地平线过去的那条,红线)。

因此我们得到了如下公式:

其中dist为某种距离表示,在本算法中使用了公式一的距离(cycleGAN)。

因此把3公式化简一下,得到了loss函数:

 

3、DLOW模型

综上,我们现在有了Source域的数据,和Z=[0,1],Gst的目的是得到中间域而不再是得到Target域。

因此有

Adversarial Loss:GAN肯定会有对抗损失。这里定义判别器Ds(x)是区分M(z)和S,而Dt(x)是区分M(z)和T。因此对抗损失可以写为:

用上面的损失来代入dist的话,则得到:

Image Cycle Consistency Loss:cycleGAN的循环一致损失:

其中Gts是从target到M的生成器:

总损失:

 

实现:

整个网络架构如下所示:

z和S一起作为输入输进生成器Gst,对z作反卷积得到(1,16,1,1)的向量,同时对z进行采样:

这样的话,z会在一开始的时候趋向于较小,随着训练逐渐加大,这样可以更稳定。

 

提升域自适应模型:

作者做了一个实验,可以提升域自适应算法的能力。

原本的source域的数据为S,那么作者用DLOW这套算法,把z从[0,1]中均匀采样,用生成器GST生成了新的数据集S~。

因此在S~中,数据分布从S到T都有,再用S~数据集作为域自适应算法的训练数据,可以有效提升效果。

至于此处把Ladv赋予了一个权值,根号1-z,是因为对于每一个样本来说,如果z比较大的话,说明这个样本更接近target域,因此对抗损失的权值需要降低。

 

风格生成:

大部分的风格迁移算法,都只能一对一地进行迁移。也就是说在训练完后,就只能迁移到那个风格了。

但是DLOW可以生成训练数据里没有见过的风格。假设有K个目标域,则z拓展成一个k维的向量[z1,z2,...,zk],所有z值加起来等于1。因此我们需要优化的目标变成了:

可以比较容易地修改网络结构得到这个。

 

4、实验部分

实验部分作者做了俩实验,一个是刚刚说的用生成的中间域数据训练domain adaptation的模型,得到更好的结果。一个是风格迁移得到新风格unseen in the training data。

先说实验一,做了一个语义分割的task,从GTA5迁移到Cityscapes。结果如下:

 

实验二,作者用了真实照片迁移到油画风格的task。用了莫奈、梵高等不同的target domain。

因此z变成了[z1,z2,z3,z4],其中z1+z2+z3+z4=1 

训练的时候每五步分别用:[1,0,0,0], [0,1,0,0], [0,0,1,0], [0,0,0,1] ,均匀随机取样。得到了不错的结果。

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章