Transformer Models Explained
Transformers are all you need?
Parts of the Transformer Model
A transformer model is made of transformer blocks stacked in sequence. It may also follow an encoder-decoder scheme. A transformer block has an attention mechanism and a linear feedforward block. They usually utilize both layer normalization layers and residual connections either before or after each attention block and linear block and, in some cases, dropout layers, depending on if the architecture of the transformer follows a pre-normalization or post-normalization convention. The encoder section takes the inputs to the transformer and creates a representation of them through multiple transformer blocks. Then, the decoder takes as input the previous predictions, performs attention on them, and then performs attention between the encoder outputs and the output of the attention block of the previous predictions. This is done through each transformer block until it reaches its final layer, which varies depending on the task.
Commonly, the layers within a transformer are implemented as one-by-one convolutions instead of dense layers. These layers would perform an identical operation with an input matrix of rank two. As the data input for a transformer is of the form batch size by sequence length by an embedding dimension, the data has a rank of three. In this case, a one-by-one convolution learns to create one filter for each of the data channels in the embedding dimension. In contrast, a dense layer would learn weights for the interactions between the data channels in the embedding dimension. It is possible to use either variety of linear layers in a transformer to achieve adequate results. However, a dense layer will use more parameters than a one-by-one convolution. In my experience working with transformer models, I have run into issues in Tensorflow with out-of-memory errors during training while using one-by-one convolutions instead of dense layers.
Attention Mechanism
The attention mechanism in transformer models can vary, but, most commonly, it is a scaled dot-product attention mechanism using a set of queries, keys, and values. I will explain this mechanism in an example where the transformer model would be used for a natural language processing task. It acts as a soft dictionary where the queries and keys are multiplied to find the similarities between their vectors and measure word relevance. The keys can be thought of as the word’s definition, while the queries would be how the word’s meaning can change based on its context. This product is then scaled for numerical stability before being put through a softmax operation. The softmax operation takes each product of queries and keys and creates a probability distribution out of them. The product can be thought of as finding for each word in the input how much the other words caused it. In essence, the product finds how much of a word’s occurrence can be attributed to the other words in the sentence. Then the probabilities are multiplied by the values. The values can be thought of as the main idea or what the sentence is trying to say.
In many cases, it is essential to mask out words later in the sentence because they cannot cause a previous word. Masking is setting the product of words later than it to zero for each word. Then, as a transformer model continually predicts future words to generate an output, an attention block builds a hierarchical representation of the sentence by performing attention at each step with respect to the previously established context. The queries, keys, and values are determined by a separate linear layer or block of linear layers.
In practice, this mechanism for each attention block is done with a multi-headed approach. The idea is to separate the embedding dimension into sections containing different information and have another attention block for each section. Then at the end of the block, the outputs of the individual heads are concatenated. I like to think of this as each head having a different possible interpretation for each word in a sentence, as the same word can have multiple meanings. Then an individual head would check to see if that particular definition of the word makes sense.
Difficulties in Training Transformers
Transformer models are notoriously difficult to train. According to the research papers that I have read on this topic, there are two reasons for transformer training difficulty. The first is a high variance in the magnitude of the gradients in the first steps of gradient descent. The second is the inability of a large transformer model to propagate information during the training process to the earlier layers in the network even with residual connections due to an amplification effect of variance within layer outputs for post-normalization transformers. Pre-normalization transformers have normalized inputs for each layer, which allows the transformer to avoid the problem of amplification by limiting its dependency on previous layers. This allows pre-normalization transformers to be trained more simply at the cost of potentially important information from previous layers being lost during the normalization process. Post-normalization transformers would have the ability to have more expressive layers and allow for higher performance while being more difficult to train.
Rectified Adam and LAMB Optimizers
The problem of high variance gradients in the early stages of training can be addressed by using a warmup phase in a learning rate scheduler, the rectified adaptive momentum optimizer, or the LAMB optimizer. A warm-up phase addresses the problem of high variance gradients in the early stages of training a transformer by continuously increasing the learning rate at the beginning before learning rate decay so that the model’s weights are within a more suitable neighborhood before fine-tuning them. The Rectified Adam optimizer approaches the problem of gradient variance by adjusting its learning rate based on the exponential moving averages inside the adaptive momentum optimizer. A proposed solution to the amplification problem in post-normalization transformers trained with RAdam is to initialize the weights that control the variance of layer outputs in the early stage of training, namely ADMIN initialization. The LAMB optimizer addresses the problem of the amplification effect by using an adaptive learning rate for each layer in the model and the problem of gradient variance by using a large batch size.
Tips and Tricks
When faced with difficulty training a transformer model, consider increasing the number of warmup steps in the learning rate scheduler, using either RAdam or LAMB for the optimizer, and using a pre-normalization convention within the transformer model architecture. Larger batch sizes also help to stabilize the training of transformers. Hyperparameters should be carefully chosen depending on the variant of the transformer being implemented and can significantly impact model performance. Averaging the weights of the checkpoints in the final stages of training can help model performance considerably as it helps to combat the problem of more recent training examples having a disproportionate effect on the model’s output.
Alternative Modalities
Transformer models are not just useful for natural language processing tasks. The Vision Transformer model, for example, models image data. It does this by dividing an image into patches and can take features from a pre-trained computer vision model, such as the Inception model from each patch. It turns the input image into a sequence of symbolic representations of the content of the patches of the image. A more general variant of the transformer model is the Perceiver transformer. The Perceiver uses the input combined with Fourier features and cross attention with a latent learnable vector and iterative attention more similarly to a recurrent neural network to model various data modalities.
Thoughts on Transformers
Transformer models are a dominant model architecture. I believe that they will continue to scale well, learn from fewer examples than other models, and store vast amounts of data. They will benefit from big data while other models will be more suited to smaller data sets or when prediction speed matters more than accuracy. In general, I expect transformers to become more prevalent in various domains.