Meta-learning, also known as “learning to learn”, intends to design models that can learn new skills or adapt to new environments rapidly with a few training examples. There are three common approaches: 1) learn an efficient distance metric (metric-based); 2) use (recurrent) network with external or internal memory (model-based); 3) optimize the model parameters explicitly for fast learning (optimization-based).

[Updated on 2019-10-01: thanks to Tianhao, we have this post translated in Chinese!]

A good machine learning model often requires training with a large number of samples. Humans, in contrast, learn new concepts and skills much faster and more efficiently. Kids who have seen cats and birds only a few times can quickly tell them apart. People who know how to ride a bike are likely to discover the way to ride a motorcycle fast with little or even no demonstration. Is it possible to design a machine learning model with similar properties — learning new concepts and skills fast with a few training examples? That’s essentially what meta-learning aims to solve.

We expect a good meta-learning model capable of well adapting or generalizing to new tasks and new environments that have never been encountered during training time. The adaptation process, essentially a mini learning session, happens during test but with a limited exposure to the new task configurations. Eventually, the adapted model can complete new tasks. This is why meta-learning is also known as learning to learn.

The tasks can be any well-defined family of machine learning problems: supervised learning, reinforcement learning, etc. For example, here are a couple concrete meta-learning tasks:

  • A classifier trained on non-cat images can tell whether a given image contains a cat after seeing a handful of cat pictures.
  • A game bot is able to quickly master a new game.
  • A mini robot completes the desired task on an uphill surface during test even through it was only trained in a flat surface environment.

Define the Meta-Learning Problem

In this post, we focus on the case when each desired task is a supervised learning problem like image classification. There is a lot of interesting literature on meta-learning with reinforcement learning problems (aka “Meta Reinforcement Learning”), but we would not cover them here.

A Simple View

A good meta-learning model should be trained over a variety of learning tasks and optimized for the best performance on a distribution of tasks, including potentially unseen tasks. Each task is associated with a dataset D, containing both feature vectors and true labels. The optimal model parameters are:

θ=argminθEDp(D)[Lθ(D)]

It looks very similar to a normal learning task, but one dataset is considered as one data sample.

Few-shot classification is an instantiation of meta-learning in the field of supervised learning. The dataset D is often split into two parts, a support set S for learning and a prediction set B for training or testing, D=S,B. Often we consider a K-shot N-class classification task: the support set contains K labelled examples for each of N classes.

few-shot-classification

Fig. 1. An example of 4-shot 2-class image classification. (Image thumbnails are from Pinterest)

Training in the Same Way as Testing

A dataset D contains pairs of feature vectors and labels, D={(xi,yi)} and each label belongs to a known label set Llabel. Let’s say, our classifier fθ with parameter θ outputs a probability of a data point belonging to the class y given the feature vector x, Pθ(y|x).

The optimal parameters should maximize the probability of true labels across multiple training batches BD:

θ=argmaxθE(x,y)D[Pθ(y|x)]θ=argmaxθEBD[(x,y)BPθ(y|x)]; trained with mini-batches.

In few-shot classification, the goal is to reduce the prediction error on data samples with unknown labels given a small support set for “fast learning” (think of how “fine-tuning” works). To make the training process mimics what happens during inference, we would like to “fake” datasets with a subset of labels to avoid exposing all the labels to the model and modify the optimization procedure accordingly to encourage fast learning:

  1. Sample a subset of labels, LLlabel.
  2. Sample a support set SLD and a training batch BLD. Both of them only contain data points with labels belonging to the sampled label set L, yL,(x,y)SL,BL.
  3. The support set is part of the model input.
  4. The final optimization uses the mini-batch BL to compute the loss and update the model parameters through backpropagation, in the same way as how we use it in the supervised learning.

You may consider each pair of sampled dataset (SL,BL) as one data point. The model is trained such that it can generalize to other datasets. Symbols in red are added for meta-learning in addition to the supervised learning objective.

θ=argmaxθELL[ESLD,BLD[(x,y)BLPθ(x,y,SL)]]

The idea is to some extent similar to using a pre-trained model in image classification (ImageNet) or language modeling (big text corpora) when only a limited set of task-specific data samples are available. Meta-learning takes this idea one step further, rather than fine-tuning according to one down-steam task, it optimizes the model to be good at many, if not all.

Learner and Meta-Learner

Another popular view of meta-learning decomposes the model update into two stages:

  • A classifier fθ is the “learner” model, trained for operating a given task;
  • In the meantime, a optimizer gϕ learns how to update the learner model’s parameters via the support set S, θ=gϕ(θ,S).

