August 16, 2022
Diffusion Denoising Probability Models (DDPM) [1] and Vision Transformer (ViT) [2] have demonstrated significant progress in generative tasks and discriminative tasks, respectively, and thus far these models have largely been developed in their own domains. In this paper, we establish a direct connection between DDPM and ViT by integrating the ViT architecture into DDPM, and introduce a new generative model called Generative ViT (GenViT). The modeling flexibility of ViT enables us to further extend GenViT to hybrid discriminative-generative modeling, and introduce a Hybrid ViT (HybViT). Our work is among the first to explore a single ViT for image generation and classification jointly. We conduct a series of experiments to analyze the performance of proposed models and demonstrate their superiority over prior state-of-the-arts in both generative and discriminative tasks. Our code and pre-trained models can be found in https://github.com/sndnyang/Diffusion_ViT.
Discriminative models and generative models based on the Convolutional Neural Network (CNN) [3] architectures, such as GAN [4] and ResNet [5], have achieved state-of-the-art performance in a wide range of learning tasks. Thus far, they have largely been developed in two separate domains. In recent years, ViTs have started to rival CNNs in many vision tasks. Unlike CNNs, ViTs can capture the features from an entire image by self-attention, and they have demonstrated superiority in modeling non-local contextual dependencies as well as their efficiency and scalability to achieve comparable classification accuracy with smaller computational budgets (measured in FLOPs). Since the inception, ViTs have been exploited in various tasks such as object detection [6], video recognition [7], multi-modal pre-training [8], and image generation [9], [10]. Especially, VQ-GAN [11], TransGAN [9] and ViTGAN [10] investigate the application of ViT in image generation. However, VQ-GAN is built upon an extra CNN-based VQ-VAE, and the latter two require two ViTs to construct a GAN for generation tasks. Therefore we ask the following question: is it possible to train a generative model using a single ViT?
DDPM is a class of generative models that matches a data distribution by learning to reverse a multi-step diffusion process. It has recently been shown that DDPMs can even outperform prior SOTA GAN-based generative models [12]–[14]. Unlike GAN which needs to train with two competing networks, DDPM utilizes a UNet [15] as a backbone for image generation and is trained to optimize maximum likelihood to avoid the notorious instability issue in GAN [13], [16] and EBM [17], [18].
In this paper, we establish a direct connection between DDPM and ViT for the task of image generation and classification. Specifically, we answer the question whether a single ViT can be trained as a generative model. We design Generative ViT (GenViT) for pure generation tasks, as well as Hybrid ViT (HybViT) that extends GenViT to a hybrid model for both image classification and generation. As shown in Fig 2 and 3, the reconstruction of image patches and the classification are two routines independent to each other and train a shared set of features together.
Our experiments show that HybViT outperforms previous state-of-the-art hybrid models. In particular, the Joint Energy-based Model (JEM), the previous state-of-the-art proposed by [18], [19], requires extremely expensive MCMC sampling, which introduce instability and causes the training processes to fail for large-scale datasets due to the long training procedures required. To the best of our knowledge, GenViT is the first model that utilizes a single ViT as a generative model, and HybViT is a new type of hybrid model without the expensive MCMC sampling during training. Compared to existing methods, our new models demonstrate a number of conceptual advantages [17]: 1) Our methods provide simplicity and stability similar to DDPM, and are less prone to collapse compared to GANs and EBMs. 2) The generative and discriminative paths of our model are trained with a single objective which enables sharing of statistical strengths. 3) Advantageous computational efficiency and scalability to growing model and data sizes inherited from the ViT backbone.
Our contributions can be summarized as following:
We propose GenViT, which to the best of our knowledge, is the first approach to utilize a single ViT as an alternative to the UNet in DDPM.
We introduce HybViT, a new hybrid approach for image classification and generation leveraging ViT, and show that HybViT considerably outperforms the previous state-of-the-art hybrid models on both classification and generation tasks while at the same time optimizes more effectively than MCMC-based models such as JEM/JEM++.
We perform comprehensive analysis on model characteristics including adversarial robustness, uncertainty calibration, likelihood and OOD detection, comparing GenViT and HybViT with existing benchmarks.
We first review the derivation of DDPM [1]. DDPM is built upon the theory of Nonequilibrium Thermodynamics [20] with a few simple yet effective assumptions. It assumes diffusion is a noising process \(q\) that accumulates isotropic Gaussian noises over timesteps (Figure 1).
Starting from the data distribution \(\vec{x}_0 \sim q(\vec{x}_0)\), the diffusion process \(q\) produces a sequence of latents \(\vec{x}_1\) through \(\vec{x}_T\) by adding Gaussian noise at each time \(t \in [0, \cdots, T-1]\) with variance \(\beta_t \in (0,1)\) as follows: \[\begin{align} {2} q(\vec{x}_1, ..., \vec{x}_T | \vec{x}_0) &\mathrel{\vcenter{:}}= \prod_{t=1}^{T} q(\vec{x}_t | \vec{x}_{t-1}) \tag{1} \\ q(\vec{x}_t | \vec{x}_{t-1}) &\mathrel{\vcenter{:}}= \mathcal{N}(\vec{x}_t; \sqrt{1-\beta_t} \vec{x}_{t-1}, \beta_t \mathbf{I}) \tag{2} \end{align}\]
Then, the process in reverse aims to get a sample in \(q(\vec{x}_0)\) from sampling \(\vec{x}_T \sim \mathcal{N}(0, \mathbf{I})\) by using a neural network: \[\label{eq:nn} p_{\theta}(\vec{x}_{t-1}|\vec{x}_t) \mathrel{\vcenter{:}}= \mathcal{N}(\vec{x}_{t-1}; \mu_{\theta}(\vec{x}_t, t), \Sigma_{\theta}(\vec{x}_t, t))\tag{3}\]
With the approximation of \(q\) and \(p\), DDPM gets a variational lower bound (VLB) as follows: \[\begin{align} \label{eq:elbo} \log p_{\boldsymbol{\theta}(\boldsymbol{x}_0)} & \geq \log p_{\boldsymbol{\theta}(\boldsymbol{x}_0)} - D_{KL}\infdivx{ q(\boldsymbol{x}_{1:T} | \boldsymbol{x}_0)}{p_{\boldsymbol{\theta}} (\boldsymbol{x}_{0:T}) } \nonumber \\ & = - \mathbb{E}_q \left[ \frac{q(\boldsymbol{x}_{1:T} | \boldsymbol{x}_0)}{p_{\boldsymbol{\theta}} (\boldsymbol{x}_{0:T})} \right] \end{align}\tag{4}\]
Then they derive a loss for VLB as: \[\begin{align} {2} L_{\text{vlb}} &= L_0 + L_1 + ... + L_{T-1} + L_T \tag{5} \\ L_{0} &= -\log p_{\theta}(\vec{x}_0 | \vec{x}_1) \tag{6} \\ L_{t-1} &= D_{KL}\infdivx{q(\vec{x}_{t-1}|\vec{x}_t,\vec{x}_0)}{p_{\theta}(\vec{x}_{t-1}|\vec{x}_t)} \tag{7} \\ L_{T} &= D_{KL}\infdivx{q(\vec{x}_T | \vec{x}_0)}{p(\vec{x}_T)} \tag{8} \end{align}\] where \(L_0\) is modeled by an independent discrete decoder from the Gaussian \(\mathcal{N}(\vec{x}_{0}; \mu_{\theta}(\vec{x}_1, 1), \sigma_{1}^{2} \vec{I})\), and \(L_T\) is constant and can be ignored.
As noted in [1], the forward process can sample an arbitrary timestep \(\boldsymbol{x}_t\) directly conditioned on the input \(\boldsymbol{x}_0\) in a closed form. With the nice property, we define \(\alpha_t \mathrel{\vcenter{:}}= 1 - \beta_t\) and \(\bar{\alpha}_t \mathrel{\vcenter{:}}= \prod_{s=0}^{t} \alpha_s\). Then we have \[\begin{align} {2} q(\vec{x}_t|\vec{x}_0) &= \mathcal{N}(\vec{x}_t; \sqrt{\bar{\alpha}_t} \vec{x}_0, (1-\bar{\alpha}_t) \mathbf{I}) \tag{9} \\ \vec{x}_t &= \sqrt{\bar{\alpha}_t} \vec{x}_0 + \sqrt{1-\bar{\alpha}_t} \epsilon \tag{10} \end{align}\] where \(\epsilon\!\!\sim\!\!\mathcal{N}(0,\mathbf{I})\) using the reparameterization. Then using Bayes theorem, we can calculate the posterior \(q(\vec{x}_{t-1}|\vec{x}_t,\vec{x}_0)\) in terms of \(\tilde{\beta}_t\) and \(\tilde{\mu}_t(\vec{x}_t,\vec{x}_0)\) as follows: \[\begin{align} q(\vec{x}_{t-1}|\vec{x}_t,\vec{x}_0) &= \mathcal{N}(\vec{x}_{t-1}; \tilde{\mu}(\vec{x}_t, \vec{x}_0), \tilde{\beta}_t \mathbf{I}) \tag{11} \\ \tilde{\mu}_t(\vec{x}_t,\vec{x}_0) &\mathrel{\vcenter{:}}=\! \frac{\sqrt{\bar{\alpha}_{t-1}}\beta_t}{1-\bar{\alpha}_t}\vec{x}_0\!+\! \frac{\sqrt{\alpha_t}(1\!-\!\bar{\alpha}_{t-1})}{1\!-\!\bar{\alpha}_t} \vec{x}_t \tag{12} \\ \tilde{\beta}_t &\mathrel{\vcenter{:}}= \frac{1-\bar{\alpha}_{t-1}}{1-\bar{\alpha}_t} \beta_t \tag{13} \end{align}\]
As we can observe, the objective in Eq. 5 is a sum of independent terms \(L_{t-1}\). Using Eq. 10 , we can sample from an arbitrary step of the forward diffusion process and estimate \(L_{t-1}\) efficiently. Hence, DDPM uniformly samples \(t\) for each sample in each mini-batch to approximate the expectation \(E_{\boldsymbol{x}_0,t,\epsilon}[L_{t-1}]\) to estimate \(L_{\text{vlb}}\).
To parameterize \(\mu_{\theta}(\vec{x}_t, t)\) for Eq. 12 , we can predict \(\mu_{\theta}(\vec{x}_t, t)\) directly with a neural network. Alternatively, we can first use Eq. 10 to replace \(\boldsymbol{x}_0\) in Eq. 12 to predict the noise \(\epsilon\) as \[\mu_{\theta}(\vec{x}_t, t) = \frac{1}{\sqrt{\alpha_t}} \left( \vec{x}_t - \frac{\beta_t}{\sqrt{1-\bar{\alpha}_t}} \epsilon_{\theta}(\vec{x}_t, t) \right), \label{mu95to95noise}\tag{14}\] [1] finds that predicting the noise \(\epsilon\) worked best with a reweighted loss function: \[\label{eq:loss95simple} L_{\text{simple}} = E_{t,\vec{x}_0,\epsilon}\left[ || \epsilon - \epsilon_{\theta}(\vec{x}_t, t) ||^2 \right].\tag{15}\] This objective can be seen as a reweighted form of \(L_{\text{vlb}}\) (without the terms affecting \(\Sigma_{\theta}\)). For more details of the training and inference, we refer the readers to [1]. A closely related branch is called score matching [21], [22], which builds a connection bridging DDPMs and EBMs. Our work is mainly built upon DDPM, but it’s straightforward to substitute DDPM with a score matching method.
Transformers [23] have made huge impacts across many deep learning fields [24] due to their prediction power and flexibility. They are based on the concept of self-attention, a function that allows interactions with strong gradients between all inputs, irrespective of their spatial relationships. The self-attention layer (Eq. 16 ) encodes inputs as key-value pairs, where values \(\vec{V}\) represent embedded inputs and keys \(\vec{K}\) act as an indexing method, and subsequently, a set of queries \(\vec{Q}\) are used to select which values to observe. Hence, a single self-attention head is computed as: \[\label{eqn:attention} \text{Attn}(\vec{Q}, \vec{K}, \vec{V}) = \text{softmax}\bigg( \frac{\vec{Q}\vec{K}^T}{\sqrt{d_k}} \bigg) \vec{V}.\tag{16}\] where \(d_k\) is the dimension of \(K\).
Vision transformers (ViT) ViT2021 has emerged as a famous architecture that outperforms CNNs in various vision domains. The transformer encoder is constructed by alternating layers of multi-headed self-attention (MSA) and MLP blocks (Eq. 18 , 19 ), and layernorm (LN) is applied before every block, followed by residual connections after every block [25], [26]. The MLP contains two layers with a GELU non-linearity. The 2D image \(\boldsymbol{x} \in {\mathbb{R}}^{H \times W \times C}\) is flattened into a sequence of image patches, denoted by \(\boldsymbol{x}_p \in {\mathbb{R}}^{L \times (P^2 \cdot C)}\), where \(L=\frac{H\times W}{P^2}\) is the effective sequence length and \(P \times P \times C\) is the dimension of each image patch.
Following BERT [27], we prepend a learnable classification embedding \(\boldsymbol{x}_\text{class}\) to the image patch sequence, then the 1D positional embeddings \({\mathbf{E}}_{pos}\) are added to formulate the patch embedding \(\boldsymbol{z}_0\). The overall pipeline of ViT is shown as follows: \[\begin{align} \boldsymbol{z}_0 = &[ \boldsymbol{x}_\text{class}; \, \boldsymbol{x}^1_p \boldsymbol{E}; \, \boldsymbol{x}^2_p \boldsymbol{E}; \cdots; \, \boldsymbol{x}^{N}_p \boldsymbol{E} ] + \boldsymbol{E}_{pos}, \tag{17} \\ & \boldsymbol{E} \in \mathbb{R}^{(P^2 \cdot C) \times D},\, \boldsymbol{E}_{pos} \in \mathbb{R}^{(N + 1) \times D} \nonumber \\ \boldsymbol{z^\prime}_\ell = &\text{MSA}(\text{LN}(\boldsymbol{z}_{\ell-1})) + \boldsymbol{z}_{\ell-1}, \;\; \ell=1\ldots L \tag{18} \\ \boldsymbol{z}_\ell = &\text{MLP}(\text{LN}(\boldsymbol{z^\prime}_{\ell})) + \boldsymbol{z^\prime}_{\ell}, \;\; \;\; \;\; \ell=1\ldots L \tag{19} \\ \boldsymbol{y} = &\text{LN}(\boldsymbol{z}_L^0) \tag{20} \end{align}\]
ViT have made significant breakthroughs in various discriminative tasks and generative tasks, including image classification, multi-modal, and high-quality image and text generation [2], [10], [28]. Inspired by the parallelism between patches/embeddings of ViT, we experiment with applying a standard ViT directly to generative modeling with minimal possible modifications.
Hybrid models [29] commonly model the density function \(p(\boldsymbol{x})\) and perform discriminative classification jointly using shared features. Notable examples are [18], [30]–[34].
Hybrid models can utilize two or more classes of generative model to balance the trade-off such as slow sampling and poor scalability with dimension. For example, VAE can be increased by applying a second generative model such as a Normalizing Flow [35]–[37] or EBM [38] in latent space. Alternatively, a second model can be used to correct samples [39]. In our work, we focus on training a single ViT as a hybrid model without the auxiliary model.
Energy-based models (EBMs) are an appealing family of models to represent data as they permit unconstrained architectures. Implicit EBMs define an unnormalized distribution over data typically learned through contrastive divergence [17], [40].
Joint Energy-based Model (JEM) [18] reinterprets the standard softmax classifier as an EBM and trains a single network to achieve impressive hybrid discriminative-generative performance. Beyond that, JEM++ [19] proposes several training techniques to improve JEM’s accuracy, training stability, and speed, including proximal gradient clipping, YOPO-based SGLD sampling, and informative initialization. Unfortunately, training EBMs using SGLD sampling is still impractical for high-dimensional data.
We propose GenViT by substituting UNet, the backbone of DDPM, with a single ViT. In our model design, we follow the standard ViT [2] as close as possible. An overview of the architecture of the proposed GenViT is depicted in Fig 2.
Given the input \(\boldsymbol{x}_t\) from DDPM, we follow the raster scan to get a sequence of image patches \(\boldsymbol{x}_p\), which is fed into GenViT as: \[\begin{align} {\mathbf{h}}_0 &= [ \boldsymbol{x}_\text{class}; \, \boldsymbol{x}_p^1 {\mathbf{E}}; \, \boldsymbol{x}_p^2 {\mathbf{E}}; \cdots; \, \boldsymbol{x}_p^N {\mathbf{E}}] + {\mathbf{E}}_{pos}, \nonumber \\ & {\mathbf{E}}\in {\mathbb{R}}^{(P^2 \cdot C) \times D},\, {\mathbf{E}}_{pos} \in {\mathbb{R}}^{(N + 1) \times D} \nonumber \\ \mathbf{h^\prime}_\ell &= \text{MSA}(\text{LN}(M(\mathbf{h}_{\ell-1}, \mathbf{A} ))) + \mathbf{h}_{\ell-1}, \ell=1,\ldots,L \nonumber \\ \mathbf{h}_\ell &= \text{MLP}(\text{LN}(M(\mathbf{h^\prime_{\ell}, A}))) + \mathbf{h^\prime}_{\ell}, \ell=1,\ldots,L \nonumber \\ {\mathbf{y}}&= {\mathbf{h}}_L = [{\mathbf{y}}^1, \cdots, {\mathbf{y}}^N] \nonumber \\ \boldsymbol{x}' &= [\boldsymbol{x}_p^1, \cdots, \boldsymbol{x}_p^N] = [f_r({\mathbf{y}}^1),\ldots,f_r({\mathbf{y}}^N)], \label{eq:generator95final95rep} \\ & \boldsymbol{x}_p^i \in {\mathbb{R}}^{P^2 \times C}, \boldsymbol{x}' \in {\mathbb{R}}^{H \times W \times C}. \nonumber \end{align}\tag{21}\] Different from ViT, GenViT takes the embedding of \(t\) as input to control the hidden features \(h_{\ell}\) every layer, and finally reconstruct \(L\)-th layer output \(\boldsymbol{h}_L \in {\mathbb{R}}^{(N + 1) \times D}\) to an image \(\boldsymbol{x}'\). Following the design of UNet in DDPM, we first compute the embedding of \(t\) using an MLP \(\boldsymbol{A} = \text{MLP}_t(t)\). Then we compute \(M( \mathbf{h_{\ell}, A} ) = \mathbf{h}_{\ell} * (\mu_{\ell}(\mathbf{A}) + 1) + \sigma_{\ell}(\mathbf{A})\) for each layer, where \({\mu_{\ell}}(\mathbf{A}) = \text{MLP}_{\ell}(\mathbf{A})\).
JEM reinterprets the standard softmax classifier as an EBM and trains a single network for hybrid discriminative-generative modeling. Specifically, JEM maximizes the logarithm of joint density function \(p_{\boldsymbol{\theta}}(\boldsymbol{x},y)\): \[\label{eq:jem95loss} \log p_{\boldsymbol{\theta}}(\boldsymbol{x}, y) = \log p_{\boldsymbol{\theta}}(y|\boldsymbol{x}) + \log p_{\boldsymbol{\theta}}(\boldsymbol{x}),\tag{22}\] where the first term is the cross-entropy classification objective, and the second term can be optimized by the maximum likelihood learning of EBM using contrastive divergence and MCMC sampling. However, MCMC-based EBM is notorious due to the expensive \(K\)-step MCMC sampling that requires \(K\) full forward and backward propagations at every iteration. Hence, removing the MCMC sampling in training is a promising direction [33].
We propose Hybrid ViT (HybViT), a simple framework to extend GenViT for hybrid modeling. We substitute the optimization of \(\log p_{\boldsymbol{\theta}}(\boldsymbol{x})\) in Eq. 22 with the VLB of GenViT as Eq. 4 . Hence, we can train \(p(y|x)\) using standard cross-entropy loss and optimize \(p(x)\) using \(L_{simple}\) loss in Eq 15 . The final loss of our HybViT is \[\begin{align} L = & L_{\text{CE}} + \alpha L_{\text{simple}} \tag{23} \\ = & E_{\boldsymbol{x}_0, y} \left[ H(\boldsymbol{x}_0, y) \right] + \alpha E_{t,\vec{x}_0,\epsilon}\left[ || \epsilon - \epsilon_{\theta}(\vec{x}_t, t) ||^2 \right] \tag{24} \end{align}\] We empirically find that a larger \(\alpha=100\) improves the generation quality while retaining comparable classification accuracy. The training pipeline can be viewed in Fig 3.
This section evaluates the discriminative and generative performance on multiple benchmark datasets, including CIFAR10, CIFAR100, STL10, CelebA-HQ-128, Tiny-ImageNet, and ImageNet 32x32.
Our code is largely built on top of ViT [[41]]2 and DDPM3. Note that we set the batch size as 128, and we update all ViT-based models with 1170 iterations in one epoch, while 390 iterations for CNN-based methods4. Most experiments of ViTs run for 500 epochs, but 2500 epochs for STL10 and 100 epochs for ImageNet 32x32. Thanks to the memory efficiency of ViT, all our experiments can be performed with PyTorch on a single Nvidia GPU. For reproducibility, our source code is provided in the supplementary material.
We first compare the performance with the state-of-the-art hybrid models, stand-alone discriminative and generative models on CIFAR10. We use accuracy, Inception Score (IS) [43] and Fréchet Inception Distance (FID) [44] as evaluation metrics. IS and FID are employed to evaluate the quality of generated images. The results on CIFAR10 are shown in Tables 1. HybViT outperforms other hybrid models including JEM (\(K\!=\!20\)) and JEM++ (\(M\!=\!20\)) on accuracy (95.9%) and FID score (26.4), when the original ViT achieves comparable accuracy to WideResNet(WRN) 28-10. Moreover, GenViT and HybViT are superior in training stability. HybViT matches or outperforms the classification accuracy of JEM++ (\(M\!=\!20\)), and in the meantime, it exhibits high stability during training while JEM (\(K\!=\!20\)) and JEM++ (\(M\!=\!5\)) would easily diverge at early epochs. The comparison results on more benchmark datasets, including CIFAR100, STL10, CelebA-128, Tiny-ImageNet, ImageNet 32x32 are shown in Table 2. Example images generated by GenViT and HybViT are provided in Fig 4 and 5, respectively. More generated images can be found in the appendix.
Model | Acc % \(\uparrow\) | IS \(\uparrow\) | FID \(\downarrow\) |
---|---|---|---|
ViT | 96.5 | - | - |
GenViT | - | 8.17 | 20.2 |
HybViT | 95.9 | 7.68 | 26.4 |
Single Hybrid Model | |||
IGEBM | 49.1 | 8.30 | 37.9 |
JEM | 92.9 | 8.76 | 38.4 |
JEM++ (M=20) | 94.1 | 8.11 | 38.0 |
JEAT | 85.2 | 8.80 | 38.2 |
Generative Models | |||
SNGAN | - | 8.59 | 21.7 |
StyleGAN2-ADA | - | 9.74 | 2.92 |
DDPM | - | 9.46 | 3.17 |
DiffuEBM | - | 8.31 | 9.58 |
VAEBM | - | 8.43 | 12.2 |
FlowEBM | - | - | 78.1 |
Other Models | |||
WRN-28-10 | 96.2 | - | - |
VERA(w/ generator) | 93.2 | 8.11 | 30.5 |
It’s worth mentioning that the overall quality of synthesis is worse than UNet-based DDPM. In particular, our methods don’t generate realistic images for complex and high-resolution data. ViT is known to model global relations between patches and lack of local inductive bias. We hope advances in ViT architectures and DDPM may address these issues in future work, such as Performer [47], Swin Transformer [48], CvT [49] and Analytic-DPM [50].
Model | Acc % \(\uparrow\) | IS \(\uparrow\) | FID \(\downarrow\) |
---|---|---|---|
CIFAR100 | |||
ViT | 77.8 | - | - |
GenViT | - | 8.19 | 26.0 |
HybViT | 77.4 | 7.45 | 33.6 |
WRN-28-10 | 79.9 | - | - |
SNGAN | - | 9.30 | 15.6 |
BigGAN | - | 11.0 | 11.7 |
Tiny-ImageNet | |||
ViT | 57.6 | - | - |
GenViT | - | 7.81 | 66.7 |
HybViT | 56.7 | 6.79 | 74.8 |
PreactResNet18 | 55.5 | - | - |
ADC-GAN | - | - | 19.2 |
STL10 | |||
ViT | 84.2 | - | - |
GenViT | - | 7.92 | 110 |
HybViT | 80.8 | 7.87 | 109 |
WRN-16-8 | 76.6 | - | - |
SNGAN | - | 9.10 | 40.1 |
ImageNet 32x32 | |||
ViT | 57.5 | - | - |
GenViT | - | 7.37 | 41.3 |
HybViT | 53.5 | 6.66 | 46.4 |
WRN-28-10 | 59.1 | - | - |
IGEBM | - | 5.85 | 62.2 |
KL-EBM | - | 8.73 | 32.4 |
CelebA 128 | |||
GenViT | - | - | 22.07 |
KL-EBM | - | - | 28.78 |
SNGAN | - | - | 24.36 |
UNet GAN | - | - | 2.95 |
In this section, we conduct a thorough evaluation of proposed methods beyond the accuracy and generation quality. Note that it is not our intention to propose approaches to match or outperform the best models in all metrics.
Recent works show that the predictions of modern convolutional neural networks could be over-confident due to increased model capacity [54]. Incorrect but confident predictions can be catastrophic for safety-critical applications. Hence, we investigate ViT and HybViT in terms of calibration using the metric Expected Calibration Error (ECE). Interestingly, Fig 6 shows that predictions of both HybViT and ViT look like well-calibrated when trained with strong augmentations, however they are less confident and have worse ECE compared to WRN. More comparison results can be found in the appendix.
Determining whether inputs are out-of-distribution (OOD) is an essential building block for safely deploying machine learning models in the open world. The model should be able to assign lower scores to OOD examples than to in-distribution examples such that it can be used to distinguish OOD examples from in-distribution ones. For evaluating the performance of OOD detection, we use a threshold-free metric, called Area Under the Receiver-Operating Curve (AUROC) [55]. Using the input density \(p_{\boldsymbol{\theta}}(\boldsymbol{x})\) [56] as the score, ViTs performs better in distinguishing the in-distribution samples from out-of-distribution samples as shown in Table 3,.
\(s_{\boldsymbol{\theta}}(\boldsymbol{x})\) | Model | SVHN | Interp | C100 | CelebA |
---|---|---|---|---|---|
\(\log p_{\boldsymbol{\theta}}(\boldsymbol{x})\) | WRN* | .91 | - | .87 | .78 |
IGEBM | .63 | .70 | .50 | .70 | |
JEM | .67 | .65 | .67 | .75 | |
JEM++ | .85 | .57 | .68 | .80 | |
VERA | .83 | .86 | .73 | .33 | |
KL-EBM | .91 | .65 | .83 | - | |
ViT | .93 | .93 | .82 | .81 | |
HybViT | .93 | .92 | .84 | .76 |
The result is from [57].
Adversarial examples [58], [59] tricks the neural networks into giving incorrect predictions by applying minimal perturbations to the inputs, and hence, adversarial robustness is a critical characteristics of the model, which has received an influx of research interest. In this paper, we investigate the robustness of models trained on CIFAR10 using the white-box PGD attack [60] under an \(L_\infty\) or \(L_2\) constraint. Fig 7 compares ViT and HybViT with the baseline WRN-based classifier. We can see that ViT and HybViT have similar performance and both outperform WRN-based classifiers.
An advantage of DDPM is that it can use the VLB as the approximated likelihood while most EBMs can’t compute the intractable partition function w.r.t \(\boldsymbol{x}\). Table 4 reports the test negative log-likelihood(NLL) in bits per dimension on CIFAR10. As we can observe, HybViT achieves comparable result to GenViT, and both are worse than other methods.
In this section, we study the effect of different training configurations on the performance of image classification and generation by conducting an exhaustive ablation study on CIFAR10. We investigate the impact of 1) training epochs, 2) the coefficient \(\alpha\), and 3) configurations of ViT/HybViT architecture in the main content. Due to page limitations, more results can be found in the appendix.
Model | Acc % \(\uparrow\) | IS \(\uparrow\) | FID \(\downarrow\) |
---|---|---|---|
ViT (epoch=100) | 94.2 | - | - |
ViT (epoch=300) | 96.2 | - | - |
ViT (epoch=500) | 96.5 | - | - |
GenViT(epoch=100) | - | 7.25 | 33.3 |
GenViT(epoch=300) | - | 7.67 | 26.2 |
GenViT(epoch=500) | - | 8.17 | 20.2 |
HybViT(epoch=100) | 93.1 | 7.15 | 35.0 |
HybViT(epoch=300) | 95.9 | 7.59 | 29.5 |
HybViT(epoch=500) | 95.9 | 7.68 | 26.4 |
The results are reported in Table 5 and 6. First, Table 5 shows a trade-off between the overall performance and computation time. The gain of classification and generation is relatively large when we prolong the training from 100 epochs to 300. With more training epochs, the accuracy gap between ViT and HybViT decreases. Furthermore, The generation quality can slightly improve after 300 epochs. Then we thoroughly explore the settings of the backbone ViT for GenViT and HybViT in Table 6. It can be observed that larger \(\alpha\) is preferred with high-quality generation and only small drop in accuracy. The number of heads also has a small effect on the trade-off between classification accuracy and generation quality. Enlarging the model capacity, depth, or hidden dimensions can improve the accuracy and generation quality.
Model | Acc % \(\uparrow\) | IS \(\uparrow\) | FID \(\downarrow\) |
---|---|---|---|
HybViT | 95.9 | 7.68 | 26.4 |
HybViT(\(\alpha\)=1) | 96.6 | 4.74 | 68.9 |
HybViT(\(\alpha\)=10) | 97.0 | 6.40 | 38.2 |
HybViT(head=6) | 96.0 | 7.51 | 30.0 |
HybViT(head=8) | 95.9 | 7.74 | 28.0 |
HybViT(head=16) | 95.4 | 7.79 | 27.1 |
HybViT(depth=6) | 94.7 | 7.39 | 30.6 |
HybViT(depth=12) | 96.6 | 7.78 | 24.3 |
HybViT(dim=192) | 94.1 | 7.06 | 35.0 |
HybViT(dim=768) | 96.4 | 8.04 | 19.9 |
GenViT(dim=192) | - | 7.26 | 32.5 |
GenViT(dim=384) | - | 8.17 | 20.2 |
GenViT(dim=768) | - | 8.32 | 18.7 |
While it is challenging for our methods to generate realistic images for complex and high-resolution data, it is beyond the scope of this work to further improve the generation quality for high-resolution data. Thus, it warrants an exciting direction of future work. We suppose the large patch size of the ViT’s architecture is the critical causing factor. Hence, we investigate the impact of different patch sizes on STL10 in Table 7. However, even though a smaller patch size can improve the accuracy by a notably margin at the cost of increasing model sizes, but the generation quality for high-resolution images plateaued around \(p=6\). These results indicate that the bottleneck of image generation comes from other components, such as the linear projections and reconstruction projections, other than the multi-head self-attention. Note that a larger patch size (ps=12) do further deteriorate the generation quality and would lead to critical issues for high-resolution data like ImageNet, since the corresponding patch size is typically set to 14 or larger.
Model | NoP | Acc % \(\uparrow\) | IS \(\uparrow\) | FID \(\downarrow\) |
---|---|---|---|---|
ViT(ps=8) | 12.9M | 78.7 | - | - |
HybViT(ps=4) | 41.1M | 87.1 | 6.90 | 125.5 |
HybViT(ps=6) | 17.0M | 81.7 | 7.30 | 123.6 |
HybViT(ps=8) | 12.9M | 77.5 | 6.95 | 125.2 |
HybViT(ps=12) | 11.4M | 69.1 | 2.55 | 240.2 |
GenViT(dim=384) | 12.9M | - | 6.95 | 125.2 |
GenViT(dim=576) | 26.4M | - | 7.02 | 124.1 |
GenViT(dim=768) | 45.2M | - | 7.01 | 126.6 |
We report the empirical training speeds of our models and baseline methods on a single GPU for CIFAR10 in Table 8 and those for ImageNet 32x32 is in the appendix. As discussed previously, two mini-batches are utilized in HybViT: one for training of \(L_{simple}\) and the other for training of the cross entropy loss. Hence, HybViT requires about \(2\times\) training time compared to GenViT. One of the advantages of GenViT and HybViT is that even with much more (\(7.5\times\)) iterations, they still reduce training time significantly compared to EBMs. The results demonstrate that our new methods are much faster and affordable for academia research settings.
Model | NoP(M) | Min/Epoch | Runtime(Hours) |
---|---|---|---|
ViT-based Models | |||
ViT(d=384) | 11.2 | 1.72 | 14.4 |
GenViT(d=384) | 11.2 | 2.11 | 17.6 |
HybViT(d=192) | 3.2 | 2.14 | 17.9 |
HybViT(d=384) | 11.2 | 3.71 | 31.2 |
HybViT(d=768) | 43.2 | 9.34 | 77.8 |
WRN-based Models | |||
WRN 28-10 | 36.5 | 1.53 | 5.2 |
JEM(K=20) | 36.5 | 30.2 | 101.3 |
JEM++(K=10) | 36.5 | 20.4 | 67.4 |
VERA | 40 | 19.3 | 64.3 |
IGEBM | - | 1 GPU for 2 days | |
KL-EBM | 6.64 | 1 GPU for 1 day | |
VAEBM* | 135 | 400 epochs, 8 GPUs, 55 hours | |
DDPM | 35.7 | 800k iter, 8 TPUs, 10.6 hours | |
DiffuEBM | 34.8 | 240k iter, 8 TPUs, 40+ hours |
The runtime is for pretraining NVAE only. It further needs 25,000 iterations (or 16 epochs) on CIFAR-10 using one GPU for VAEBM.
As shown in previous sections, our models GenViT and HybViT exhibit promising results. However, compared to CNN-based methods, the main limitations are: 1) The generation quality is relatively low compared with pure generation (non-hybrid) SOTA models. 2) They require more training iterations to achieve high classification performance compared with pure classification models. 3) The sampling speed during inference is slow (typically \(T \geq 1000\)) while GAN only needs one-time forward.
We believe the results presented in this work are sufficient to motivate the community to solve these limitations and improve speed and generative quality.
In this work, we integrate a single ViT into DDPM to propose a new type of generative model, GenViT. Furthermore, we present HybViT, a simple approach for training hybrid discriminative-generative models. We conduct a series of thorough experiments to demonstrate the effectiveness of these models on multiple benchmark datasets with state-of-the-art results in most of the tasks of image classification, and image generation. We also investigate the intriguing properties, including likelihood, adversarial robustness, uncertainty calibration, and OOD detection. Most importantly, the proposed approach HybViT provides stable training, and outperforms the previous state-of-the-art hybrid models on both discriminative and generation tasks. While there are still challenges in training the models for high-resolution images, we hope the results presented here will encourage the community to improve upon current approaches.
The image benchmark datasets used in our experiments are described below:
CIFAR10 [63] contains 60,000 RGB images of size \(32\times 32\) from 10 classes, in which 50,000 images are for training and 10,000 images are for test.
CIFAR100 [63] also contains 60,000 RGB images of size \(32\times 32\), except that it contains 100 classes with 500 training images and 100 test images per class.
STL10 [64] 500 training images from 10 classes as CIFAR10, 800 test images per class.
Tiny-ImageNet contains 100000 images of 200 classes (500 for each class) downsized to 64×64 colored images. Each class has 500 training images, 50 validation images and 50 test images.
CelebA-HQ [65] is a human face image dataset. In our experiment, we use the downsampled version with size \(128\times 128\).
Imagenet 32x32 [66] is a downsampled variant of ImageNet with 1,000 classes. It contains the same number of images as vanilla ImageNet, but the image size is \(32\times 32\).
As we discuss in the main content, all our experiments are based on vanilla ViT in [[41]]5 and DDPM6 and follow their settings. We use SGD for all datasets with an initial learning rate of 0.1. We reduce the learning rate using the cosine scheduler. Table 9 lists the hyper-parameters in our experiments. We also tried \(T = 4000\) and \(L_2\) loss to train our GenViT and HybViT, and their final results are comparable.
Variable | Value |
---|---|
Learning rate | 0.1 |
Batch Size | 128 |
Warmup Epochs | 10 |
Coefficient \(\alpha\) in HybViT | 1, 10, 100 |
Configurations of ViT | |
Dimensions | 384 |
Depth | 9 |
Heads | 12 |
Patch Size | 4, 8 |
Configurations of DDPM | |
Number of Timesteps \(T\) | 1000 |
Loss Type | \(L_1\) |
Noise Schedule | cosine |
First, we investigate the gap between ViT, GenViT and HybViT in Fig 8. We select two benchmark datasets CIFAR10 and ImageNet 32x32. It can be observed that the improvement of generation quality is relatively small after 10% training epochs. The difference is almost visually imperceptible for human between samples with FID=40 and FID=20 as shown in Fig. Hence, we think accelerating the convergence rates of our models is an interesting direction in the future.
Following the setting of JEM [18], we conduct a qualitative analysis of samples on CIFAR10. We define an energy function of \(\boldsymbol{x}\) as \(p_{\boldsymbol{\theta}}(\boldsymbol{x}) \varpropto E(\boldsymbol{x}) = \log\! \sum_{y}\!e^{f_{\boldsymbol{\theta}}(\boldsymbol{x})\left[y\right]}\!=\!\text{LSE}( f_{\boldsymbol{\theta}}(\boldsymbol{x}))\), the negative of the energy function in [18], [57]. We use a CIFAR10-trained HybViT model to generate 10,000 images from scratch, then feed them back into the HybViT model to compute \(E(\boldsymbol{x})\) and \(p(y|\boldsymbol{x})\). We show the examples and distribution by class in Fig 10 and Fig 11. We can observe that the worst examples of Plane can be completely blank. Additional HybViT generated class-conditional (best and worst) samples of CIFAR10 are provided in Figures 15-24.
Model | Aug | Acc % \(\uparrow\) | IS \(\uparrow\) | FID \(\downarrow\) |
---|---|---|---|---|
ViT | Strong | 96.5 | - | - |
Weak | 87.1 | - | - | |
HybViT | Strong | 95.9 | 7.68 | 26.4 |
Weak | 84.6 | 7.85 | 24.9 |
We study the effect of data augmentation. ViT is known to require a too large amount of training data and/or repeated strong data augmentations to obtain acceptable visual representation. Table 10 compares the performance between strong augmented data and conventional Inception-style pre-processed(namely weak augmentation) data [67]. We can conclude that the strong data augmentation is really essential for high classification performance and the effect on generation is negative but tiny. Note that the data augmentation is only used for classification, and for DDPM, we don’t apply any data augmentation.
Another useful OOD score function is the maximum probability from a classifier’s predictive distribution: \(s_{\boldsymbol{\theta}}(\boldsymbol{x}) = \max_y p_{\boldsymbol{\theta}}(y|\boldsymbol{x})\). The results can be found in Table 11 (bottom row).
\(s_{\boldsymbol{\theta}}(\boldsymbol{x})\) | Model | SVHN | CIFAR10 Interp | CIFAR100 | CelebA |
---|---|---|---|---|---|
\(\log p_{\boldsymbol{\theta}}(\boldsymbol{x})\) | WideResNet [57] | .91 | - | .87 | .78 |
IGEBM [17] | .63 | .70 | .50 | .70 | |
JEM [18] | .67 | .65 | .67 | .75 | |
JEM++ [19] | .85 | .57 | .68 | .80 | |
VERA [33] | .83 | .86 | .73 | .33 | |
ImCD [51] | .91 | .65 | .83 | - | |
ViT | .93 | .93 | .82 | .81 | |
HybViT | .93 | .92 | .84 | .76 | |
\(\max_y p_{\boldsymbol{\theta}}(y|\boldsymbol{x})\) | WideResNet | .93 | .77 | .85 | .62 |
IGEBM [17] | .43 | .69 | .54 | .69 | |
JEM [18] | .89 | .75 | .87 | .79 | |
JEM++ [19] | .94 | .77 | .88 | .90 | |
ViT | .91 | .95 | .82 | .74 | |
HybViT | .91 | .94 | .85 | .67 |
Given ViT models trained with different data augmentations, we can investigate their robustness since weak data augmentations are commonly used in CNNs. Table 12 shows an interesting phenomena that HybViT with weak data augmentation is much robust than other models, especially under \(L_2\) attack. We suppose it’s because the noising process feeds huge amount of noisy samples to HybViT, then HybViT learns from the noisy data implicitly to improve the flatness and robustness.
Model | Clean (%) | \(L_\infty\) \(\epsilon=1/255\) | \(2\) | \(4\) | \(8\) | \(12\) | \(16\) | \(22\) | \(30\) |
---|---|---|---|---|---|---|---|---|---|
ViT | 96.5 | 70.8 | 46.7 | 21.7 | 7.0 | 1.4 | 0.1 | 0 | 0 |
- Weak Aug | 87.1 | 67.3 | 41.8 | 14.8 | 1.4 | 0.1 | 0 | 0 | 0 |
HybViT | 95.9 | 70.4 | 48.0 | 21.9 | 5.5 | 1.3 | 0.3 | 0 | 0 |
- Weak Aug | 84.6 | 71.3 | 55.6 | 30.3 | 6.7 | 0.6 | 0.1 | 0 | 0 |
Model | Clean (%) | \(L_2\) \(\epsilon=50/255\) | \(100\) | \(150\) | \(200\) | \(250\) | \(300\) | \(350\) | \(400\) |
ViT | 96.5 | 52.3 | 9.2 | 1.1 | 0.3 | 0.1 | 0.1 | 0 | 0 |
- Weak Aug | 87.1 | 53.9 | 21.4 | 5.5 | 1.0 | 0.1 | 0 | 0 | 0 |
HybViT | 95.9 | 58.7 | 16.3 | 3.4 | 1.0 | 0.2 | 0.1 | 0.1 | 0 |
- Weak Aug | 84.6 | 65.8 | 42.3 | 25.7 | 13.2 | 6.4 | 3.4 | 1.5 | 0.7 |
Figures in 12 provide a comparison of ViT and HybViT with the baselines WRN and JEM, and also corresponding ViTs trained without strong data augmentations. It can be observed that strong data augmentations can better calibrate the predictions of ViT and HybViT, but further make them under-confident.
We further report the empirical training speeds of our models and baseline methods for ImageNet 32x32. Our methods are memory efficient since it only requires a single GPU, and much faster.
Model | NoP(M) | Runtime |
---|---|---|
ViT | 11.6 | 3 days |
GenViT | 11.6 | 2 days |
HybViT | 11.6 | 5 days |
IGEBM | 32 GPUs for 5 days | |
KL-EBM | 8 GPUs for 3 days |
Additional generated samples of CIFAR10, CIFAR100, ImageNet 32x32, TinyImageNet, STL10, and CelebA 128 are provided in Figure 13. We further provide some generated images for ImageNet 128x128 and vanilla ImageNet 224x224 are shown in 14. , The patch size are set as 8 and 14 for ImageNet 128 and 224 respectively. Similar to previous discussion about patch size, we find the generation quality is very low. Due to limited computation resource and low generation quality, we only show a preliminary generative results on ImageNet-128 and vanilla ImageNet 224x224.