February 09, 2024
Source-free domain adaptation (SFDA) alleviates the domain discrepancy among data obtained from domains without accessing the data for the awareness of data privacy. However, existing conventional SFDA methods face inherent limitations in medical contexts, where medical data are typically collected from multiple institutions using various equipment. To address this problem, we propose a simple yet effective method, named Uncertainty-aware Adaptive Distillation (UAD) for the multi-source-free unsupervised domain adaptation (MSFDA) setting. UAD aims to perform well-calibrated knowledge distillation from (i) model level to deliver coordinated and reliable base model initialisation and (ii) instance level via model adaptation guided by high-quality pseudo-labels, thereby obtaining a high-performance target domain model. To verify its general applicability, we evaluate UAD on two image-based diagnosis benchmarks among two multi-centre datasets, where our method shows a significant performance gain compared with existing works. The code will be available soon.
Unsupervised Domain Adaptation, Multi-source-free, Uncertainty-ware
Unsupervised domain adaptation (UDA) is a promising streamline of works to compensate for the distributional discrepancy[1]. It seeks to utilise existing transferable knowledge from labelled data drawn from one or more source domains to recognise unlabelled data in the target domain[2]. UDA has shown great success in a broad spectrum of downstream applications, including classification [3] [4] [5], segmentation [6] [7] [8] and object detection [9] [10] by mitigating this domain shift.
Despite its great promises in general visual perception tasks, existing UDA approaches inherently fall short in medical scenarios where additional regulations on data sharing restrictions. To address the problems on medical images, source-free DA methods [4] have been developed, providing the pre-trained source model only instead of directly accessing the source data to preserve the privacy issue.
In this work, we investigate multi-source-free unsupervised domain adaptation (MSFDA) [11] [12] and improve the typical SFDA settings [4] [13] by introducing multiple source domains. It therefore holds the potential to serve as an appealing solution for real-world large-scale medical image analysis studies involving multiple centres. Several recent efforts have been made [12] [14] with preliminary attempts to the self-supervised clustering pseudo-labelling method [15], which is commonly adopted for MSFDA. However, they tend to be suboptimal particularly for medical image processing. Since the distinctions of the data from multiple centres are large, the models trained on datasets derived from single or multiple healthcare institutions have not demonstrated a consistent ability to generalise their applicability to external sites [16].
To transcend the aforementioned bottlenecks, in this paper, we propose a framework for MSFDA for medical image analysis. Our contributions include:
1) We propose a novel algorithm termed as Uncertainty-aware Adaptive Distillation (UAD). Our algorithm first recognises the source model with the most comparable underlying data distribution to the target domain to deliver coordinated model initialisation, and then further leverages the complementary knowledge among source models for precise distillation to the target domain; 2) To avoid over- and under-confidence issues, we apply the Temperature Scaling (TS) method for comprehensive confidence calibration over source models towards a well-regulated knowledge distillation procedure; 3) We substantiate the effectiveness of the proposed method by comparison experiments and ablation studies across diverse scenarios, demonstrating its practical benefits towards various endpoints with clinical significance.
Figure 1: Overview of the proposed framework. Our framework follows a multi-source domain model pre-training process with a two-stage uncertainty-aware adaptive distillation (UAD) process of model initialisation and pseudo-labelling.
Without involving any source domain data in training the final model, we aim to transfer a series of models, pre-trained on multiple source domains, to a new target domain without any human annotation. In this work, we will consider the \(K\)-way classification-model adaptation. We are given a source model zoo \(\{\theta_S^{j}\}_{j=1}^{N}\), which contains \(N\) source classification models from \(N\) source domains. For the \(j\)-th source model \(\theta_S^{j}\) in the source model zoo, with the input space being \(\mathcal{X}\) and the output space being \(\mathcal{Y}\), it is learned by the source dataset \(\mathcal{D}_S^j=\{x_{S_j}^i, y_{S_j}^i\}_{i=1}^{n_j}\) with \(n_j\) instances, where \(x_{S_j}^{i} \in \mathcal{X}_{S_j}\), \(y_{S_j}^{i}\in \mathcal{Y}_{S_j}\). A target classification model \(\theta_T: \mathcal{X} \rightarrow \mathbb{R}^K\) is learned by only \(\{\theta_S^{j}\}_{j=1}^{N}\) and the unlabelled target domain dataset \(\mathcal{D}_T=\{x^i_T\}_{i=1}^{n_T}\) with \(n_T\) instances.
In the proposed framework, we transfer the knowledge from multiple source models to adapt the target domain with pseudo-labels generated by distilling the proper source model. Technically, we learn a set of uncertainty (or its opposite, confidence) measures for both overall domain-wise and individual instance-wise distillation corresponding to each source model in the source model zoo. It evaluates the distributional distance of certain source models working on the target domain dataset and the quality of pseudo-labelling. Specifically, we introduce margin, defined as the difference between the predicted probabilities of the first and second most probable classes [17], as the metric to estimate the confidence measure:
\[\begin{align} \mathcal{M} = \mathop{\mathrm{Topk}}_{k=1}(\delta(\theta(x))) - \mathop{\mathrm{Topk}}_{k=2}(\delta(\theta(x))),\; \label{margin} \end{align}\tag{1}\]
where \(\delta(\cdot)\) denotes the Softmax Layer operation with \(\delta_j(v)=\frac{\exp(v_j)}{\sum_{i=1}^K \exp(v_i)}\) for \(j = 1,...,K\) and \(v \in \mathbb{R}^K\mapsto(0,1)^K\). Intuitively, if a model \(\theta\) has a larger value of the margin \(\mathcal{M}\) while predicting an instance, it is regarded as more optimal to extract the instance’s feature and finally does the classification task.
In order to prevent the trained target domain model from being interrupted by confounding factors incurred by attributed irrelevant to the target task (e.g., image appearance discrepancy due to inconsistent imaging protocols) or avoid local minima problems, we propose to perform Uncertainty-aware Adaptive Distillation (UAD) from two complementary perspectives, (i) model-level and (ii) instance-level, towards directed and well-regularised multi-source model adaptation. The overview of our proposed framework is illustrated in Fig. 1.
Model-level UAD: In previous work related to multi-source domain adaptation[11], it was a common practice to involve all source models with varying weights in the subsequent fine-tuning stage. However, we found that if there is a significant domain gap between a particular source model and the target domain, negative transfer [18] could be incurred which results in biased adaptation. Thus, to initialise a base target model with minimal disturbance, we collect all pre-trained source models from each domain and estimate the overall confidence measure of each source model for predicting the target domain data. Specifically, for assessing the confidence of a source model \(\theta_S^j\)’s inference results on the target domain data, we average all confidence measures estimated for each instance of the target domain data as follows: \(\mathcal{M}_{\text{j}} = \frac{\sum_{i=1}^{n_T} \mathcal{M}_{\text{i}}}{n_T}\). The source model with the largest confidence measure which is defined as \(\varepsilon\) for the target domain, \(\theta_S^\ast\), is regarded as the model conforming to the underlying data distribution closest to the target domain and can be considered as the optimal teacher:
\[\begin{align} \varepsilon = \text{arg} \;{\text{max}}([\mathcal{M}_{\text{j}}]_{j=1}^{N}).\; \label{pick95margin} \end{align}\tag{2}\]
We assign the source model \(\theta_S^\ast\) as the initial model for SFDA learning on the target data to minimise the gap between the multiple source domains and the target domain.
Instance-level UAD: As the target domain data are not annotated, we propose to use the instance-level UAD method for self-supervised learning on the target data with pseudo labels. Specifically, we sequentially estimate the confidence measure (margin) of each model in the source model zoo for predicting each instance \(x^i_T\), for \(i=1,...,n_T\), in the target domain and select the most confident source model to generate the pseudo-label:
\[\begin{align} \varepsilon_i = \text{arg} \;{\text{max}}([\mathcal{M}_{\text{i}}]_{i=1}^{n_T}),\; \label{single95margin} \end{align}\tag{3}\]
where \(\mathcal{M}_{\text{i}}\) denotes the margin values of source models predicting the target domain instance with:
\[\begin{align} \mathcal{M}_{\text{i}} = \Big[\mathop{\mathrm{Topk}}_{k=1}(\delta(\theta_S^j(x^i_T))) - \mathop{\mathrm{Topk}}_{k=2}(\delta(\theta_S^j(x^i_T))) \Big]_{i=1,j=1}^{n_T,N}. \end{align}\]
For the instance \(x^i_T\), the corresponding pseudo-label is obtained by prediction of the source model with \(\mathcal{M}_{\text{i}}=\varepsilon_i\), which we define as \(\theta_T^i\): \(\hat{y}^i_T = \Big[\theta_T^i(x^i_T) \Big]_{i=1}^{n_T}\). \(\{x^i_T, \hat{y}^i_T\}_{i=1}^{n_T}\) is leveraged to fine-tune the target initial model \(\theta_T = \theta_S^\ast\) by minimising the standard cross-entropy loss:
\[\mathcal{L}_{tar} = -\mathbb{E}_{(x_{T},\hat{y}_{T})\in \mathcal{X}_{T} \times \mathcal{\hat{Y}}_{T}} \sum\nolimits_{k=1}^{K} \mathbb{1}_{[k=\hat{y}_{T}]} \log \delta_k(\theta_T(x_{T})), \label{overall95loss}\tag{4}\]
where \(\mathbb{1}(\cdot)\) gives value \(1\) when the argument is true.
In certain models, domain shift and limited data in source domains may result in over- and under-confidence in predicting target domain data which potentially triggers a mismatch between model prediction accuracy and confidence [19]. In other words, when this phenomenon occurs, the confidence measure \(\varepsilon\) will no longer be an optimal measure for improving model prediction accuracy.
To address this problem, we embedded Temperature Scaling (TS) which acts on prediction probabilities to calibrate the logits prior to confidence measurement. In our approach, TS is capable of effectively regularising the representation of uncertainty in model predictions, and a more precise and unbiased representation of uncertainty is preferable for the process of knowledge distillation. The parameter \(\mathcal{T}\) is the so-called temperature, which yields softer probability estimates with larger a temperature to alleviate over-confidence in the model. For every source model \([\theta_S^{j}]_{j=1}^{N}\), we learn \(\mathcal{T}_{\text{j}}\) by setting an initialisation value \(\mathcal{T}_\text{initial}\) and applying temperature scaling on the target domain data \(\mathcal{D}_T\): \(\mathcal{T}_{\text{j}} = \text{TS-Alg}([\theta_S^{j}]_{j=1}^{N}, \mathcal{D}_T)\). Specifically, the temperature scaling models are tuned by minimising expected calibration error (ECE), a.k.a., calibration gap, which is defined as the difference between accuracy and confidence for a given bin [20]:
\[\begin{align} \text{ECE} = \sum_{m=1}^{M}\frac{|B_m|}{n_T}\Big|\text{acc}(B_m) - \text{conf}(B_m)\Big|,\; \label{ece} \end{align}\tag{5}\]
where \(M\) denotes the number of interval bins that we group predictions, and \(B_m\) represents the batch of indices of instances allocated in the interval \(I_m = (\frac{m-1}{M},\frac{m}{M}]\).
Given the logit vector \(\theta_S^j(x^i_T)\) obtained from each source model, the calibrated probabilities are estimated by the formula: \(z_j = \theta_S^j(x^i_T) / \mathcal{T}_{\text{j}}\), where \(z_j\) is the calibrated pre-softmax output (logits) that will be utilised in Sec. 2.2.
Datasets: We evaluate the proposed multi-source-free domain adaptation framework for classification tasks on two series of datasets:
In our experimental process, we reprocess the data by first resizing into \(256\times256\) and cropping into size \(224\); then, we assign one domain as the target in turn while considering the others as source domains.
1pt
Method | DR | HAM10000 | |||||||
---|---|---|---|---|---|---|---|---|---|
2-5 (lr)6-10 | D, I \(\rightarrow\) A | A, I \(\rightarrow\) D | A, D \(\rightarrow\) I | AVG. | F, L, U \(\rightarrow\) B | B, L, U \(\rightarrow\) F | B, F, U \(\rightarrow\) L | B, F, L \(\rightarrow\) U | AVG. |
AaD (22’) [13] | 36.13 | 33.07 | 46.32 | 38.51 | 64.55 | 64.30 | 65.14 | 72.36 | 66.59 |
DECISION (21’) [11] | 57.32 | 45.43 | 58.33 | 53.69 | 74.27 | 76.24 | 71.06 | 78.98 | 75.14 |
CAiDA (21’) [12] | 71.74 | 44.98 | 50.97 | 55.90 | 73.68 | 73.83 | 79.59 | 78.80 | 76.48 |
M-UAD | 71.49 | 62.03 | 50.39 | 61.30 | 81.84 | 68.19 | 87.48 | 83.27 | 80.20 |
I-UAD | 72.91 | 63.71 | 53.10 | 63.24 | 84.58 | 69.66 | 88.78 | 83.09 | 81.53 |
M-UAD + I-UAD | 74.47 | 64.39 | 53.88 | 64.25 | 85.40 | 71.41 | 89.41 | 84.08 | 82.58 |
M-UAD + I-UAD + TS | 74.52 | 65.27 | 58.72 | 66.17 | 85.40 | 73.29 | 89.70 | 84.44 | 83.21 |
Implementation Details: Following the top-rank solution for medical image classification [25], we employ DenseNet-121 as the backbone. In the source model training process, we use smooth labels instead of the usual one-hot labels to reduce overfitting and label noise. The maximum number of epochs \(\mathcal{N}_\text{epoch}\) for both DR and HAM10000 datasets is set to \(100\); while during the UAD process, the \(\mathcal{N}_\text{epoch}\) is set to \(15\) with a series of updated pseudo-labels at the start of each. The batch size is set to 32. For each epoch, there are \(\mathcal{N}_\text{training data}/32\) iterations in domains. We use \(\mathcal{T}_\text{initial}=\log{(1/1.5)}\) and \(1.5\) for the DR dataset and the HAM10000 dataset, respectively. For both source models pre-training and adaptive distillation, we leverage stochastic gradient descent with momentum value \(0.9\) and weight decay \(10^{-3}\), with the learning rate scheduling method [3] during the model learning progress.
For experimental comparison, we included one existing SFDA framework AaD [13] with multi-source extension and two MSFDA frameworks DECISION [11] and CAiDA [12] as baseline methods. We re-implement them following their default settings. The experimental results are reported in Table [comparison95experiments]. The multi-source extension of AaD is implemented via an ensemble that passes the target data through each of the adapted source model and takes an average of the soft prediction to obtain the test label. By exploring the experimental results of iterations during the SFDA process for DECISION, we noted that, except for the target domain I in DR and F in HAM10000, the performance of the DECISION model deteriorates as the iterations increase for training the target model. This phenomenon is also observed in the CAiDA framework, although the degradation in model performance in the domain adaptation process is not as severe as in the DECISION framework. Intuitively, in a domain-biased and unsupervised setting, the model overfits to noisy labels when training on the target data. It is due to the effect of the involvement of inappropriate source models and low-quality pseudo-labels generated.
In comparison with existing frameworks, our proposed method effectively mitigates both factors that could potentially diminish the performance of the target domain model: we identify the most confident source model, excluding inappropriate ones from participating in the training of the target model, and generate the most reliable pseudo-labels through the optimal source model. The last row in Table [comparison95experiments] shows that the average accuracy of domain adaptation via UAD (our method) in both datasets significantly outperforms all the baselines.
Furthermore, we also performed an ablation study on the domain adaptation process: the model-level UAD only without training implementation, the instance-level UAD only without training implementation, and the model-level and instance-level UAD with training but without temperature scaling.
Effectiveness on Model-level and Instance-level UAD: To avoid inappropriate source model(s), which are learned by the source domain data that deviates significantly from the target domain data distribution, from disrupting the final performance of the target domain model, we first propose the exclusion of such disruptive source model(s) during the training process. Instead, using the model-level UAD (M-UAD) method, we pick the most confident source model, which is also the optimal choice among existing models, to serve as the initialisation of training the target model process. This establishes a solid foundation in the early stages of model training. The first row of the ablation study (M-UAD) in Table [comparison95experiments] demonstrates the result that implementing only M-UAD leads to an improvement of approximately \(5\%\) on average compared to the baseline results.
In an unsupervised learning setting, the generation of pseudo-labels is a crucial step in driving the eventual high-performance model. Instead, the generation of low-quality pseudo-labels leads the target model to gradually fit into these noisy labels, thereby reducing the final performance of the target model. To prevent this from occurring, we propose using the instance-level UAD (I-UAD) method to identify the most confident label corresponding to an individual instance as its pseudo-label. The second row of the ablation study (I-UAD) in Table [comparison95experiments] gives the experimental result that applying the I-UAD method leads to a higher accuracy for the target model compared to the M-UAD approach.
The third row of the ablation study (M-UAD + I-UAD) in Table [comparison95experiments] gives the experimental result that the performance can be further improved by jointly applying the two-level UAD.
Effectiveness on Temperature Scaling: According to Sec. 2.3, to mitigate the problem of over- and under-confidence in certain model(s) predicting the target domain data, TS is an effective method to calibrate the model. The last row of Table [comparison95experiments] gives the experimental result of applying the TS approach to our combined UAD framework, showing an improvement in the average accuracy compared to without applying the TS model calibration method. This effect is particularly pronounced on some target domains with relatively low accuracy, such as domains I and F in the DR and HAM10000 datasets respectively.
In this study, we proposed a two-level uncertainty-aware adaptive distillation method termed UAD, a novel deep learning framework for multi-source-free unsupervised domain adaptation on medical imaging data, with successful application on datasets across diseases and human anatomical regions. Both initialising the target domain training process by identifying the optimal source model and generating reliable pseudo-labels by leveraging a post-calibrated source model zoo, our method significantly outperforms the existing frameworks performing on the medical imaging data. In conclusion, our proposed method can fill the gap in the MSFDA setting in the field of medical image processing and analysis.
This research study was conducted retrospectively using human subject data made available in open access by [21]–[24]. Ethical approval was not required as confirmed by the license attached with the open access data.