Then in final optimization step, we need to update both θ and ϕ to maximize:

ELL[ESLD,BLD[(x,y)BLPgϕ(θ,SL)(y|x)]]

Common Approaches

There are three common approaches to meta-learning: metric-based, model-based, and optimization-based. Oriol Vinyals has a nice summary in his talk at meta-learning symposium @ NIPS 2018:

  Model-based Metric-based Optimization-based
Key idea RNN; memory Metric learning Gradient descent
How Pθ(y|x) is modeled? fθ(x,S) (xi,yi)Skθ(x,xi)yi (*) Pgϕ(θ,SL)(y|x)

(*) kθ is a kernel function measuring the similarity between xi and x.

Next we are gonna review classic models in each approach.

Metric-Based

The core idea in metric-based meta-learning is similar to nearest neighbors algorithms (i.e., k-NN classificer and k-means clustering) and kernel density estimation. The predicted probability over a set of known labels y is a weighted sum of labels of support set samples. The weight is generated by a kernel function kθ, measuring the similarity between two data samples.

Pθ(y|x,S)=(xi,yi)Skθ(x,xi)yi

To learn a good kernel is crucial to the success of a metric-based meta-learning model. Metric learning is well aligned with this intention, as it aims to learn a metric or distance function over objects. The notion of a good metric is problem-dependent. It should represent the relationship between inputs in the task space and facilitate problem solving.

All the models introduced below learn embedding vectors of input data explicitly and use them to design proper kernel functions.

Convolutional Siamese Neural Network

The Siamese Neural Network is composed of two twin networks and their outputs are jointly trained on top with a function to learn the relationship between pairs of input data samples. The twin networks are identical, sharing the same weights and network parameters. In other words, both refer to the same embedding network that learns an efficient embedding to reveal relationship between pairs of data points.

Koch, Zemel & Salakhutdinov (2015) proposed a method to use the siamese neural network to do one-shot image classification. First, the siamese network is trained for a verification task for telling whether two input images are in the same class. It outputs the probability of two images belonging to the same class. Then, during test time, the siamese network processes all the image pairs between a test image and every image in the support set. The final prediction is the class of the support image with the highest probability.

siamese

Fig. 2. The architecture of convolutional siamese neural network for few-show image classification.

  1. First, convolutional siamese network learns to encode two images into feature vectors via a embedding function fθ which contains a couple of convolutional layers.
  2. The L1-distance between two embeddings is |fθ(xi)fθ(xj)|.
  3. The distance is converted to a probability p by a linear feedforward layer and sigmoid. It is the probability of whether two images are drawn from the same class.
  4. Intuitively the loss is cross entropy because the label is binary.
p(xi,xj)=σ(W|fθ(xi)fθ(xj)|)L(B)=(xi,xj,yi,yj)B1yi=yjlogp(xi,xj)+(11yi=yj)log(1p(xi,xj))

Images in the training batch B can be augmented with distortion. Of course, you can replace the L1 distance with other distance metric, L2, cosine, etc. Just make sure they are differential and then everything else works the same.

Given a support set S and a test image x, the final predicted class is:

c^S(x)=c(argmaxxiSP(x,xi))

where c(x) is the class label of an image x and c^(.) is the predicted label.

The assumption is that the learned embedding can be generalized to be useful for measuring the distance between images of unknown categories. This is the same assumption behind transfer learning via the adoption of a pre-trained model; for example, the convolutional features learned in the model pre-trained with ImageNet are expected to help other image tasks. However, the benefit of a pre-trained model decreases when the new task diverges from the original task that the model was trained on.

Matching Networks

The task of Matching Networks (Vinyals et al., 2016) is to learn a classifier cS for any given (small) support set S={xi,yi}i=1k (k-shot classification). This classifier defines a probability distribution over output labels y given a test example x. Similar to other metric-based models, the classifier output is defined as a sum of labels of support samples weighted by attention kernel a(x,xi) - which should be proportional to the similarity between x and xi.

siamese

Fig. 3. The architecture of Matching Networks. (Image source: original paper)

cS(x)=P(y|x,S)=i=1ka(x,xi)yi, where S={(xi,yi)}i=1k

The attention kernel depends on two embedding functions, f and g, for encoding the test sample and the support set samples respectively. The attention weight between two data points is the cosine similarity, cosine(.), between their embedding vectors, normalized by softmax:

