How Lightweight Can a Vision Transformer Be

Jen Hong Tan
Data Science and Artificial Intelligence Lab (DSAIL)
Health Services Research Unit
Singapore General Hospital


Abstract

In this paper, we explore a strategy that uses Mixture-of-Experts (MoE) to streamline, rather than augment, vision transformers. Each expert in an MoE layer is a SwiGLU feedforward network, where \(\mathbf{V}\) and \(\mathbf{W_2}\) are shared across the layer. No complex attention or convolutional mechanisms are employed. Depth-wise scaling is applied to progressively reduce the size of the hidden layer and the number of experts is increased in stages. Grouped query attention is used. We studied the proposed approach with and without pre-training on small datasets and investigated whether transfer learning works at this scale. We found that the architecture is competitive even at a size of 0.67M parameters.

1 Introduction↩︎

In real-world applications of computer vision, such as edge intelligence, small and performant models are still preferred to overcome computational challenges [1]. Vision Transformers (ViTs) [2] have achieved remarkable results, but their performance drops significantly when the model size and dataset are small [3]. Consequently, there are studies investigating lightweight vision transformers that perform well on mid-size datasets. Almost all of these studies either use new types of attention blocks [4], [5] or integrate convolutional mechanisms [6] into their architectures.

On the other hand, Tan [7] has shown that by employing Masked Auto-Encoder (MAE) [8] as a pre-training strategy, it is possible to get ViT to learn effectively from small datasets. In that work, the ViT consists of 12 transformer encoder layers, each containing a multi-head attention component and a feedforward network. The feedforward network consists of two linear layers: the first expands the output to twice, rather than four times, the embedding size, and the second reduces the output back to the embedding size. To further lighten the model, reducing the expanded output size in the middle of the feedforward network can help, but excessive reduction can negatively affect model performance.

With these considerations in mind, we designed an architecture that uses Mixture-of-Experts (MoE) [9] to streamline vision transformers. In our architecture, each expert in a MoE layer is formed by a SwiGLU [10] feedforward network. By design, SwiGLU is heavier in terms of parameters compared to a typical multi-layer perceptron. However, with several experts in a MoE layer, we are able to make the hidden size in SwiGLU smaller than the embedding size without negatively affecting model performance. Furthermore, we share 2 out of the 3 linear transformations in each SwiGLU across the layer. This helps to significantly lower the parameter count while maintaining the strength of MoE. Beyond that, to further reduce the number of parameters, we progressively increase the number of experts in the MoE in stages, while linearly reducing the hidden size by depth, borrowing the idea from depth-wise scaling [11]. Lastly, we use grouped query attention [12] to keep the parameter count low. Source code will be provided in near future.

2 Method↩︎

Our proposed approach consists of two parts: mLiT and mmLiT. mLiT is a mixture-of-experts based Lightweight vision Transformer, and mmLiT is a mLiT pre-trained and fine-tuned using masked auto-encoder pretraining strategy.

2.1 mLiT↩︎

Similar to the original ViT [2], we start by dividing each image into non-overlapping patches. Each of these patches is linearly transformed into a set of embeddings, augmented by learnable positional embeddings. The processed embeddings go through a series of MoE-based transformer encoder layers. Figure 1 shows the overall structure of our transformer encoder layer .

Figure 1: The structure of an MoE based transformer encoder layer. Each expert is a SwiGLU feedforward network. \(b\) stands for batch size, \(n\) for the number of embeddings and \(m\) for embedding size. In this layer there are \(t\) number of experts

Figure 2: The working of MoE assuming the input is a vector. In this example, expert 1 and expert 3 receive the input as directed by the gating network

2.1.1 Grouped Query Attention (GQA)↩︎

Grouped query attention divides query heads into \(G\) groups, with each group sharing a single key head and value head. For instance, GQA-1, with a single group, is equivalent to multi-query attention (MQA); while GQA-\(H\), with groups equal to the number of heads (\(H\)), is equivalent to multi-head attention (MHA). An intermediate number of groups results in a model that is higher quality than MQA but faster than MHA [12], providing a favorable trade-off between performance and memory bandwidth.

2.1.2 Mixture of Experts (MoE)↩︎

A single Mixture-of-Experts layer comprises \(t\) number of expert networks \(E_1, E_2, \cdots, E_t\), with a gating network \(G\) that directs each input vector/embedding to the relevant expert(s) in the layer. For each vector or embedding, the output of \(G\) is a sparse \(t\)-dimensional vector. Figure 2 shows an overview of the working of MoE when the input is a vector.

Let’s denote a vector/embedding by \(x\), the output of the \(i\)-th expert by \(E_i(x)\), and the \(i\)-th component of the output of the gating network \(G(x)\) by \(G(x)_i\). The output \(y\) of the MoE is given by:

\[y = \sum_{i=1}^t G(x)_i E_i(x) \]

