July 25, 2024
In this paper, we explore a strategy that uses Mixture-of-Experts (MoE) to streamline, rather than augment, vision transformers. Each expert in an MoE layer is a SwiGLU feedforward network, where \(\mathbf{V}\) and \(\mathbf{W_2}\) are shared across the layer. No complex attention or convolutional mechanisms are employed. Depth-wise scaling is applied to progressively reduce the size of the hidden layer and the number of experts is increased in stages. Grouped query attention is used. We studied the proposed approach with and without pre-training on small datasets and investigated whether transfer learning works at this scale. We found that the architecture is competitive even at a size of 0.67M parameters.
In real-world applications of computer vision, such as edge intelligence, small and performant models are still preferred to overcome computational challenges [1]. Vision Transformers (ViTs) [2] have achieved remarkable results, but their performance drops significantly when the model size and dataset are small [3]. Consequently, there are studies investigating lightweight vision transformers that perform well on mid-size datasets. Almost all of these studies either use new types of attention blocks [4], [5] or integrate convolutional mechanisms [6] into their architectures.
On the other hand, Tan [7] has shown that by employing Masked Auto-Encoder (MAE) [8] as a pre-training strategy, it is possible to get ViT to learn effectively from small datasets. In that work, the ViT consists of 12 transformer encoder layers, each containing a multi-head attention component and a feedforward network. The feedforward network consists of two linear layers: the first expands the output to twice, rather than four times, the embedding size, and the second reduces the output back to the embedding size. To further lighten the model, reducing the expanded output size in the middle of the feedforward network can help, but excessive reduction can negatively affect model performance.
With these considerations in mind, we designed an architecture that uses Mixture-of-Experts (MoE) [9] to streamline vision transformers. In our architecture, each expert in a MoE layer is formed by a SwiGLU [10] feedforward network. By design, SwiGLU is heavier in terms of parameters compared to a typical multi-layer perceptron. However, with several experts in a MoE layer, we are able to make the hidden size in SwiGLU smaller than the embedding size without negatively affecting model performance. Furthermore, we share 2 out of the 3 linear transformations in each SwiGLU across the layer. This helps to significantly lower the parameter count while maintaining the strength of MoE. Beyond that, to further reduce the number of parameters, we progressively increase the number of experts in the MoE in stages, while linearly reducing the hidden size by depth, borrowing the idea from depth-wise scaling [11]. Lastly, we use grouped query attention [12] to keep the parameter count low. Source code will be provided in near future.
Our proposed approach consists of two parts: mLiT and mmLiT. mLiT is a mixture-of-experts based Lightweight vision Transformer, and mmLiT is a mLiT pre-trained and fine-tuned using masked auto-encoder pretraining strategy.
Similar to the original ViT [2], we start by dividing each image into non-overlapping patches. Each of these patches is linearly transformed into a set of embeddings, augmented by learnable positional embeddings. The processed embeddings go through a series of MoE-based transformer encoder layers. Figure 1 shows the overall structure of our transformer encoder layer .
Grouped query attention divides query heads into \(G\) groups, with each group sharing a single key head and value head. For instance, GQA-1, with a single group, is equivalent to multi-query attention (MQA); while GQA-\(H\), with groups equal to the number of heads (\(H\)), is equivalent to multi-head attention (MHA). An intermediate number of groups results in a model that is higher quality than MQA but faster than MHA [12], providing a favorable trade-off between performance and memory bandwidth.
A single Mixture-of-Experts layer comprises \(t\) number of expert networks \(E_1, E_2, \cdots, E_t\), with a gating network \(G\) that directs each input vector/embedding to the relevant expert(s) in the layer. For each vector or embedding, the output of \(G\) is a sparse \(t\)-dimensional vector. Figure 2 shows an overview of the working of MoE when the input is a vector.
Let’s denote a vector/embedding by \(x\), the output of the \(i\)-th expert by \(E_i(x)\), and the \(i\)-th component of the output of the gating network \(G(x)\) by \(G(x)_i\). The output \(y\) of the MoE is given by:
\[y = \sum_{i=1}^t G(x)_i E_i(x)\]
As the output \(G(x)\) is sparse, in MoE, no computation is performed on the \(j\)-th expert when \(G(x)_j = 0\). In many applications, however, the input to the layer is generally of shape \((b,n,m)\), where \(b\), \(n\), and \(m\) denote batch size, the number of embeddings and the embedding size, respectively. Therefore, it is common to reshape the input tensor into 2D [13] before the MoE layer. Figure 3 shows the workflow of the actual implementation of MoE.
Noisy Top-K Gating Network. We follow the proposal by Shazeer et. al. [9] that adds sparsity and noise to a softmax gating mechanism. Assume a reshaped input tensor \(\mathbf{x}\) with a shape \((b \cdot n, m)\) and two trainable weight matrices \(\mathbf{W}_g\) and \(\mathbf{W}_{noise}\) , both of shape \((m, t)\), where \(t\) is the total number of experts in the layer and \(k\) is the number of expert(s) to be selected. The output of the gating network is given by
\[G(\mathbf{x}) = \mathrm{Softmax}_k \left( H(\mathbf{x}) \right)\] where \[H(\mathbf{x}) = \mathbf{x}\mathbf{W}_g + \mathrm{randn\_like({\mathbf{x}\mathbf{W}_g}}) \odot \mathrm{Softplus}\left( \mathbf{x}\mathbf{W}_{noise} \right)\]
\(\odot\) denotes the element-wise product. \(\mathrm{randn\_like}\) generates a tensor filled with random numbers from a normal distribution with mean 0 and variance 1, and the shape of the tensor is equal to the shape of the output from \(\mathbf{x} \mathbf{W}_g\). \(\mathrm{Softmax}_k\) applies softmax (row-wise) only on the top \(k\) elements, with the rest set to 0 in the output. The equation below illustrates how \(\mathrm{Softmax}_k\) works when \(k\) is 2 and \(t\) (the number of experts) is 4:
\[\mathrm{Softmax}_k \left( \left[ \begin{array}{rrrr} 4.4742 & -5.6365 & 6.8226 & 0.9960 \\ 3.5298 & 2.3049 & 1.2113 & -1.3946 \\ -2.2414 & 0.3925 & 1.6676 & -1.9253 \\ \end{array} \right]\right) = \begin{bmatrix} 0.0872 & 0.0000 & 0.9128 & 0.0000 \\ 0.7729 & 0.2271 & 0.0000 & 0.0000 \\ 0.0000 & 0.2184 & 0.7816 & 0.0000 \\ \end{bmatrix}\]
Losses for MoE. To encourage all experts to have equal importance, two loss functions are introduced: load balancing loss and importance loss. To calculate load balancing loss, we first determine the following probability by
\[P(\mathbf{x}) = \Phi \left( \frac{\mathbf{x}\mathbf{W}_g- \Psi(H(\mathbf{x}))}{\mathrm{Softplus}(\mathbf{x}\mathbf{W}_{noise})} \right)\] where \(\Phi\) is the cumulative distribution function of a standard normal distribution. To calculate the output of \(\Psi\), we let \(\mathbf{H}\) denote the output of \(H(\mathbf{x})\), which has a shape of \((b \cdot n, t)\), with each element in \(\mathbf{H}\) denoted by \(h_{r,c}\) . Furthermore, we let \(\mathbf{H}_r\) denote a row by index \(r\) in \(\mathbf{H}\). Similarly, \(\psi_{r,c}\) denotes each element in the output of \(\Psi(H(\mathbf{x}))\). If we let \(l^{k}_r\) and \(l^{k+1}_r\) denote the \(k^{th}\) and \(k+1^{th}\) largest values respectively for a row \(r\) in \(\mathbf{H}\), \(l^{k}_r > l^{k+1}_r\), then
\[\psi_{r,c} = \left\{ \begin{array}{ll} l^{k+1}_rin\mathbf{H}_r &ifh_{r,c} \geq l^{k}_r \\ l^{k}_rin\mathbf{H}_r &ifh_{r,c} < l^{k}_r \\ \end{array} \right.\] where \(k\) is the amount of expert(s) to be selected in the layer. As an example, assume we have
\[\mathbf{H}= \left[ \begin{array}{rrrr} 4.4742 & -5.6365 & 6.8226 & 0.9960 \\ 3.5298 & 2.3049 & 1.2113 & -1.3946 \\ -2.2414 & 0.3925 & 1.6676 & -1.9253 \\ \end{array} \right]\] and \(k\) is 2, then for each row in \(\mathbf{H}\) (with row index starting from 0), we have
\[\begin{array}{crr} \hline row & k^{th}largest & k+1^{th}largest \\ \hline 0 & 4.4742 & 0.9960 \\ 1 & 2.3049 & 1.2113 \\ 2 & 0.3925 & -1.9253 \\ \hline \end{array}\] The output of \(\Psi\) is
\[\Psi(\mathbf{H})= \left[ \begin{array}{rrrr} 0.9960 & 4.4742 & 0.9960 & 4.4742 \\ 1.2113 & 1.2113 & 2.3049 & 2.3049 \\ 0.3925 & -1.9253 & -1.9253 & 0.3925 \\ \end{array} \right]\] Lastly, with the probability calculated, the load balancing loss is given by:
\[L_{load}(\mathbf{x})= w_{load} \times \mathrm{CV}\left( \mathrm{sum}_0 (P(\mathbf{x}))\right)\] where \(w_{load}\) is a hand-tuned scaling factor, \(\mathrm{CV}\) is the function to calculate coefficient of variation for the input, and \(\mathrm{sum}_0\) performs column-wise summation on \(P(\mathbf{x})\).To determine importance loss, we simply do
\[L_{importance} = w_{importance} \times \mathrm{sum}_0 (G(\mathbf{x}))\] and the total loss is
\[L = L_{importance}+L_{load}\label{eq:moe95loss}\tag{1}\] In our implementation, we set \(w_{importance} = w_{load} = 1 \times 10^{-2}\).
SwiGLU. Each expert in a MoE layer is a SwiGLU FeedForward Network (FFN). Assume we have a tensor \(\mathbf{x}\) of shape \((n, m)\), where \(n\) is the number of embeddings and \(m\) is the embedding size, we define \(\mathbf{W}, \mathbf{V}\) and \(\mathbf{W_2}\) of shape \((m,d_h)\), \((m, d_h)\) and \((d_h, m)\) respectively. \(d_h\) is the hidden size. The output of the network is given by:
\[\mathrm{FFN_{SwiGLU}}(\mathbf{x}) = \left(\mathrm{silu}\left(\mathbf{xW} \right) \odot \left(\mathbf{xV} \right) \right)\mathbf{W}_2\] Biases are omitted in the above for \(\mathbf{W}, \mathbf{V}\) and \(\mathbf{W_2}\). \(\mathrm{silu}\) is the swish function:
\[\mathrm{silu}(x) = x * \sigma(x)\] where \(\sigma(x)\) is the logistic sigmoid. Dropout is applied on the output of SwiGLU FFN.
In mLiT, \(\mathbf{V}\) and \(\mathbf{W_2}\) and their corresponding biases in each expert are shared in a MoE layer.
Our vision transformer does not have a fixed hidden size \(d_h\) across the transformer encoder layers. Instead, the hidden size is linearly reduced from the first layer to the last layer. Let \(d^{\mathrm{first}}_h\) and \(d^{\mathrm{last}}_h\) denote the hidden size of the first and the last layer respectively, the hidden size of any layer is given by
\[l^i_h = \left\lfloor \frac{(L_E - 1)-i}{L_E-1}\left(d^{\mathrm{first}}_h-d^{\mathrm{last}}_h \right) \right\rfloor + d^{\mathrm{last}}_h \label{eq:hid95size}\tag{2}\] where \(L_E\) is the total number of the transformer encoder layers, \(\lfloor \cdot \rfloor\) is floor operator, and \(i\) is a layer index starting from 0.
Furthermore, the number of experts in mLiT is increased at different stages, specifically from 3 to 5 after every 3, 4, or 5 layers, depending on the model size. See Figure 4 for more detail.
In mmLiT, we apply masked auto-encoder on mLiT. We follow strictly the process outlined in [7], which is based on the original MAE paper [8]. The masking ratio is set to 0.75. However, unlike in [7], we did not incorporate a separable learnable positional embeddings at the decoder’s input (See Figure 5). Similar to [7], we compute the loss of the auto-encoder as a sum of the loss on masked patches and an additional, discounted loss on unmasked patches. The total loss is given by
\[\mathrm{Loss} = \mathrm{MSE}_{\mathrm{masked}} + \alpha \cdot \mathrm{MSE}_{\mathrm{unmasked}} + \beta \cdot L \label{eq:loss}\tag{3}\] where \(\alpha\) denotes the discounting factor, \(\beta\) is the loss coefficient for MoE. \(L\) is calculated from Equation 1 .
We investigate the performance of mLiT and mmLiT at three different sizes: S, XS and XXS. Table 1 shows the details of the encoder of each model, and Table 2 for decoder. For mmLiT, we perform self-supervised pre-training on the Cifar100 [14] training datasets and fine-tune on Cifar100, Cifar10[14], Flowers102 [15] and Svhn [16]. Additionally, we conduct supervised learning () on the aforementioned four datasets.
The input image to the models has a size of 36 x 36, slightly larger than the original 32 x 32 dimensions of these datasets (except Flowers102, which is much larger and varied) . This results in each image being divided into 144 patches, given the patch size of 3 x 3. Similar to [7], an auxiliary dummy patch is added to each image during both the pre-training and fine-tuning phases. This dummy patch, which contains zero values for all elements, is appended as the 145th patch for classification purposes.
Configuration | S | XS | XXS |
---|---|---|---|
Embedding size | 144 | 128 | 108 |
No. of layers (\(L_E\)) | 15 | 12 | 9 |
Hidden size | 72-144 | 32-96 | 27-81 |
No. of Attn. heads | 8 | 8 | 6 |
No. of Attn. groups | 4 | 4 | 3 |
No. of experts | 3-5 | 3-5 | 3-5 |
k | 2 | 2 | 2 |
No. of stages | 3 | 3 | 3 |
No. of parameters | 2.36M | 1.21M | 0.66M |
Dropout rate @ SwiGLU | 0.1 | 0.1 | 0.1 |
Configuration | S,XS,XXS | ||
---|---|---|---|
Embedding size | 108 | ||
No. of layers (\(L_D\)) | 4 | ||
Hidden size | 72 | ||
No. of Attn. heads | 6 | ||
No. of Attn. groups | 3 | ||
No. of experts | 3 | ||
k | 2 | ||
No. of parameters | 0.34M | ||
Dropout rate @ SwiGLU | 0.1 |
All linear layers within the MoE based transformer encoder layers include bias. However, we exclude bias from other linear projection layers in both the encoder and decoder. For initializing weights and biases across all layer types, we rely on the default methods provided by Pytorch. The same applies to our approach to layer normalization [17], where we use Pytorch’s default setup.
mmLiT-S, mmLiT-XS, and mmLiT-XXS were pre-trained on Cifar100 for 4000, 6000 and 8000 epochs, respectively. We employed the AdamW optimizer [18] with a weight decay set at 0.05. The initial 5% of the total epochs were designated for warm-up [19]. We followed this with a cosine decay schedule [20] for the learning rate and adhered to the linear learning rate scaling rule with a base learning rate of \(3e-4\) [8], [19]:
\[\mathrm{lr} = \mathrm{base\_lr} \times \mathrm{batch\_size} /256 \label{eq:lr}\tag{4}\] See Table 3 for more details.
Configuration | Value |
---|---|
Optimizer | AdamW |
Weight decay | 0.05 |
Base learning rate | \(3 \times 10^{-4}\) |
Learning rate schedule | Cosine decay |
Warm-up epochs | 200 (S), 300 (XS), 400 (XXS) |
Batch size | 840 (S), 1280 (XS, XXS) |
Horizontal flipping | \(p=0.5\) |
Random resized cropping | \([0.6, 1]\) |
Color normalization | \(\mathrm{mean}=0.5, \mathrm{std}=0.5\) |
Discounting factor (\(\alpha\)) | 0.1 |
MoE loss coefficient (\(\beta\)) | 0.5 |
For each pre-trained model, we conducted two sets of fine-tuning experiments. First, we fine-tuned the models pre-trained on Cifar100 for Cifar100 classification over 300 epochs. Second, we evaluated the transfer learning capabilities of the models by fine-tuning each one on Cifar100, Cifar10, Flowers102, and SVHN for 100 epochs. Table 4 shows the configuration for the first set of fine-tuning, and Table 5 shows the deviations in configurations with respect to Table 4 for the second set of fine-tuning.
Additionally, we evaluated the performance of various sizes of mLiT models (trained from scratch with only supervised learning). Table 6 shows the deviations in configurations for supervised learning on mLiT on the four datasets. We adhere to the linear learning rate scaling rule (Eq. 3 ) for all fine-tunings and supervised learning. All pre-trainings and fine-tunings were done using mixed precision (torch.float16).
Hyperparameter | Value |
---|---|
Optimizer | AdamW |
Weight decay | 0.05 |
Base learning rate | \(5 \times 10^{-3}\) |
Learning rate schedule | Cosine decay |
Layer-wise decay | 0.9 |
Total epochs | 300 |
Warm-up epochs | 20 |
Batch size | 448 |
Horizontal flipping | \(p=0.5\) |
Random resized cropping | \([0.8, 1]\) |
AutoAugment | Policy for Cifar10 |
Color normalization | \(\mathrm{mean}=0.5, \mathrm{std}=0.5\) |
3pt
Hyperparameter | Flowers102 | Svhn | Cifar10 | Cifar100 |
---|---|---|---|---|
Batch size | 16 | 256 | 256 | 128 |
Learning rate | 10 | 2.5 | 2.5 | 5 |
Layer-wise decay | 0.9 | 0.9 | 0.8 | 0.9 |
Warm-up epoch | 10 | 10 | 5 | 5 |
AutoAugment policy | Cifar10 | Svhn | Cifar10 | Cifar10 |
3pt
Hyperparameter | Flowers102 | Svhn | Cifar10 | Cifar100 |
---|---|---|---|---|
Batch size | 16 | 256 | 256 | 128 |
Learning rate | 10 | 1 | 1 | 2 |
Layer-wise decay | 1 | 1 | 1 | 1 |
Warm-up epoch | 10 | 10 | 5 | 5 |
AutoAugment policy | Cifar10 | Svhn | Cifar10 | Cifar10 |
Figure 6 presents the pre-training losses for various sizes of mmLiT. The reductions in training losses exhibit approximately a log-linear relationship when plotted on logarithmic scales for both the x-axis (epochs) and y-axis (loss values). This indicates an exponential decrease in losses across epochs.
mmLiT-S exhibited the lowest loss, reaching 0.020404 at epoch 4000. mmLiT-XS and mmLiT-XXS achieved losses of 0.021051 and 0.021514 at epochs 6000 and 8000, respectively. The pre-training of mmLiT-S was notably stable, as evidenced by the minimal variability in its training loss. In contrast, the training losses for mmLiT-XS and mmLiT-XXS showed greater variability and more frequent, pronounced spikes, particularly noticeable after epoch 1000.
Table 7 compares the performances of models with similar sizes, all trained or fine-tuned over 300 epochs. mmLiT-S achieves an accuracy nearly 1% higher than Mae-ViT-C100, despite having only two-thirds the parameters. mmLiT-XS, which is one-third the size of Mae-ViT-C100, exhibits a performance only 1.5% lower. More interestingly, mmLiT-XXS, with just 18% of the parameters of Mae-ViT-C100, remains competitive, trailing by only 3.3% in accuracy. Furthermore, mmLiT-XXS slightly outperforms ResNet56, even though ResNet56 has 18% more parameters.
Table 8 illustrates the transfer learning capabilities of the models pre-trained only on Cifar100, with reference to the results reported from [21]. It can be seen that mmLiT is competitive even at a scale of 0.67M parameters, where ViT-T+SSAT is almost 9 times larger, and the rests are at least 30 times larger than mmLiT-XXS. Furthermore, with the exception of CVT-13, which has convolutions in the architecture, mLiT always performs better than vanilla ViT and Swin-T even at the smallest scale on Cifar100, Cifar10 and Svhn. Flowers102 is a fine-grained classification dataset, so it is no surprise that mLiT is not competitive at an image size of 36 x 36.
10pt
Model | Cifar100 | # Params |
---|---|---|
Convolutional Networks (Designed for CIFAR) | ||
ResNet56 [22] | 74.81% | 0.85 M |
ResNet110[22] | 76.63% | 1.73 M |
Vision Transformers[23] | ||
ViT-12/16 | 57.97% | 85.63 M |
ViT-Lite-7/16 | 52.87% | 3.89 M |
ViT-Lite-7/8 | 67.27% | 3.74 M |
ViT-Lite-7/4 | 73.94% | 3.72 M |
Compact Vision Transformers[23] | ||
CVT-7/8 | 70.11% | 3.74 M |
CVT-7/4 | 76.49% | 3.72 M |
Compact Convolutional Transformers[23] | ||
CCT-2/3 \(\times\) 2 | 66.93% | 0.28 M |
CCT-7/3 \(\times\) 2 | 77.72% | 3.85 M |
MAE Vision Transformers [7] | ||
Mae-ViT-C100 | 78.27% | 3.64 M |
Current work on mmLiT | ||
mmLiT-S | 79.15% | 2.38M |
mmLiT-XS | 76.70% | 1.22M |
mmLiT-XXS | 74.95% | 0.67M |
Model | # Param. (M) | Cifar100 | Cifar10 | Svhn | Flowers102 |
---|---|---|---|---|---|
Vanila variants of ViT [21] | |||||
ViT-T [24] | 5.4 | 55.11 | 79.47 | 92.04 | 45.51 |
ViT-S [24] | 21.4 | 54.08 | 79.93 | 94.45 | 56.17 |
CVT-13 [25] | 20. | 73.50 | 89.02 | 91.47 | 54.29 |
Swin-T [26] | 29. | 53.28 | 59.47 | 71.60 | 34.51 |
ResNet-50 [21], [22] | 25.6 | 72.80 | 91.78 | 96.45 | 46.92 |
mLiT (no pre-training) | |||||
mLiT-S | 2.38 | 60.58 | 84.03 | 96.01 | 34.41 |
mLiT-XS | 1.22 | 57.93 | 82.74 | 95.43 | 35.41 |
mLiT-XXS | 0.67 | 56.98 | 80.78 | 94.97 | 35.34 |
Variants of ViT augmented by Masked Autoencoders [21] | |||||
ViT-T+SSAT | 5.8 | 69.64 | 91.65 | 97.52 | 57.20 |
ViT-S+SSAT | 21.8 | 73.37 | 94.05 | 97.87 | 61.15 |
CVT-13+SSAT | 20.3 | 75.16 | 95.93 | 97.00 | 68.82 |
Swin-T+SSAT | 29.3 | 60.68 | 83.12 | 85.83 | 54.72 |
mmLiT (pre-trained on Cifar100 ) | |||||
mmLiT-S | 2.38 | 78.18 | 94.62 | 97.39 | 59.59 |
mmLiT-XS | 1.22 | 75.91 | 94.01 | 97.26 | 57.46 |
mmLiT-XXS | 0.67 | 73.42 | 92.21 | 97.32 | 48.59 |
The results from Table 8 show that ViT can learn better with streamlined MoE at a much smaller scale. In our current proposal, \(\mathbf{V}\) and \(\mathbf{W_2}\) are shared across an MoE layer. The main consideration behind this arrangement is the non-linear \(\mathrm{silu}\) applied on \(\mathbf{xW}\). Nevertheless, we have also explored two other possible sharing arrangements on mmLiT-S (see Table 9), revealing that retaining \(\mathbf{W}\) gives a slight advantage.
In our design, we have progressively reduced the hidden size and increased the number of experts at several stages. This arrangement is inspired by convolutional neural networks [27], where the size of the feature maps is reduced successively across layers, and the number of feature maps increases as the network gets deeper. We found that with the arrangement of MoE, a reduction in hidden size leads to little deterioration in performance if the reduction is not significant at the early layers and not more than 75% in the last layers. Further experiments are required in future to confirm this.
When we performed fine-tuning on the four datasets, we started by applying the fine-tuning setup we used for Cifar100 to the other three datasets. But overfitting at an early stage became a consistent issue, and because of that, we had to modify the setup to accommodate each dataset. We encountered similar problem when we tried to perform supervised learning with mLiT from scratch. Therefore, we reported setups in Table 4 and Table 5 for clarity. It is worth nothing that we did not perform a thorough search on the configurations; we simply adjusted the hyperparameters to avoid overfitting during training. Furthermore, unlike in [21], we did not use advanced augmentation techniques in all our fine-tunings.
During pre-training of mmLiT, we used different amount of epochs for the different sizes of models. From Figure 6, it is clear that more epochs should have been allocated for mmLiT-XS and mmLiT-XXS from the perspective of loss. However one of the downsides of masked autoencoders is the unclear relationship between the loss of a model and its subsequent performance in downstream tasks. In our experience, we have seen cases where further pre-training for a few hundreds epochs at a later stage leads to similar or poorer performance in classification. This might be a specific issue when the dataset is small. Despite the above, we believe that with such datasets, further pre-training of thousands of epochs should lead to a better model.
In this study, we found that mLiT pre-trained on only 50,000 images can serve as a sort of foundation model [28] even at the smallest size. On both Cifar10 and Flowers102, models pre-trained on Cifar100 improve by at least 10%. On Cifar10, mmLiT-XXS can reach 90% accuracy in less 40 epochs. On Svhn, the improvements are modest. This is partly because mLiT is already performing well without pre-training. It is also possibly due to the lack of similar images in Cifar100. We believe a model pre-trained on a slightly larger and a more diverse dataset can perform competitively on various simpler tasks at these tiny scales.
In literature, it is common to integrate convolutional layers [23], [29] or employ a convolutional ‘teacher’ [24] for imparting inductive bias. It seems the use of streamlined MoE can help alleviate the lack of inductive bias. However, as demonstrated in [7], [21], with a masked auto-encoder setup, the problem of inductive bias can be overcome.
In this paper, we demonstrated the potential of streamlined MoE architecture to create very lightweight vision transformers. By sharing parameters across MoE layers and adopting depth-wise scaling, we achieved a balance between model complexity and performance. Furthermore, our findings suggest that pre-training on a slightly larger and more diverse dataset could enhance the model’s versatility and efficacy across various tasks. The streamlined MoE approach appears promising in mitigating the lack of inductive bias, particularly when used in conjunction with masked autoencoder setups.