a(x,xi)=exp(cosine(f(x),g(xi))j=1kexp(cosine(f(x),g(xj))

Simple Embedding

In the simple version, an embedding function is a neural network with a single data sample as input. Potentially we can set f=g.

Full Context Embeddings

The embedding vectors are critical inputs for building a good classifier. Taking a single data point as input might not be enough to efficiently gauge the entire feature space. Therefore, the Matching Network model further proposed to enhance the embedding functions by taking as input the whole support set S in addition to the original input, so that the learned embedding can be adjusted based on the relationship with other support samples.

  • gθ(xi,S) uses a bidirectional LSTM to encode xi in the context of the entire support set S.
  • fθ(x,S) encodes the test sample x visa an LSTM with read attention over the support set S.
    1. First the test sample goes through a simple neural network, such as a CNN, to extract basic features, f(x).
    2. Then an LSTM is trained with a read attention vector over the support set as part of the hidden state:
      h^t,ct=LSTM(f(x),[ht1,rt1],ct1)ht=h^t+f(x)rt1=i=1ka(ht1,g(xi))g(xi)a(ht1,g(xi))=softmax(ht1g(xi))=exp(ht1g(xi))j=1kexp(ht1g(xj))
    3. Eventually f(x,S)=hK if we do K steps of “read”.

This embedding method is called “Full Contextual Embeddings (FCE)”. Interestingly it does help improve the performance on a hard task (few-shot classification on mini ImageNet), but makes no difference on a simple task (Omniglot).

The training process in Matching Networks is designed to match inference at test time, see the details in the earlier section. It is worthy of mentioning that the Matching Networks paper refined the idea that training and testing conditions should match.

θ=argmaxθELL[ESLD,BLD[(x,y)BLPθ(y|x,SL)]]

Relation Network

Relation Network (RN) (Sung et al., 2018) is similar to siamese network but with a few differences:

  1. The relationship is not captured by a simple L1 distance in the feature space, but predicted by a CNN classifier gϕ. The relation score between a pair of inputs, xi and xj, is rij=gϕ([xi,xj]) where [.,.] is concatenation.
  2. The objective function is MSE loss instead of cross-entropy, because conceptually RN focuses more on predicting relation scores which is more like regression, rather than binary classification, L(B)=(xi,xj,yi,yj)B(rij1yi=yj)2.

relation-network

Fig. 4. Relation Network architecture for a 5-way 1-shot problem with one query example. (Image source: original paper)

(Note: There is another Relation Network for relational reasoning, proposed by DeepMind. Don’t get confused.)

Prototypical Networks

Prototypical Networks (Snell, Swersky & Zemel, 2017) use an embedding function fθ to encode each input into a M-dimensional feature vector. A prototype feature vector is defined for every class cC, as the mean vector of the embedded support data samples in this class.

vc=1|Sc|(xi,yi)Scfθ(xi)

prototypical-networks

Fig. 5. Prototypical networks in the few-shot and zero-shot scenarios. (Image source: original paper)

The distribution over classes for a given test input x is a softmax over the inverse of distances between the test data embedding and prototype vectors.

P(y=c|x)=softmax(dφ(fθ(x),vc))=exp(dφ(fθ(x),vc))cCexp(dφ(fθ(x),vc))

where dφ can be any distance function as long as φ is differentiable. In the paper, they used the squared euclidean distance.

The loss function is the negative log-likelihood: L(θ)=logPθ(y=c|x).

Model-Based

Model-based meta-learning models make no assumption on the form of Pθ(y|x). Rather it depends on a model designed specifically for fast learning — a model that updates its parameters rapidly with a few training steps. This rapid parameter update can be achieved by its internal architecture or controlled by another meta-learner model.

Memory-Augmented Neural Networks

A family of model architectures use external memory storage to facilitate the learning process of neural networks, including Neural Turing Machines and Memory Networks. With an explicit storage buffer, it is easier for the network to rapidly incorporate new information and not to forget in the future. Such a model is known as MANN, short for “Memory-Augmented Neural Network”. Note that recurrent neural networks with only internal memory such as vanilla RNN or LSTM are not MANNs.

Because MANN is expected to encode new information fast and thus to adapt to new tasks after only a few samples, it fits well for meta-learning. Taking the Neural Turing Machine (NTM) as the base model, Santoro et al. (2016) proposed a set of modifications on the training setup and the memory retrieval mechanisms (or “addressing mechanisms”, deciding how to assign attention weights to memory vectors). Please go through the NTM section in my other post first if you are not familiar with this matter before reading forward.

As a quick recap, NTM couples a controller neural network with external memory storage. The controller learns to read and write memory rows by soft attention, while the memory serves as a knowledge repository. The attention weights are generated by its addressing mechanism: content-based + location based.

NTM

Fig. 6. The architecture of Neural Turing Machine (NTM). The memory at time t, Mt is a matrix of size N×M, containing N vector rows and each has M dimensions.

MANN for Meta-Learning

To use MANN for meta-learning tasks, we need to train it in a way that the memory can encode and capture information of new tasks fast and, in the meantime, any stored representation is easily and stably accessible.

The training described in Santoro et al., 2016 happens in an interesting way so that the memory is forced to hold information for longer until the appropriate labels are presented later. In each training episode, the truth label yt is presented with one step offset, (xt+1,yt): it is the true label for the input at the previous time step t, but presented as part of the input at time step t+1.

NTM

Fig. 7. Task setup in MANN for meta-learning (Image source: original paper).

In this way, MANN is motivated to memorize the information of a new dataset, because the memory has to hold the current input until the label is present later and then retrieve the old information to make a prediction accordingly.

Next let us see how the memory is updated for efficient information retrieval and storage.

Addressing Mechanism for Meta-Learning

Aside from the training process, a new pure content-based addressing mechanism is utilized to make the model better suitable for meta-learning.

» How to read from memory?
The read attention is constructed purely based on the content similarity.

First, a key feature vector kt is produced at the time step t by the controller as a function of the input x. Similar to NTM, a read weighting vector wtr of N elements is computed as the cosine similarity between the key vector and every memory vector row, normalized by softmax. The read vector rt is a sum of memory records weighted by such weightings:

ri=i=1Nwtr(i)Mt(i), where wtr(i)=softmax(ktMt(i)ktMt(i))

where Mt is the memory matrix at time t and Mt(i) is the i-th row in this matrix.

» How to write into memory?
The addressing mechanism for writing newly received information into memory operates a lot like the cache replacement policy. The Least Recently Used Access (LRUA) writer is designed for MANN to better work in the scenario of meta-learning. A LRUA write head prefers to write new content to either the least used memory location or the most recently used memory location.

  • Rarely used locations: so that we can preserve frequently used information (see LFU);
  • The last used location: the motivation is that once a piece of information is retrieved once, it probably won’t be called again for a while (see MRU).

There are many cache replacement algorithms and each of them could potentially replace the design here with better performance in different use cases. Furthermore, it would be a good idea to learn the memory usage pattern and addressing strategies rather than arbitrarily set it.

The preference of LRUA is carried out in a way that everything is differentiable:

  1. The usage weight wtu at time t is a sum of current read and write vectors, in addition to the decayed last usage weight, γwt1u, where γ is a decay factor.
  2. The write vector is an interpolation between the previous read weight (prefer “the last used location”) and the previous least-used weight (prefer “rarely used location”). The interpolation parameter is the sigmoid of a hyperparameter α.
  3. The least-used weight wlu is scaled according to usage weights wtu, in which any dimension remains at 1 if smaller than the n-th smallest element in the vector and 0 otherwise.
wtu=γwt1u+wtr+wtwwtr=softmax(cosine(kt,Mt(i)))wtw=σ(α)wt1r+(1σ(α))wt1luwtlu=1wtu(i)m(wtu,n), where m(wtu,n) is the n-th smallest element in vector wtu.

Finally, after the least used memory location, indicated by wtlu, is set to zero, every memory row is updated:

Mt(i)=Mt1(i)+wtw(i)kt,i

Meta Networks

Meta Networks (Munkhdalai & Yu, 2017), short for MetaNet, is a meta-learning model with architecture and training process designed for rapid generalization across tasks.

Fast Weights

The rapid generalization of MetaNet relies on “fast weights”. There are a handful of papers on this topic, but I haven’t read all of them in detail and I failed to find a very concrete definition, only a vague agreement on the concept. Normally weights in the neural networks are updated by stochastic gradient descent in an objective function and this process is known to be slow. One faster way to learn is to utilize one neural network to predict the parameters of another neural network and the generated weights are called fast weights. In comparison, the ordinary SGD-based weights are named slow weights.

In MetaNet, loss gradients are used as meta information to populate models that learn fast weights. Slow and fast weights are combined to make predictions in neural networks.

slow-fast-weights

Fig. 8. Combining slow and fast weights in a MLP. is element-wise sum. (Image source: original paper).

Model Components

Disclaimer: Below you will find my annotations are different from those in the paper. imo, the paper is poorly written, but the idea is still interesting. So I’m presenting the idea in my own language.

Key components of MetaNet are:

  • An embedding function fθ, parameterized by θ, encodes raw inputs into feature vectors. Similar to Siamese Neural Network, these embeddings are trained to be useful for telling whether two inputs are of the same class (verification task).
  • A base learner model gϕ, parameterized by weights ϕ, completes the actual learning task.

If we stop here, it looks just like Relation Network. MetaNet, in addition, explicitly models the fast weights of both functions and then aggregates them back into the model (See Fig. 8).

Therefore we need additional two functions to output fast weights for f and g respectively.

  • Fw: a LSTM parameterized by w for learning fast weights θ+ of the embedding function f. It takes as input gradients of f’s embedding loss for verification task.
  • Gv: a neural network parameterized by v learning fast weights ϕ+ for the base learner g from its loss gradients. In MetaNet, the learner’s loss gradients are viewed as the meta information of the task.

Ok, now let’s see how meta networks are trained. The training data contains multiple pairs of datasets: a support set S={xi,yi}i=1K and a test set U={xi,yi}i=1L. Recall that we have four networks and four sets of model parameters to learn, (θ,ϕ,w,v).

meta-net

Fig.9. The MetaNet architecture.

Training Process

  1. Sample a random pair of inputs at each time step t from the support set S, (xi,yi) and (xj,yj). Let x(t,1)=xi and x(t,2)=xj.
    for t=1,,K:
    • a. Compute a loss for representation learning; i.e., cross entropy for the verification task:
      Ltemb=1yi=yjlogPt+(11yi=yj)log(1Pt), where Pt=σ(W|fθ(x(t,1))fθ(x(t,2))|)
  2. Compute the task-level fast weights: θ+=Fw(θL1emb,,LTemb)
  3. Next go through examples in the support set S and compute the example-level fast weights. Meanwhile, update the memory with learned representations.
    for i=1,,K:
    • a. The base learner outputs a probability distribution: P(y^i|xi)=gϕ(xi) and the loss can be cross-entropy or MSE: Litask=yiloggϕ(xi)+(1yi)log(1gϕ(xi))
    • b. Extract meta information (loss gradients) of the task and compute the example-level fast weights: ϕi+=Gv(ϕLitask)
      • Then store ϕi+ into i-th location of the “value” memory M.
    • d. Encode the support sample into a task-specific input representation using both slow and fast weights: ri=fθ,θ+(xi)
      • Then store ri into i-th location of the “key” memory R.
  4. Finally it is the time to construct the training loss using the test set U={xi,yi}i=1L.
    Starts with Ltrain=0:
    for j=1,,L:
    • a. Encode the test sample into a task-specific input representation: rj=fθ,θ+(xj)
    • b. The fast weights are computed by attending to representations of support set samples in memory R. The attention function is of your choice. Here MetaNet uses cosine similarity:
      aj=cosine(R,rj)=[r1rjr1rj,,rNrjrNrj]ϕj+=softmax(aj)M
    • c. Update the training loss: LtrainLtrain+Ltask(gϕ,ϕ+(xi),yi)
  5. Update all the parameters (θ,ϕ,w,v) using Ltrain.

Optimization-Based

Deep learning models learn through backpropagation of gradients. However, the gradient-based optimization is neither designed to cope with a small number of training samples, nor to converge within a small number of optimization steps. Is there a way to adjust the optimization algorithm so that the model can be good at learning with a few examples? This is what optimization-based approach meta-learning algorithms intend for.

LSTM Meta-Learner

The optimization algorithm can be explicitly modeled. Ravi & Larochelle (2017) did so and named it “meta-learner”, while the original model for handling the task is called “learner”. The goal of the meta-learner is to efficiently update the learner’s parameters using a small support set so that the learner can adapt to the new task quickly.

Let’s denote the learner model as Mθ parameterized by θ, the meta-learner as RΘ with parameters Θ, and the loss function L.

Why LSTM?

The meta-learner is modeled as a LSTM, because:

  1. There is similarity between the gradient-based update in backpropagation and the cell-state update in LSTM.
  2. Knowing a history of gradients benefits the gradient update; think about how momentum works.

The update for the learner’s parameters at time step t with a learning rate αt is:

θt=θt1αtθt1Lt

It has the same form as the cell state update in LSTM, if we set forget gate ft=1, input gate it=αt, cell state ct=θt, and new cell state c~t=θt1Lt:

ct=ftct1+itc~t=θt1αtθt1Lt

While fixing ft=1 and it=αt might not be the optimal, both of them can be learnable and adaptable to different datasets.

ft=σ(Wf[θt1Lt,Lt,θt1,ft1]+bf); how much to forget the old value of parameters.it=σ(Wi[θt1Lt,Lt,θt1,it1]+bi); corresponding to the learning rate at time step t.θ~t=θt1Ltθt=ftθt1+itθ~t

Model Setup

lstm-meta-learner

Fig.10. How the learner Mθ and the meta-learner RΘ are trained. (Image source: original paper with more annotations)

The training process mimics what happens during test, since it has been proved to be beneficial in Matching Networks. During each training epoch, we first sample a dataset D=(Dtrain,Dtest)D^meta-train and then sample mini-batches out of Dtrain to update θ for T rounds. The final state of the learner parameter θT is used to train the meta-learner on the test data Dtest.

Two implementation details to pay extra attention to:

  1. How to compress the parameter space in LSTM meta-learner? As the meta-learner is modeling parameters of another neural network, it would have hundreds of thousands of variables to learn. Following the idea of sharing parameters across coordinates,
  2. To simplify the training process, the meta-learner assumes that the loss Lt and the gradient θt1Lt are independent.

train-meta-learner

MAML

MAML, short for Model-Agnostic Meta-Learning (Finn, et al. 2017) is a fairly general optimization algorithm, compatible with any model that learns through gradient descent.

Let’s say our model is fθ with parameters θ. Given a task τi and its associated dataset (Dtrain(i),Dtest(i)), we can update the model parameters by one or more gradient descent steps (the following example only contains one step):

θi=θαθLτi(0)(fθ)

where L(0) is the loss computed using the mini data batch with id (0).

MAML

Fig. 11. Diagram of MAML. (Image source: original paper)

Well, the above formula only optimizes for one task. To achieve a good generalization across a variety of tasks, we would like to find the optimal θ so that the task-specific fine-tuning is more efficient. Now, we sample a new data batch with id (1) for updating the meta-objective. The loss, denoted as L(1), depends on the mini batch (1). The superscripts in L(0) and L(1) only indicate different data batches, and they refer to the same loss objective for the same task.

θ=argminθτip(τ)Lτi(1)(fθi)=argminθτip(τ)Lτi(1)(fθαθLτi(0)(fθ))θθβθτip(τ)Lτi(1)(fθαθLτi(0)(fθ)); updating rule

MAML Algorithm

Fig. 12. The general form of MAML algorithm. (Image source: original paper)

First-Order MAML

The meta-optimization step above relies on second derivatives. To make the computation less expensive, a modified version of MAML omits second derivatives, resulting in a simplified and cheaper implementation, known as First-Order MAML (FOMAML).

Let’s consider the case of performing k inner gradient steps, k1. Starting with the initial model parameter θmeta:

θ0=θmetaθ1=θ0αθL(0)(θ0)θ2=θ1αθL(0)(θ1)θk=θk1αθL(0)(θk1)

Then in the outer loop, we sample a new data batch for updating the meta-objective.

θmetaθmetaβgMAML; update for meta-objectivewhere gMAML=θL(1)(θk)=θkL(1)(θk)(θk1θk)(θ0θ1)(θθ0); following the chain rule=θkL(1)(θk)(i=1kθi1θi)I=θkL(1)(θk)i=1kθi1(θi1αθL(0)(θi1))=θkL(1)(θk)i=1k(Iαθi1(θL(0)(θi1)))

The MAML gradient is:

gMAML=θkL(1)(θk)i=1k(Iαθi1(θL(0)(θi1)))

The First-Order MAML ignores the second derivative part in red. It is simplified as follows, equivalent to the derivative of the last inner gradient update result.

gFOMAML=θkL(1)(θk)

Reptile

Reptile (Nichol, Achiam & Schulman, 2018) is a remarkably simple meta-learning optimization algorithm. It is similar to MAML in many ways, given that both rely on meta-optimization through gradient descent and both are model-agnostic.

The Reptile works by repeatedly:

  • 1) sampling a task,
  • 2) training on it by multiple gradient descent steps,
  • 3) and then moving the model weights towards the new parameters.

