Deep Cocktail Network: Multi-source Unsupervised Domain Adaptation with Category Shift
SUMMARY@ 2020/5/12
文章目錄
- 1. Method abstract
- 2. Motivation
- 3. Challenges /Problem to be solved
- 4. Contribution
- 5. Related work
- 5.1 Unsupervised domain adaptation with single source
- 5.2 Domain adaptation with multiple sources
- 5.3 two branches of transfer learning closely relate to MDA (supervised)
- 6. Settings
- 7. Compared with Open Set DA
- 8. DCTN: framework details
- 8.1 Feature extractor
- 8.2 (Multi-source) domain discriminator
- 8.3 (Multi-source) category classifier
- 8.4 Target classification operator
- 8.5 Connection to distribution weighted combining rule
- 9 Learning
- 9.1 Pre-training C and F
- 9.2 Multi-way Adversarial Adaptation
- Online hard domain batch mining
- 9.3 Target Discriminative Adaptation
- 10. Experiments
- 10.1 Benchmarks
- 10.2 Evaluations in the vanilla(common) setting
- 10.3 Evaluations in the category shift setting
- 11. Further Analysis
1. Method abstract
Inspired by the distribution weighted combining rule in [33], the target distribution can be represented as the weighted combination of the multi-source distributions.
An ideal target predictor can be obtained by integrating all source predictions based on the corresponding source distribution weights.
- besides of the feature extractor
- DCTN also includes a (multi-source) category classifier to predict the class from different sources
- and a (multi-source) domain discriminator to produce multiple source-target-specific perplexity scores as the approximation of source distribution weights.
during training, two alternating adaptation steps:
-
domain discriminator: The multi-way adversarial adaptation implicitly reduces domain shifts among
those sources.- deploys multi-way adversarial learning to minimize the discrepancy between the target and each of the multiple source domains,
- also predict the source-specific perplexity scores to denote the possibilities that a target sample belongs to different source domains.
-
feature extractor and the category classifier
- The multi-source category classifiers are integrated with the perplexity scores to classify target
sample, and the pseudo-labeled target samples together with source samples are utilized to update the multi-source category classifier and the feature extractor
- The multi-source category classifiers are integrated with the perplexity scores to classify target
2. Motivation
This paper focuses on the problem of multi-source domain adaptation, where there is category shift between diverse sources.
Category shift is a new protocol in MDA, where domain shift and categorical disalignment co-exist among the sources.
This paper aims at domain shift and category shift all together.
3. Challenges /Problem to be solved
- cannot simply apply same UDA via combining all source domains since there are possible domain shifts among sources
- eliminate the distribution discrepancy between target and each source maybe too strict, and harmful.
- category shift in sources
4. Contribution
-
- We present a novel and realistic MDA protocol termed category shift that relaxes the requirement on the shared category set among any source domains.
-
- Inspired from the distribution weighted combining rule, we proposed the deep cocktail network (DCTN) together with the alternating adaptation algorithm to learn transferable and discriminative representation.
-
- We conduct comprehensive experiments on three well-known benchmarks, and testify our model in both the vanilla and the category shift settings. Our method has achieved the state of the art across most transfer tasks.
5. Related work
5.1 Unsupervised domain adaptation with single source
- domain discrepancy based methods: reduce the domain shift across the source and the target
- domain discrepancy based methods
- deep-model-based
- adversarial learning based
- others: semi-supervised method [42], domain reconstruction [14], duality [19], alignments [9] [50] [44], manifold learning [15], tensor methods [24],[31], etc.
5.2 Domain adaptation with multiple sources
- originates from A-SVM[49]
- shallow models[8] [22] [27]
- theoretical
- learning bound for multi source DA[3]
- distribution weighted combining rule[33]
5.3 two branches of transfer learning closely relate to MDA (supervised)
- continual transfer learning (CTL) [43] ,[39].
- CTLs train the learner to sequentially master multiple tasks across multiple domains.
- domain generalization (DG)
- uses the existing multiple labeled domains for training regardless of the unlabeled target samples.[13, 35]
6. Settings
-
Suppose the classifier for each source domain is known
-
Vanilla MDA: samples from diverse sources share a same category set
-
Category Shift: categories from different sources might be also different
-
different underlying source distributions
-
1 target distribution , no label
-
training set ensemble: datasets
-
testing set: from target distribution
-
target domain get labeled by the union of all categories in those sources
7. Compared with Open Set DA
-
The uncommon classes are unified as a negative category called “unknown”.
-
In contrast, category shift consider the specific disaligned categories among multiple sources to enrich the classification in transfer.
8. DCTN: framework details
8.1 Feature extractor
- deep convolution nets as the backbone
- share weights: map all images from N sources and target into a common feature space
- employ adversarial learning to obtain the optimal mapping
- because it can successfully learn both domain-invariant features and each target-source-specific relations.
8.2 (Multi-source) domain discriminator
-
source-specific discriminators:
-
Given image from the source or the target domain, the domain discriminator receives the features , classifies whether from the source or the target
-
for the data flow from each target instance , the domain discriminators yields the source-specific discriminative results
-
target-source perplexity scores
is the source-specific concentration constant, It is obtained by averaging the source discriminator losses over .in supplementary, different score, different :
denotes how many times the target samples have been visited to train our model
denotes the source j instance come along with the coupled target instances in the adversarial learning.
8.3 (Multi-source) category classifier
-
a multi-output net composed by source-specific predictors
-
Each predictor is softmax classifier
-
for the image from source : only the value from get activated and provides the gradient for training
-
For a target image instead, all source-specific predictors provide categorization results to the target classification operator.
8.4 Target classification operator
-
for each target feature , the target classification operator takes each source perplexity score to re-weight the corresponding source-specific prediction
the confidence belongs to presents as
- denotes the softmax value of source corresponding to class
- means only those sources with class can join the perplexity score weighting.
- means all the sources
8.5 Connection to distribution weighted combining rule
- In the distribution weighted combining rule [33], the target distribution is treated as a mixture of the multi-source distributions with the coefficients by normalized source distributions weighted by unknown positive , namely
note that the hypothesis is one-dimension output
- in this paper
The ideal target classifier presents as the weighted combination of source classifiers.
Note that here each classifier for each source is a multi output softmax result.
with the increase of the probability that from source ,
so
- 所以用score代替了distribution的weighting
- target images should be categorized by the classifiers from multiple sources, with whose features more similar to target, the source classifiers’ prediction are more trustful
9 Learning
9.1 Pre-training C and F
-
take all source images to jointly train the feature extractor F and the category classifier C
-
pseudo label for target: Those networks and the target classification operator then predict categories for all target images and annotate those with high confidences.
-
Since the domain discriminator hasn’t been trained, we take the uniform distribution simplex weight as the perplexity scores to the target classification operator.
-
Finally, we obtain the pre-trained feature extractor and category classifier via further fine-tuning them with sources and the pseudo-labeled target images.
In object recognition, we initiate our DCTN by following the same way of DAN (start with an AlexNet model pretrained on ImageNet 2012 and fine-tune it).
In terms of digit recognition, we perform DCTN learning from scratch.
9.2 Multi-way Adversarial Adaptation
ref: ADDA論文Adversarial Discriminative Domain Adaptation
- original GAN:(M means mapping / feature extractor)
-
change method 1: early on during training the discriminator converges quickly, causing the gradient to vanish, change the generator objective, splits the optimization into two independent objectives, one for the generator and one for the discriminator,
-
change method 2: in the setting where both distributions are changing, this objective will lead to oscillation–when the mapping converges to its optimum, the discriminator can simply flip the sign of its prediction in response.
Tzeng et al. instead proposed the domain confusion objective, under which the mapping is trained using a cross-entropy loss function against a uniform distribution
This loss ensures that the adversarial discriminator views the two domains identically.
confuse就是要讓它“半信半疑”,讓source和target經過mapping的marginal distribution儘量接近。來自source和target的可能性都接近一半(或者說相當於source和target中的樣本的真實domain標籤都是來自1和0的可能性佔一半,這樣最小化這個差異的交叉熵損失函數,得到的mapping後的source和target分佈就都是接近均均分佈,可以認爲source和target被map成很相似的domain,DA的任務就完成了)
注意:其實*式子在ADDA論文中結果沒有用,只是用來說明related work,ADDA中用的還是(**);
*式子是 Simultaneous Deep Transfer Across Domains and Tasks 文中提出來的;
ADDA論文改了generator的優化目標爲**。
in this paper
-
minmax adversarial domain adaptation
- classifier is fixed as to provide stable gradient values.
- the first term denotes our adversarial mechanism
- the second term is a multi-source classification losses.
The optimization based on Eq.4 works well for but not .
Since the feature extractor learns the mapping from the multiple sources and the target, the domain distributions become simultaneously changing in adversary, which results in an oscillation then spoils our feature extractor.
when source and target feature mappings share their architectures, the domain confusion can be introduced to replace the adversarial objective, which performs stable to learn the mapping .
-
multidomain confusion loss
where
i.e.
和(*)的差別在於:-
沒有負號
-
是multi source所以有N個discriminator,每個對應一個source和target的域判別
-
*中是source和target的mapping不一樣,這裏是feature extractor一樣
-
本文中直接修改成了*是discriminator和generator公用的loss function(的相反數,因爲爲負數),表示的是target和每個source之間
交叉熵表示的是兩個分佈之間的差異,注意交叉熵一定是正數的結果
- 最大化6式,等價於最小化交叉熵損失,就是最優化discriminator
-
Online hard domain batch mining
-
samples from different sources are sometimes useless to improve the adaptation to the target, and as the training proceeds, more redundant source samples turn to draw back the whole model performance
-
minibatch: sample batch for target and each source domain
-
Each source target discriminator ‘s loss is viewed as the degrees to distinguish from the th source’ s samples
這裏是交叉熵損失,是最原始GAN的形式。越大表示損失越大,表示對M個source樣本和M個target樣本的來自source 還是target domain的區分效果越差,即這個source 的discriminator效果不好。
- find hard source domain: feature extractor performs the worst to transform the target samples to confuse the th source
-
we use the source and the target samples in the minibatchto train the feature extractor
-
以下是用於迭代更新、找到最好的feature extractor的算法1
9.3 Target Discriminative Adaptation
-
Aided by the multi-way adversary, DCTN has been able to obtain good domain-invariant features, yet not surely classifiable in the target domain.
-
auto-labeling strategy: annotate target samples, then jointly train our feature extractor and multi-source category classifier with source and target images by their (pseudo-) labels
-
classification losses from multiple source images and target images with pseudo labels
apply the target classification operator to assign pseudo labels, and the samples with the confidence higher than a preseted threshold will be selected into .
given a target instance with pseudo-labeled class , we find those sources include this class , then update our network via the sum of the multi-source classification losses
10. Experiments
10.1 Benchmarks
- 3 widely used UDA benchmarks
- Office-31 [41]:
- a object recognition benchmark with 31 categories and 4652 images unevenly spread in three visual domains A (Amazon), D (DSLR), W (Webcam).
- ImageCLEF-DA:
- 50 images in each category
- totally 600 images for each domain
- derives from ImageCLEF 2014 domain adaptation challenge, and is organized by selecting 12 object categories (aeroplane, bike bird, boat, bottle, bus, car, dog, horse, monitor, motorbike, and people) shared in the three famous real-world datasets, I (ImageNet ILSVRC 2012), P (Pascal VOC 2012), C (Caltech-256).
- Digits-five
- five digit image sets respectively sampled from following public datasets
- mt (MNIST) [26]
- mm (MNIST-M) [11]
- sv(SVHN) [36]
- up (USPS)
- sy (Synthetic Digits) [11].
- Towards the images in MNIST, MNIST-M, SVHN and Synthetic Digits, we draw 25000 for training and 9000 for testing in each dataset.
- There are only 9298 images in USPS, so we choose the entire dataset as our domain.
- five digit image sets respectively sampled from following public datasets
- Office-31 [41]:
10.2 Evaluations in the vanilla(common) setting
baseline
-
mullti source: two shallow methods
- sparse FRAME (sFRAME) [46]
- a non-stationary Markov random field model that reproduces the observed statistical properties of filter responses at a subset of selected locations, scales and orientations.
- representing a wide variety of object patterns in natural images and that the learned models are useful for object classification.
- SGF [16]
- Motivated by incremental learning, we create intermediate representations of data between the two domains by viewing the generative subspaces (of same dimension) created from these domains as points on the Grassmann manifold, and sampling points along the geodesic between them to obtain subspaces that provide a meaningful description of the underlying domain shift.
- sparse FRAME (sFRAME) [46]
-
single source models----> multi source: conventional (TCA, GFK)/ deep
Since those methods perform in single-source setting, we introduce two MDA standards for different purposes
- Source combine: all source domains are combined into a traditional single-source v.s. target setting.
- The first standard testify whether the multi-source is valuable to exploit
- Single best: in the multi-source domains, we report the single source transfer result best-performing in the test set.
- whether we can further improve the best single source UDA via introducing another source transfer.
- Source combine: all source domains are combined into a traditional single-source v.s. target setting.
-
source only
- as baselines in the Source combine and multisource standards
- use all images from sources to train backbone-based multi-source classifiers and directly apply them to classify target images
10.3 Evaluations in the category shift setting
-
depart all categories into two non-overlapped class sets and define them as the private classes
- overlap
- disjoint
-
DAN also suffers negative transfer gains in most situations, which
indicates the transferbility of DAN cripled in the category
shift. -
In contrast, DCTN reduces the performance drops compared to the model in the vanilla setting, and obtains positive transfer gains in all situations. It reveals that DCTN can resist the negative transfer caused by the category shift
11. Further Analysis
11.1 Feature visualization.
visualize the DCTN activations before and after adaptation.
- DCTN can successfully learn transferable features with multiple sources
- features learned by DCTN attains desirable discriminative property
11.2 Ablation study
-
The adversarial-only model excludes the pseudo labels and updates the category classifier with source samples.
-
The pseudo-only model forbids the adversary and categorize target samples with average multi-source results
-
without domain batch mining technique
11.3 Convergence analysis
despite of the frequent deviation, the classification loss, adversarial loss and testing error gradually converge.