As the output \(G(x)\) is sparse, in MoE, no computation is performed on the \(j\)-th expert when \(G(x)_j = 0\). In many applications, however, the input to the layer is generally of shape \((b,n,m)\), where \(b\), \(n\), and \(m\) denote batch size, the number of embeddings and the embedding size, respectively. Therefore, it is common to reshape the input tensor into 2D [13] before the MoE layer. Figure 3 shows the workflow of the actual implementation of MoE.

Figure 3: The actual implementation of MoE. Unlike the case when the input is just a vector, when \(b \cdot n\) embeddings go into the layer, almost all the expert networks will receive some embeddings sent by the dispatcher

Noisy Top-K Gating Network. We follow the proposal by Shazeer et. al. [9] that adds sparsity and noise to a softmax gating mechanism. Assume a reshaped input tensor \(\mathbf{x}\) with a shape \((b \cdot n, m)\) and two trainable weight matrices \(\mathbf{W}_g\) and \(\mathbf{W}_{noise}\) , both of shape \((m, t)\), where \(t\) is the total number of experts in the layer and \(k\) is the number of expert(s) to be selected. The output of the gating network is given by

\[G(\mathbf{x}) = \mathrm{Softmax}_k \left( H(\mathbf{x}) \right)\] where \[H(\mathbf{x}) = \mathbf{x}\mathbf{W}_g + \mathrm{randn\_like({\mathbf{x}\mathbf{W}_g}}) \odot \mathrm{Softplus}\left( \mathbf{x}\mathbf{W}_{noise} \right) \]

\(\odot\) denotes the element-wise product. \(\mathrm{randn\_like}\) generates a tensor filled with random numbers from a normal distribution with mean 0 and variance 1, and the shape of the tensor is equal to the shape of the output from \(\mathbf{x} \mathbf{W}_g\). \(\mathrm{Softmax}_k\) applies softmax (row-wise) only on the top \(k\) elements, with the rest set to 0 in the output. The equation below illustrates how \(\mathrm{Softmax}_k\) works when \(k\) is 2 and \(t\) (the number of experts) is 4:

\[\mathrm{Softmax}_k \left( \left[ \begin{array}{rrrr} 4.4742 & -5.6365 & 6.8226 & 0.9960 \\ 3.5298 & 2.3049 & 1.2113 & -1.3946 \\ -2.2414 & 0.3925 & 1.6676 & -1.9253 \\ \end{array} \right]\right) = \begin{bmatrix} 0.0872 & 0.0000 & 0.9128 & 0.0000 \\ 0.7729 & 0.2271 & 0.0000 & 0.0000 \\ 0.0000 & 0.2184 & 0.7816 & 0.0000 \\ \end{bmatrix} \]

Losses for MoE. To encourage all experts to have equal importance, two loss functions are introduced: load balancing loss and importance loss. To calculate load balancing loss, we first determine the following probability by

\[P(\mathbf{x}) = \Phi \left( \frac{\mathbf{x}\mathbf{W}_g- \Psi(H(\mathbf{x}))}{\mathrm{Softplus}(\mathbf{x}\mathbf{W}_{noise})} \right) \] where \(\Phi\) is the cumulative distribution function of a standard normal distribution. To calculate the output of \(\Psi\), we let \(\mathbf{H}\) denote the output of \(H(\mathbf{x})\), which has a shape of \((b \cdot n, t)\), with each element in \(\mathbf{H}\) denoted by \(h_{r,c}\) . Furthermore, we let \(\mathbf{H}_r\) denote a row by index \(r\) in \(\mathbf{H}\). Similarly, \(\psi_{r,c}\) denotes each element in the output of \(\Psi(H(\mathbf{x}))\). If we let \(l^{k}_r\) and \(l^{k+1}_r\) denote the \(k^{th}\) and \(k+1^{th}\) largest values respectively for a row \(r\) in \(\mathbf{H}\), \(l^{k}_r > l^{k+1}_r\), then