See the algorithm below: SGD(Lτi,θ,k) performs stochastic gradient update for k steps on the loss Lτi starting with initial parameter θ and returns the final parameter vector. The batch version samples multiple tasks instead of one within each iteration. The reptile gradient is defined as (θW)/α, where α is the stepsize used by the SGD operation.

Reptile Algorithm

Fig. 13. The batched version of Reptile algorithm. (Image source: original paper)

At a glance, the algorithm looks a lot like an ordinary SGD. However, because the task-specific optimization can take more than one step. it eventually makes SGD(Eτ[Lτ],θ,k) diverge from Eτ[SGD(Lτ,θ,k)] when k > 1.

The Optimization Assumption

Assuming that a task τp(τ) has a manifold of optimal network configuration, Wτ. The model fθ achieves the best performance for task τ when θ lays on the surface of Wτ. To find a solution that is good across tasks, we would like to find a parameter close to all the optimal manifolds of all tasks:

θ=argminθEτp(τ)[12dist(θ,Wτ)2]

Reptile Algorithm

Fig. 14. The Reptile algorithm updates the parameter alternatively to be closer to the optimal manifolds of different tasks. (Image source: original paper)

Let’s use the L2 distance as dist(.) and the distance between a point θ and a set Wτ equals to the distance between θ and a point Wτ(θ) on the manifold that is closest to θ:

