Few-shot Learning
(30 min) Literature review of popular methods in few-shot learning. With a focus on metric-learning and meta-learning.
1. Introduction
It is an understatement to say that modern Deep Neural Networks (DNNs) require large amount of data during training. This can be shown both in practice and through Learning Theory. For example, the classic benchmark datasets such as ImageNet contains 14 million images (for 1000 classes), language models such as GTP3 are literally trained on THE INTERNET.
While these DNNs’ performance have been impressive, occasionally achieving super-human level accuracy/performance. In real life, data is not as abundant and easy to obtain. The problem is exacerbated when labelled data is required (because £££). Another key challenge in a real-life setting is that model output requirement changes over time. For example, imagine working for an identity-document classification company; at launch, your company may only support 4 types of IDs. Over time, the number of supported IDs will grow. Re-training the model itself is an art and can be very time consuming.
When quality data is scarce and model reusability is key, few-shot learning have been employed successfully to discover patterns in data and make beneficial predictions. When combined with active learning can often bring a competitive edge at the inception of an ML company!
2. Problem Definition and Datasets
Few-shot classification is an instantiation of meta-learning in the field of supervised learning (I think the distinction between few-shot and meta-learning is rather arbitrary). The best way to understand what few-shot learning is trying to achieve is by examining a few common datasets used to train and evaluate few-shot models. Let’s begin by looking at the K-shot N-class classification task:
K-shot N-class classification task: during training, we are provided with several datasets each containing N classes with K examples, at test time, we are given a new dataset, and asked to make predictions based on the labelled examples.
The term few-shot (or K-shot) therefore means the number of examples provided per class (or task). Pictorially, this is shown as a dataset of datasets:
To be more concrete, the Omniglot dataset contains 50 different alphabets, each containing different number (around 20) characters. Each character is drawn by 20 people, amounting to a total of 1623 characters.
Going back to the earlier definition, each alphabet’s OCR problem can be characterised as an S-class 20-shots. Where S corresponds to the size of each alphabet.
One of the most challenging dataset used for few-shots learning is the mini-ImageNet which contains 100 classes, each with only 600 examples. It is worth mentioning, the way train-val-test split for few-shots learning is done on the classes, in this case, we typically have a 64-16-20 split.
Hopefully I have conveyed how challenging few-shot learning is. Personally, I think few-shot is more general than the typical supervised DNNs. People normally call this type of learning “learning to learn”, since our model needs to either have weights/architecture that generalises across different tasks, or to be able to rapidly change the weights to adapt to a new task.
3. Methods/Literature Overview
One rather obvious (though practically difficult) thing to achieve few-shots learning is through data augmentation. The motivation is simple: we do not have enough example data per class, let’s just generate more. People can really go crazy on data augmentation, recent methods use ideas such as GANs and Sailency-guided Hallucinations. Since we can always perform data augmentation in conjunction with other approaches (and frankly, the concept is not new), this class of method will not be discussed in this post.
The other approach which is done by most ML practitioners is transfer learning. The idea of taking an off-the-shelf model trained by big research groups on millions of data points, then fine tune the last few layer’s weights on our own dataset is definitely not new. Therefore, we will also omit discussing these. In fact, for most methods demonstrated, assume we always do some sort of transfer learning via models like VGG net.
Metric-based learning such as the Siamese networks and Matching networks will be described in depth here. The fundamental idea is that these methods learn a distance metric between the querying data point and the support set (labelled examples).
Optimisation-based learning such as the LSTM Meta-Learner and Model-Agnostic Meta-Learning modifies the classic gradient descent. Encouraging the loss function to converge within a small number of optimisation steps.
Model-based learning designs models specifically for fast-learning. This is normally achieved through making the network having explicit memories (much like Neural Turing Machines). These models are interesting but because they are quiet contrived, we will omit their discussions here.
Finally, we will briefly look at Bayesian methods for few-shot learning, including VERSA by Richard Turner et al.
4. Metric-Based Approach
The core-idea behind metric-based few-shot approaches can be thought as performing mixture modelling with some deep network. Let’s define a support set $S = \{X, Y\}$
which is a set of labelled examples during training and/or inference. At a high-level, the metric-based approach produces outputs as weighted sums of support set labels:
$$P_{\theta}(y \mid \mathbf{x}, S)=\sum_{\left(\mathbf{x}_{i}, y_{i}\right) \in S} k_{\theta}\left(\mathbf{x}, \mathbf{x}_{i}\right) y_{i}\quad\quad [1]$$
where the weight is generated by some kernel function $k_\theta$
.
4.1 Siamese Network
The most naive form of metric-based few-shot learning is the Siamese network. Where we simply use a CNN $f_\theta$
(i.e. VGG network) to extract features from two images. Then $l_1$
(or any other differentiable distance metric like cosine distance) is computed between the two the embeddings:
The distance is converted to a pseudo-probability by a linear feedforward layer and a sigmoid (no theoretical grounds here…). To put the animation above in a mathematical form (with $l_1$
distance metric):
$$p\left(x_{i}, x_{j}\right)=\sigma\left(\mathbf{W}\left|f_{\theta}\left(x_{i}\right)-f_{\theta}\left(x_{j}\right)\right|\right)$$
Where $\mathbf{W}$
can be trained in conjunction with the CNN via binary cross entropy loss (whether the two images are of the same class or not).
To classify a query image, we do a ‘brute-force search’, by running our query image and every class of support images through the network above. Final prediction is simply an argmax over the probability that the query image has the same class as the support image:
$$\hat{c}_{S}(x)=c\left(\arg \max _{x_{i} \in S} p\left(x, x_{i}\right)\right)$$
where $c(x)$
is the class label of an image $x$
. This is animated below:
Of course, you can shoe-horn the formula above into the format of equation $[1]$
by performing softmax over all the $p\left(x, x_{i}\right)$
pairs (i.e. re-normalising). But this does not change the performance of the method.
4.2 Relation Network
One obvious downside of the Siamese network is the choice of the distance function being quiet arbitrary. Why did we use $l_1$
over $l_2$
? One other disadvantage is that distances are calculated independently across the support set, which ignores higher-level interaction between the query image and the support set.
Relation network overcomes these disadvantages by using another convolution network acting as the distance function. It’s inputs are the concatenated features between the query image and all images in the support set (one per class). This is animated below:
One note on the loss function, since we are more interested in the relationship between the query image and the support set, which is more like a regression; MSE loss is used instead of cross entropy:
$$\mathcal{L}=\sum_{\left(x_{i}, \mathbf{x}_{j}, y_{i}, y_{j}\right) \in S}\left(g_\theta(z)-1_{y_{i}=y_{j}}\right)^{2}$$
where $z$
is the concatenated output by the feature extractor $f_\phi$
.
4.2 Prototyping Networks
One major disadvantage of both methods above is that we have been using one example image per class. This is a $K$
-shot learning when $K=1$
, i.e. 1-shot learning. We typically want a model that increases in performance as labelled examples increase (larger and larger $K$
and different numbers of $K$
per class). Methods such as Matching networks use an (Bi-)LSTM to ‘read’ the entirety of the query-set.
A simpler, more intuitive method is the idea of prototyping. This idea is similar to manifold learning, where a prototype $v_c$
is the average of all examples of the same class $S_c$
in some latent space:
$$v_{c}=\frac{1}{\left|S_{c}\right|} \sum_{\left(x_{i}, y_{i}\right) \in S_{c}} f_{\theta}\left(x_{i}\right)$$
In the animation below, we see how a prototyping network works, with red squares on the latent map (right) representing prototypes for the three different classes.
At test-time, the distribution over classes for a given test input $x$
is a softmax over the inverse of distances between the test data embedding and prototype vectors:
$$p(y=c \mid x)=\operatorname{softmax}\left(-d_{\varphi}\left(f_{\theta}(x), v_{c}\right)\right)$$
This is slightly more general than just picking the argmin over distances between prototypes and the query image’s latent embedding. Since the distance function $d_\psi$
can be another neural network with learnt weights.
5. Optimisation-Based Approach
Another way to tackle few-shot learning is through modifying how parameters are updated within DL models. This approach is termed optimisation-based. The key concept is these methods utilise another network, i.e. a meta network, that learns how to train a deep learning model across different tasks.
5.1 LSTM Meta-Learner
The standard optimisation algorithms used to train deep neural networks take the form:
$$\theta_{t}=\theta_{t-1}-\alpha_{t} \nabla_{\theta_{t-1}} \mathcal{L}_{t}$$
where $\theta_{t-1}$
are the parameters of the model after $t-1$
updates, $\alpha_t$
is the learning rate at time $t$
, $\mathcal{L}_{t}$
is the loss optimised by the model and $\nabla_{\theta_{t-1}} \mathcal{L}_{t}$
denotes the gradient of the loss. Ravi & Larochelle (2017) observed that, this update resembles the update for the cell state in an LSTM::
$$c_{t}=f_{t} \odot c_{t-1}+i_{t} \odot \tilde{c}_{t}$$
if $f_{t}=1, c_{t-1}=\theta_{t-1}, i_{t}=\alpha_{t}$
and $\tilde{c}_{t}=-\nabla_{\theta_{t-1}} \mathcal{L}_{t}$
. Thus, they proposed training a meta-learner LSTM to learn an update rule for training a learner network. The cell state of the LSTM is the parameters of the learner i.e. $c_{t}=\theta_{t}$
, and the candidate cell state $\tilde{c}_{t}=\nabla_{\theta_{t-1}} \mathcal{L}_{t}$
.
The best way to understand the LSTM meta-learner is by first looking at the animation below and then the algorithm shown in the original paper:
The $\mathcal{D}_{meta-train}$
dataset is the dataset of datasets we discussed in the second section. Basically, we sample a dataset (1 shot, 5 class in the animation above) and use that to update the learner’s parameters via the meta-learner (whose parameters are fixed). We then use the test set of the sampled dataset to update meta-learner’s parameters. It is a bit weird to train on the test set. But remember in this setting, our true test set is another dataset of datasets. We repeat the procedure above by sampling another dataset from $\mathcal{D}_{meta-train}$
.
One key challenge for this method is - the learner has tens of thousands of parameters, which the meta-learner needs to output. To prevent an explosion of meta-learner parameters, they employed some sort of parameter sharing. I.e. meta-learner only modifies a few “key layers” within the learner network. The original paper has more in-depth description on how this is achieved.
5.2 MAML
MAML stands for Model-Agnostic Meta-Learning (Finn, et al. 2017), is a fairly general optimisation algorithm, that can be used not only for few-shot classification, but also generalises to reinforcement learning. It is quiet easy to implement also, and achieves SOTA (or used to) on a variety of tasks.
The high level ideal learning for MAML is updating parameters so that we always update towards $\theta^*$
which are the parameters that are “closest” to the optimal parameters for all the subtasks. Hence fine-tuning is more efficient. Pictorially, this is shown in the animation below, where $\theta_i^*$
are optimal parameters for the $i$
th task.
The algorithm for MAML is very simple. Algorithm below is taken from the original paper. Unfortunately, a different set of notations are used, but $p(\mathcal{T})$
here literally means the meta training dataset (dataset of datasets) as in meta-LSTM.