\[\psi_{r,c} = \left\{ \begin{array}{ll} l^{k+1}_rin\mathbf{H}_r &ifh_{r,c} \geq l^{k}_r \\ l^{k}_rin\mathbf{H}_r &ifh_{r,c} < l^{k}_r \\ \end{array} \right.\] where \(k\) is the amount of expert(s) to be selected in the layer. As an example, assume we have

\[\mathbf{H}= \left[ \begin{array}{rrrr} 4.4742 & -5.6365 & 6.8226 & 0.9960 \\ 3.5298 & 2.3049 & 1.2113 & -1.3946 \\ -2.2414 & 0.3925 & 1.6676 & -1.9253 \\ \end{array} \right] \] and \(k\) is 2, then for each row in \(\mathbf{H}\) (with row index starting from 0), we have

\[\begin{array}{crr} \hline row & k^{th}largest & k+1^{th}largest \\ \hline 0 & 4.4742 & 0.9960 \\ 1 & 2.3049 & 1.2113 \\ 2 & 0.3925 & -1.9253 \\ \hline \end{array} \] The output of \(\Psi\) is

\[\Psi(\mathbf{H})= \left[ \begin{array}{rrrr} 0.9960 & 4.4742 & 0.9960 & 4.4742 \\ 1.2113 & 1.2113 & 2.3049 & 2.3049 \\ 0.3925 & -1.9253 & -1.9253 & 0.3925 \\ \end{array} \right] \] Lastly, with the probability calculated, the load balancing loss is given by:

\[L_{load}(\mathbf{x})= w_{load} \times \mathrm{CV}\left( \mathrm{sum}_0 (P(\mathbf{x}))\right) \] where \(w_{load}\) is a hand-tuned scaling factor, \(\mathrm{CV}\) is the function to calculate coefficient of variation for the input, and \(\mathrm{sum}_0\) performs column-wise summation on \(P(\mathbf{x})\).To determine importance loss, we simply do

\[L_{importance} = w_{importance} \times \mathrm{sum}_0 (G(\mathbf{x})) \] and the total loss is

\[L = L_{importance}+L_{load}\label{eq:moe95loss}\tag{1}\] In our implementation, we set \(w_{importance} = w_{load} = 1 \times 10^{-2}\).

SwiGLU. Each expert in a MoE layer is a SwiGLU FeedForward Network (FFN). Assume we have a tensor \(\mathbf{x}\) of shape \((n, m)\), where \(n\) is the number of embeddings and \(m\) is the embedding size, we define \(\mathbf{W}, \mathbf{V}\) and \(\mathbf{W_2}\) of shape \((m,d_h)\), \((m, d_h)\) and \((d_h, m)\) respectively. \(d_h\) is the hidden size. The output of the network is given by:

\[\mathrm{FFN_{SwiGLU}}(\mathbf{x}) = \left(\mathrm{silu}\left(\mathbf{xW} \right) \odot \left(\mathbf{xV} \right) \right)\mathbf{W}_2 \] Biases are omitted in the above for \(\mathbf{W}, \mathbf{V}\) and \(\mathbf{W_2}\). \(\mathrm{silu}\) is the swish function:

\[\mathrm{silu}(x) = x * \sigma(x) \] where \(\sigma(x)\) is the logistic sigmoid. Dropout is applied on the output of SwiGLU FFN.

In mLiT, \(\mathbf{V}\) and \(\mathbf{W_2}\) and their corresponding biases in each expert are shared in a MoE layer.

2.1.3 Depth-wise Scaling↩︎

Our vision transformer does not have a fixed hidden size \(d_h\) across the transformer encoder layers. Instead, the hidden size is linearly reduced from the first layer to the last layer. Let \(d^{\mathrm{first}}_h\) and \(d^{\mathrm{last}}_h\) denote the hidden size of the first and the last layer respectively, the hidden size of any layer is given by

\[l^i_h = \left\lfloor \frac{(L_E - 1)-i}{L_E-1}\left(d^{\mathrm{first}}_h-d^{\mathrm{last}}_h \right) \right\rfloor + d^{\mathrm{last}}_h \label{eq:hid95size}\tag{2}\] where \(L_E\) is the total number of the transformer encoder layers, \(\lfloor \cdot \rfloor\) is floor operator, and \(i\) is a layer index starting from 0.

Furthermore, the number of experts in mLiT is increased at different stages, specifically from 3 to 5 after every 3, 4, or 5 layers, depending on the model size. See Figure 4 for more detail.

Figure 4: The overall architecture of mLiT. The linear transformation and positional embeddings before the first layer are not included in the illustration. \(b\) stands for batch size, \(n\) for the number of embeddings and \(m\) for embedding size. The input and the output embedding size of each SwiGLU FFN expert are equal to the embedding size \(m\). The model in this figure has 9 MoE based transformer encoder layers, with the number of experts increased by 1 at layer 3 and layer 6. \(d^{\mathrm{first}}_h\) and \(d^{\mathrm{last}}_h\) are 81 and 27, respectively. The hidden size at each layer is calculated using Equation 2 .

2.2 mmLiT↩︎

In mmLiT, we apply masked auto-encoder on mLiT. We follow strictly the process outlined in [7], which is based on the original MAE paper [8]. The masking ratio is set to 0.75. However, unlike in [7], we did not incorporate a separable learnable positional embeddings at the decoder’s input (See Figure 5). Similar to [7], we compute the loss of the auto-encoder as a sum of the loss on masked patches and an additional, discounted loss on unmasked patches. The total loss is given by

\[\mathrm{Loss} = \mathrm{MSE}_{\mathrm{masked}} + \alpha \cdot \mathrm{MSE}_{\mathrm{unmasked}} + \beta \cdot L \label{eq:loss}\tag{3}\] where \(\alpha\) denotes the discounting factor, \(\beta\) is the loss coefficient for MoE. \(L\) is calculated from Equation 1 .

Figure 5: The architecture of mmLiT. \(b\), \(m\) and \(p\) stand for batch size, embedding size and patch size, respectively. \(n_E\) and \(n_D\) are the number of embeddings/patches at the encoder and decoder respectively. \(n_E + n_D = n\). \(L_E\) is the total number of MoE transformer encoder layers at encoder; \(L_D\) is the total number the layers at decoder.

3 Experimental Setup↩︎

We investigate the performance of mLiT and mmLiT at three different sizes: S, XS and XXS. Table 1 shows the details of the encoder of each model, and Table 2 for decoder. For mmLiT, we perform self-supervised pre-training on the Cifar100 [14] training datasets and fine-tune on Cifar100, Cifar10[14], Flowers102 [15] and Svhn [16]. Additionally, we conduct supervised learning () on the aforementioned four datasets.

The input image to the models has a size of 36 x 36, slightly larger than the original 32 x 32 dimensions of these datasets (except Flowers102, which is much larger and varied) . This results in each image being divided into 144 patches, given the patch size of 3 x 3. Similar to [7], an auxiliary dummy patch is added to each image during both the pre-training and fine-tuning phases. This dummy patch, which contains zero values for all elements, is appended as the 145th patch for classification purposes.

Table 1: Configurations for various sizes of mLiT
Configuration S XS XXS
Embedding size 144 128 108
No. of layers (\(L_E\)) 15 12 9
Hidden size 72-144 32-96 27-81
No. of Attn. heads 8 8 6
No. of Attn. groups 4 4 3
No. of experts 3-5 3-5 3-5
k 2 2 2
No. of stages 3 3 3
No. of parameters 2.36M 1.21M 0.66M
Dropout rate @ SwiGLU 0.1 0.1 0.1
Table 2: Configurations for decoder in mmLiT
Configuration S,XS,XXS
Embedding size 108
No. of layers (\(L_D\)) 4
Hidden size 72
No. of Attn. heads 6
No. of Attn. groups 3
No. of experts 3
k 2
No. of parameters 0.34M
Dropout rate @ SwiGLU 0.1

All linear layers within the MoE based transformer encoder layers include bias. However, we exclude bias from other linear projection layers in both the encoder and decoder. For initializing weights and biases across all layer types, we rely on the default methods provided by Pytorch. The same applies to our approach to layer normalization [17], where we use Pytorch’s default setup.

3.1 Pre-training↩︎

mmLiT-S, mmLiT-XS, and mmLiT-XXS were pre-trained on Cifar100 for 4000, 6000 and 8000 epochs, respectively. We employed the AdamW optimizer [18] with a weight decay set at 0.05. The initial 5% of the total epochs were designated for warm-up [19]. We followed this with a cosine decay schedule [20] for the learning rate and adhered to the linear learning rate scaling rule with a base learning rate of \(3e-4\) [8], [19]:

\[\mathrm{lr} = \mathrm{base\_lr} \times \mathrm{batch\_size} /256 \label{eq:lr}\tag{4}\] See Table 3 for more details.

Table 3: Parameters and configuration for pre-training
Configuration Value
Optimizer AdamW
Weight decay 0.05
Base learning rate \(3 \times 10^{-4}\)
Learning rate schedule Cosine decay
Warm-up epochs 200 (S), 300 (XS), 400 (XXS)
Batch size 840 (S), 1280 (XS, XXS)
Horizontal flipping \(p=0.5\)
Random resized cropping \([0.6, 1]\)
Color normalization \(\mathrm{mean}=0.5, \mathrm{std}=0.5\)
Discounting factor (\(\alpha\)) 0.1
MoE loss coefficient (\(\beta\)) 0.5

3.2 Fine-tuning↩︎

For each pre-trained model, we conducted two sets of fine-tuning experiments. First, we fine-tuned the models pre-trained on Cifar100 for Cifar100 classification over 300 epochs. Second, we evaluated the transfer learning capabilities of the models by fine-tuning each one on Cifar100, Cifar10, Flowers102, and SVHN for 100 epochs. Table 4 shows the configuration for the first set of fine-tuning, and Table 5 shows the deviations in configurations with respect to Table 4 for the second set of fine-tuning.

Additionally, we evaluated the performance of various sizes of mLiT models (trained from scratch with only supervised learning). Table 6 shows the deviations in configurations for supervised learning on mLiT on the four datasets. We adhere to the linear learning rate scaling rule (Eq. 3 ) for all fine-tunings and supervised learning. All pre-trainings and fine-tunings were done using mixed precision (torch.float16).

Table 4: Hyperparameters for the first set of fine-tuning
Hyperparameter Value
Optimizer AdamW
Weight decay 0.05
Base learning rate \(5 \times 10^{-3}\)
Learning rate schedule Cosine decay
Layer-wise decay 0.9
Total epochs 300
Warm-up epochs 20
Batch size 448
Horizontal flipping \(p=0.5\)
Random resized cropping \([0.8, 1]\)
AutoAugment Policy for Cifar10
Color normalization \(\mathrm{mean}=0.5, \mathrm{std}=0.5\)

3pt

Table 5: Hyperparameters deviations for the second set of fine-tuning. The learning rate is in \(10^{-3}\).
Hyperparameter Flowers102 Svhn Cifar10 Cifar100
Batch size 16 256 256 128
Learning rate 10 2.5 2.5 5
Layer-wise decay 0.9 0.9 0.8 0.9
Warm-up epoch 10 10 5 5
AutoAugment policy Cifar10 Svhn Cifar10 Cifar10

3pt

Table 6: Hyperparameters deviations for supervised learning on mLiT across four datasets. The learning rate is in \(10^{-3}\).
Hyperparameter Flowers102 Svhn Cifar10 Cifar100
Batch size 16 256 256 128
Learning rate 10 1 1 2
Layer-wise decay 1 1 1 1
Warm-up epoch 10 10 5 5
AutoAugment policy Cifar10 Svhn Cifar10 Cifar10

4 Result↩︎

4.1 On Pre-Training↩︎

Figure 6 presents the pre-training losses for various sizes of mmLiT. The reductions in training losses exhibit approximately a log-linear relationship when plotted on logarithmic scales for both the x-axis (epochs) and y-axis (loss values). This indicates an exponential decrease in losses across epochs.

mmLiT-S exhibited the lowest loss, reaching 0.020404 at epoch 4000. mmLiT-XS and mmLiT-XXS achieved losses of 0.021051 and 0.021514 at epochs 6000 and 8000, respectively. The pre-training of mmLiT-S was notably stable, as evidenced by the minimal variability in its training loss. In contrast, the training losses for mmLiT-XS and mmLiT-XXS showed greater variability and more frequent, pronounced spikes, particularly noticeable after epoch 1000.

Figure 6: Pre-training losses for mmLiT-S,mmLiT-XS, and mmLiT-XXS. The pre-training of mmLiT-S was the most stable, with minimal variability. This can be observed by the thickness of the curves.

4.2 On Fine-Tuning↩︎

Table 7 compares the performances of models with similar sizes, all trained or fine-tuned over 300 epochs. mmLiT-S achieves an accuracy nearly 1% higher than Mae-ViT-C100, despite having only two-thirds the parameters. mmLiT-XS, which is one-third the size of Mae-ViT-C100, exhibits a performance only 1.5% lower. More interestingly, mmLiT-XXS, with just 18% of the parameters of Mae-ViT-C100, remains competitive, trailing by only 3.3% in accuracy. Furthermore, mmLiT-XXS slightly outperforms ResNet56, even though ResNet56 has 18% more parameters.

Table 8 illustrates the transfer learning capabilities of the models pre-trained only on Cifar100, with reference to the results reported from [21]. It can be seen that mmLiT is competitive even at a scale of 0.67M parameters, where ViT-T+SSAT is almost 9 times larger, and the rests are at least 30 times larger than mmLiT-XXS. Furthermore, with the exception of CVT-13, which has convolutions in the architecture, mLiT always performs better than vanilla ViT and Swin-T even at the smallest scale on Cifar100, Cifar10 and Svhn. Flowers102 is a fine-grained classification dataset, so it is no surprise that mLiT is not competitive at an image size of 36 x 36.

10pt

Table 7: Comparison of Top-1 validation accuracy
Model Cifar100 # Params
Convolutional Networks (Designed for CIFAR)
ResNet56 [22] 74.81% 0.85 M
ResNet110[22] 76.63% 1.73 M
Vision Transformers[23]
ViT-12/16 57.97% 85.63 M
ViT-Lite-7/16 52.87% 3.89 M
ViT-Lite-7/8 67.27% 3.74 M
ViT-Lite-7/4 73.94% 3.72 M
Compact Vision Transformers[23]
CVT-7/8 70.11% 3.74 M
CVT-7/4 76.49% 3.72 M
Compact Convolutional Transformers[23]
CCT-2/3 \(\times\) 2 66.93% 0.28 M
CCT-7/3 \(\times\) 2 77.72% 3.85 M
MAE Vision Transformers [7]
Mae-ViT-C100 78.27% 3.64 M
Current work on mmLiT
mmLiT-S 79.15% 2.38M
mmLiT-XS 76.70% 1.22M
mmLiT-XXS 74.95% 0.67M
Table 8: Top-1 classification accuracy (in percentage) of various models on Cifar100, Cifar10, Svhn, and Flowers102. All models were trained for 100 epochs. For mmLiT and mLiT, all images were resized to 36 x 36, including Flowers-102. For the other models, the image size was set to 32 x 32 for Cifar100, Cifar10, and Svhn, and 224 x 224 for Flowers102.
Model # Param. (M) Cifar100 Cifar10 Svhn Flowers102
Vanila variants of ViT [21]
ViT-T [24] 5.4 55.11 79.47 92.04 45.51
ViT-S [24] 21.4 54.08 79.93 94.45 56.17
CVT-13 [25] 20. 73.50 89.02 91.47 54.29
Swin-T [26] 29. 53.28 59.47 71.60 34.51
ResNet-50 [21], [22] 25.6 72.80 91.78 96.45 46.92
mLiT (no pre-training)
mLiT-S 2.38 60.58 84.03 96.01 34.41
mLiT-XS 1.22 57.93 82.74 95.43 35.41
mLiT-XXS 0.67 56.98 80.78 94.97 35.34
Variants of ViT augmented by Masked Autoencoders [21]
ViT-T+SSAT 5.8 69.64 91.65 97.52 57.20
ViT-S+SSAT 21.8 73.37 94.05 97.87 61.15
CVT-13+SSAT 20.3 75.16 95.93 97.00 68.82
Swin-T+SSAT 29.3 60.68 83.12 85.83 54.72
mmLiT (pre-trained on Cifar100 )
mmLiT-S 2.38 78.18 94.62 97.39 59.59
mmLiT-XS 1.22 75.91 94.01 97.26 57.46
mmLiT-XXS 0.67 73.42 92.21 97.32 48.59

5 Discussion↩︎

The results from Table 8 show that ViT can learn better with streamlined MoE at a much smaller scale. In our current proposal, \(\mathbf{V}\) and \(\mathbf{W_2}\) are shared across an MoE layer. The main consideration behind this arrangement is the non-linear \(\mathrm{silu}\) applied on \(\mathbf{xW}\). Nevertheless, we have also explored two other possible sharing arrangements on mmLiT-S (see Table 9), revealing that retaining \(\mathbf{W}\) gives a slight advantage.

Table 9: Investigation into various ways of sharing linear transformations in SwiGLU across MoE layer
Items shared Acc. @ Cifar100
\(\mathbf{V}\), \(\mathbf{W_2}\) 79.15
\(\mathbf{V}\), \(\mathbf{W}\) 78.34
\(\mathbf{W}\), \(\mathbf{W_2}\) 78.87

In our design, we have progressively reduced the hidden size and increased the number of experts at several stages. This arrangement is inspired by convolutional neural networks [27], where the size of the feature maps is reduced successively across layers, and the number of feature maps increases as the network gets deeper. We found that with the arrangement of MoE, a reduction in hidden size leads to little deterioration in performance if the reduction is not significant at the early layers and not more than 75% in the last layers. Further experiments are required in future to confirm this.

When we performed fine-tuning on the four datasets, we started by applying the fine-tuning setup we used for Cifar100 to the other three datasets. But overfitting at an early stage became a consistent issue, and because of that, we had to modify the setup to accommodate each dataset. We encountered similar problem when we tried to perform supervised learning with mLiT from scratch. Therefore, we reported setups in Table 4 and Table 5 for clarity. It is worth nothing that we did not perform a thorough search on the configurations; we simply adjusted the hyperparameters to avoid overfitting during training. Furthermore, unlike in [21], we did not use advanced augmentation techniques in all our fine-tunings.

During pre-training of mmLiT, we used different amount of epochs for the different sizes of models. From Figure 6, it is clear that more epochs should have been allocated for mmLiT-XS and mmLiT-XXS from the perspective of loss. However one of the downsides of masked autoencoders is the unclear relationship between the loss of a model and its subsequent performance in downstream tasks. In our experience, we have seen cases where further pre-training for a few hundreds epochs at a later stage leads to similar or poorer performance in classification. This might be a specific issue when the dataset is small. Despite the above, we believe that with such datasets, further pre-training of thousands of epochs should lead to a better model.

In this study, we found that mLiT pre-trained on only 50,000 images can serve as a sort of foundation model [28] even at the smallest size. On both Cifar10 and Flowers102, models pre-trained on Cifar100 improve by at least 10%. On Cifar10, mmLiT-XXS can reach 90% accuracy in less 40 epochs. On Svhn, the improvements are modest. This is partly because mLiT is already performing well without pre-training. It is also possibly due to the lack of similar images in Cifar100. We believe a model pre-trained on a slightly larger and a more diverse dataset can perform competitively on various simpler tasks at these tiny scales.

In literature, it is common to integrate convolutional layers [23], [29] or employ a convolutional ‘teacher’ [24] for imparting inductive bias. It seems the use of streamlined MoE can help alleviate the lack of inductive bias. However, as demonstrated in [7], [21], with a masked auto-encoder setup, the problem of inductive bias can be overcome.

6 Conclusion↩︎

In this paper, we demonstrated the potential of streamlined MoE architecture to create very lightweight vision transformers. By sharing parameters across MoE layers and adopting depth-wise scaling, we achieved a balance between model complexity and performance. Furthermore, our findings suggest that pre-training on a slightly larger and more diverse dataset could enhance the model’s versatility and efficacy across various tasks. The streamlined MoE approach appears promising in mitigating the lack of inductive bias, particularly when used in conjunction with masked autoencoder setups.

References↩︎

[1]
Di Liu, Hao Kong, Xiangzhong Luo, Weichen Liu, and Ravi Subramaniam. Bringing ai to edge: From deep learning’s perspective. Neurocomputing, 485:297–320, 2022.
[2]
Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, et al. An image is worth 16x16 words: Transformers for image recognition at scale. arXiv preprint arXiv:2010.11929, 2020.
[3]
J. Zheng, L. Yang, Y. Li, K. Yang, Z. Wang, and J. Zhou. Lightweight vision transformer with spatial and channel enhanced self-attention. In 2023 IEEE/CVF International Conference on Computer Vision Workshops (ICCVW), pages 1484–1488, Los Alamitos, CA, USA, oct 2023. IEEE Computer Society.
[4]
Junting Pan, Adrian Bulat, Fuwen Tan, Xiatian Zhu, Lukasz Dudziak, Hongsheng Li, Georgios Tzimiropoulos, and Brais Martinez. Edgevits: Competing light-weight cnns on mobile devices with vision transformers. In European Conference on Computer Vision, 2022.
[5]
Wenhai Wang, Enze Xie, Xiang Li, Deng-Ping Fan, Kaitao Song, Ding Liang, Tong Lu, Ping Luo, and Ling Shao. Pyramid vision transformer: A versatile backbone for dense prediction without convolutions. In 2021 IEEE/CVF International Conference on Computer Vision (ICCV), pages 548–558, 2021.
[6]
Sachin Mehta and Mohammad Rastegari. Mobilevit: Light-weight, general-purpose, and mobile-friendly vision transformer. In International Conference on Learning Representations, 2022.
[7]
Jen Hong Tan. Pre-training of lightweight vision transformers on small datasets with minimally scaled images, 2024.
[8]
Kaiming He, Xinlei Chen, Saining Xie, Yanghao Li, Piotr Dollár, and Ross Girshick. Masked autoencoders are scalable vision learners. In Proceedings of the IEEE/CVF conference on computer vision and pattern recognition, pages 16000–16009, 2022.
[9]
Noam Shazeer, *Azalia Mirhoseini, *Krzysztof Maziarz, Andy Davis, Quoc Le, Geoffrey Hinton, and Jeff Dean. Outrageously large neural networks: The sparsely-gated mixture-of-experts layer. In International Conference on Learning Representations, 2017.
[10]
Noam Shazeer. Glu variants improve transformer, 2020.
[11]
Sachin Mehta, Marjan Ghazvininejad, Srinivasan Iyer, Luke Zettlemoyer, and Hannaneh Hajishirzi. Delight: Deep and light-weight transformer. In International Conference on Learning Representations, 2021.
[12]
Joshua Ainslie, James Lee-Thorp, Michiel de Jong, Yury Zemlyanskiy, Federico Lebrón, and Sumit Sanghai. Gqa: Training generalized multi-query transformer models from multi-head checkpoints. arXiv preprint arXiv:2305.13245, 2023.
[13]
David Rau. Sparsely-gated mixture-of-experts pytorch implementation, 2019.
[14]
Alex Krizhevsky and Geoffrey Hinton. Learning multiple layers of features from tiny images. Technical Report 0, University of Toronto, Toronto, Ontario, 2009.
[15]
Maria-Elena Nilsback and Andrew Zisserman. Automated flower classification over a large number of classes. In Indian Conference on Computer Vision, Graphics and Image Processing, Dec 2008.
[16]
Yuval Netzer, Tao Wang, Adam Coates, Alessandro Bissacco, Bo Wu, and Andrew Y. Ng. Reading digits in natural images with unsupervised feature learning. In NIPS Workshop on Deep Learning and Unsupervised Feature Learning 2011, 2011.
[17]
Jimmy Lei Ba, Jamie Ryan Kiros, and Geoffrey E. Hinton. Layer normalization, 2016.
[18]
Ilya Loshchilov and Frank Hutter. Decoupled weight decay regularization. In International Conference on Learning Representations, 2019.
[19]
Priya Goyal, Piotr Dollár, Ross Girshick, Pieter Noordhuis, Lukasz Wesolowski, Aapo Kyrola, Andrew Tulloch, Yangqing Jia, and Kaiming He. Accurate, large minibatch sgd: Training imagenet in 1 hour, 2018.
[20]
Ilya Loshchilov and Frank Hutter. : Stochastic gradient descent with warm restarts. In International Conference on Learning Representations, 2017.
[21]
Srijan Das, Tanmay Jain, Dominick Reilly, Pranav Balaji, Soumyajit Karmakar, Shyam Marjit, Xiang Li, Abhijit Das, and Michael Ryoo. Limited data, unlimited potential: A study on vits augmented by masked autoencoders. 2024 IEEE/CVF Winter Conference on Applications of Computer Vision (WACV), 2024.
[22]
Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. Deep residual learning for image recognition. In Proceedings of the IEEE conference on computer vision and pattern recognition, pages 770–778, 2016.
[23]
Ali Hassani, Steven Walton, Nikhil Shah, Abulikemu Abuduweili, Jiachen Li, and Humphrey Shi. Escaping the big data paradigm with compact transformers. arXiv preprint arXiv:2104.05704, 2021.
[24]
Hugo Touvron, Matthieu Cord, Matthijs Douze, Francisco Massa, Alexandre Sablayrolles, and Hervé Jégou. Training data-efficient image transformers & distillation through attention. In International conference on machine learning, pages 10347–10357. PMLR, 2021.
[25]
Haiping Wu, Bin Xiao, Noel Codella, Mengchen Liu, Xiyang Dai, Lu Yuan, and Lei Zhang. Cvt: Introducing convolutions to vision transformers. In 2021 IEEE/CVF International Conference on Computer Vision (ICCV), pages 22–31, 2021.
[26]
Z. Liu, Y. Lin, Y. Cao, H. Hu, Y. Wei, Z. Zhang, S. Lin, and B. Guo. Swin transformer: Hierarchical vision transformer using shifted windows. In 2021 IEEE/CVF International Conference on Computer Vision (ICCV), pages 9992–10002, Los Alamitos, CA, USA, oct 2021. IEEE Computer Society.
[27]
Alex Krizhevsky, Ilya Sutskever, and Geoffrey E Hinton. Imagenet classification with deep convolutional neural networks. Advances in neural information processing systems, 25, 2012.
[28]
Rishi Bommasani, Drew A. Hudson, Ehsan Adeli, Russ Altman, Simran Arora, Sydney von Arx, Michael S. Bernstein, Jeannette Bohg, Antoine Bosselut, Emma Brunskill, Erik Brynjolfsson, S. Buch, Dallas Card, Rodrigo Castellon, Niladri S. Chatterji, Annie S. Chen, Kathleen A. Creel, Jared Davis, Dora Demszky, Chris Donahue, Moussa Doumbouya, Esin Durmus, Stefano Ermon, John Etchemendy, Kawin Ethayarajh, Li Fei-Fei, Chelsea Finn, Trevor Gale, Lauren E. Gillespie, Karan Goel, Noah D. Goodman, Shelby Grossman, Neel Guha, Tatsunori Hashimoto, Peter Henderson, John Hewitt, Daniel E. Ho, Jenny Hong, Kyle Hsu, Jing Huang, Thomas F. Icard, Saahil Jain, Dan Jurafsky, Pratyusha Kalluri, Siddharth Karamcheti, Geoff Keeling, Fereshte Khani, O. Khattab, Pang Wei Koh, Mark S. Krass, Ranjay Krishna, Rohith Kuditipudi, Ananya Kumar, Faisal Ladhak, Mina Lee, Tony Lee, Jure Leskovec, Isabelle Levent, Xiang Lisa Li, Xuechen Li, Tengyu Ma, Ali Malik, Christopher D. Manning, Suvir P. Mirchandani, Eric Mitchell, Zanele Munyikwa, Suraj Nair, Avanika Narayan, Deepak Narayanan, Benjamin Newman, Allen Nie, Juan Carlos Niebles, Hamed Nilforoshan, J. F. Nyarko, Giray Ogut, Laurel Orr, Isabel Papadimitriou, Joon Sung Park, Chris Piech, Eva Portelance, Christopher Potts, Aditi Raghunathan, Robert Reich, Hongyu Ren, Frieda Rong, Yusuf H. Roohani, Camilo Ruiz, Jack Ryan, Christopher R’e, Dorsa Sadigh, Shiori Sagawa, Keshav Santhanam, Andy Shih, Krishna Parasuram Srinivasan, Alex Tamkin, Rohan Taori, Armin W. Thomas, Florian Tramèr, Rose E. Wang, William Wang, Bohan Wu, Jiajun Wu, Yuhuai Wu, Sang Michael Xie, Michihiro Yasunaga, Jiaxuan You, Matei A. Zaharia, Michael Zhang, Tianyi Zhang, Xikun Zhang, Yuhui Zhang, Lucia Zheng, Kaitlyn Zhou, and Percy Liang. On the opportunities and risks of foundation models. ArXiv, 2021.
[29]
Benjamin Graham, Alaaeldin El-Nouby, Hugo Touvron, Pierre Stock, Armand Joulin, Hervé Jégou, and Matthijs Douze. Levit: a vision transformer in convnet’s clothing for faster inference. In Proceedings of the IEEE/CVF international conference on computer vision, pages 12259–12269, 2021.