dist(θ,Wτ)=dist(θ,Wτ(θ)), where Wτ(θ)=argminWWτdist(θ,W)

The gradient of the squared euclidean distance is:

θ[12dist(θ,Wτi)2]=θ[12dist(θ,Wτi(θ))2]=θ[12(θWτi(θ))2]=θWτi(θ); See notes.

Notes: According to the Reptile paper, “the gradient of the squared euclidean distance between a point Θ and a set S is the vector 2(Θ − p), where p is the closest point in S to Θ”. Technically the closest point in S is also a function of Θ, but I’m not sure why the gradient does not need to worry about the derivative of p. (Please feel free to leave me a comment or send me an email about this if you have ideas.)

Thus the update rule for one stochastic gradient step is:

θ=θαθ[12dist(θ,Wτi)2]=θα(θWτi(θ))=(1α)θ+αWτi(θ)

The closest point on the optimal task manifold Wτi(θ) cannot be computed exactly, but Reptile approximates it using SGD(Lτ,θ,k).

Reptile vs FOMAML

To demonstrate the deeper connection between Reptile and MAML, let’s expand the update formula with an example performing two gradient steps, k=2 in SGD(.). Same as defined above, L(0) and L(1) are losses using different mini-batches of data. For ease of reading, we adopt two simplified annotations: gj(i)=θL(i)(θj) and Hj(i)=θ2L(i)(θj).

θ0=θmetaθ1=θ0αθL(0)(θ0)=θ0αg0(0)θ2=θ1αθL(1)(θ1)=θ0αg0(0)αg1(1)

According to the early section, the gradient of FOMAML is the last inner gradient update result. Therefore, when k=1:

gFOMAML=θ1L(1)(θ1)=g1(1)gMAML=θ1L(1)(θ1)(Iαθ2L(0)(θ0))=g1(1)αH0(0)g1(1)

The Reptile gradient is defined as:

gReptile=(θ0θ2)/α=g0(0)+g1(1)

Up to now we have:

Reptile vs FOMAML

Fig. 15. Reptile versus FOMAML in one loop of meta-optimization. (Image source: slides on Reptile by Yoonho Lee.)

gFOMAML=g1(1)gMAML=g1(1)αH0(0)g1(1)gReptile=g0(0)+g1(1)

Next let’s try further expand g1(1) using Taylor expansion. Recall that Taylor expansion of a function f(x) that is differentiable at a number a is:

f(x)=f(a)+f(a)1!(xa)+f(a)2!(xa)2+=i=0f(i)(a)i!(xa)i

We can consider θL(1)(.) as a function and θ0 as a value point. The Taylor expansion of g1(1) at the value point θ0 is:

g1(1)=θL(1)(θ1)=θL(1)(θ0)+θ2L(1)(θ0)(θ1θ0)+12θ3L(1)(θ0)(θ1θ0)2+=g0(1)αH0(1)g0(0)+α22θ3L(1)(θ0)(g0(0))2+; because θ1θ0=αg0(0)=g0(1)αH0(1)g0(0)+O(α2)

Plug in the expanded form of g1(1) into the MAML gradients with one step inner gradient update:

gFOMAML=g1(1)=g0(1)αH0(1)g0(0)+O(α2)gMAML=g1(1)αH0(0)g1(1)=g0(1)αH0(1)g0(0)+O(α2)αH0(0)(g0(1)αH0(1)g0(0)+O(α2))=g0(1)αH0(1)g0(0)αH0(0)g0(1)+α2αH0(0)H0(1)g0(0)+O(α2)=g0(1)αH0(1)g0(0)αH0(0)g0(1)+O(α2)

The Reptile gradient becomes:

gReptile=g0(0)+g1(1)=g0(0)+g0(1)αH0(1)g0(0)+O(α2)

So far we have the formula of three types of gradients:

gFOMAML=g0(1)αH0(1)g0(0)+O(α2)gMAML=g0(1)αH0(1)g0(0)αH0(0)g0(1)+O(α2)gReptile=g0(0)+g0(1)αH0(1)g0(0)+O(α2)

During training, we often average over multiple data batches. In our example, the mini batches (0) and (1) are interchangeable since both are drawn at random. The expectation Eτ,0,1 is averaged over two data batches, ids (0) and (1), for task τ.

Let,

  • A=Eτ,0,1[g0(0)]=Eτ,0,1[g0(1)]; it is the average gradient of task loss. We expect to improve the model parameter to achieve better task performance by following this direction pointed by A.
  • B=Eτ,0,1[H0(1)g0(0)]=12Eτ,0,1[H0(1)g0(0)+H0(0)g0(1)]=12Eτ,0,1[θ(g0(0)g0(1))]; it is the direction (gradient) that increases the inner product of gradients of two different mini batches for the same task. We expect to improve the model parameter to achieve better generalization over different data by following this direction pointed by B.

To conclude, both MAML and Reptile aim to optimize for the same goal, better task performance (guided by A) and better generalization (guided by B), when the gradient update is approximated by first three leading terms.

Eτ,1,2[gFOMAML]=AαB+O(α2)Eτ,1,2[gMAML]=A2αB+O(α2)Eτ,1,2[gReptile]=2AαB+O(α2)

It is not clear to me whether the ignored term O(α2) might play a big impact on the parameter learning. But given that FOMAML is able to obtain a similar performance as the full version of MAML, it might be safe to say higher-level derivatives would not be critical during gradient descent update.


Cited as:

@article{weng2018metalearning,
  title   = "Meta-Learning: Learning to Learn Fast",
  author  = "Weng, Lilian",
  journal = "lilianweng.github.io/lil-log",
  year    = "2018",
  url     = "http://lilianweng.github.io/lil-log/2018/11/29/meta-learning.html"
}

If you notice mistakes and errors in this post, don’t hesitate to leave a comment or contact me at [lilian dot wengweng at gmail dot com] and I would be very happy to correct them asap.

See you in the next post!

Reference

[1] Brenden M. Lake, Ruslan Salakhutdinov, and Joshua B. Tenenbaum. “Human-level concept learning through probabilistic program induction.” Science 350.6266 (2015): 1332-1338.

[2] Oriol Vinyals’ talk on “Model vs Optimization Meta Learning”

[3] Gregory Koch, Richard Zemel, and Ruslan Salakhutdinov. “Siamese neural networks for one-shot image recognition.” ICML Deep Learning Workshop. 2015.

[4] Oriol Vinyals, et al. “Matching networks for one shot learning.” NIPS. 2016.

[5] Flood Sung, et al. “Learning to compare: Relation network for few-shot learning.” CVPR. 2018.

[6] Jake Snell, Kevin Swersky, and Richard Zemel. “Prototypical Networks for Few-shot Learning.” CVPR. 2018.

[7] Adam Santoro, et al. “Meta-learning with memory-augmented neural networks.” ICML. 2016.

[8] Alex Graves, Greg Wayne, and Ivo Danihelka. “Neural turing machines.” arXiv preprint arXiv:1410.5401 (2014).

[9] Tsendsuren Munkhdalai and Hong Yu. “Meta Networks.” ICML. 2017.

[10] Sachin Ravi and Hugo Larochelle. “Optimization as a Model for Few-Shot Learning.” ICLR. 2017.

[11] Chelsea Finn’s BAIR blog on “Learning to Learn”.

[12] Chelsea Finn, Pieter Abbeel, and Sergey Levine. “Model-agnostic meta-learning for fast adaptation of deep networks.” ICML 2017.

[13] Alex Nichol, Joshua Achiam, John Schulman. “On First-Order Meta-Learning Algorithms.” arXiv preprint arXiv:1803.02999 (2018).

[14] Slides on Reptile by Yoonho Lee.