Ever wondered how Google Translate really works and wanted to create something like it? In this blog post, we embark on a journey to build a Transformer-based translation system inspired by the landmark “Attention Is All You Need” paper, illuminating self-attention, multi-head attention, positional encodings, and the encoder–decoder stack through a hands-on PyTorch implementation trained on the massive Samanantar English–Indic corpus, showing you how to tune epochs, batch size, and learning rate, deploy a Streamlit demo that mirrors Google Translate, and trace the Transformer’s rise from its RNN and LSTM predecessors.
Introduction
In 2017, the deep learning community was transformed by a landmark paper titled “Attention Is All You Need” by Vaswani et al. This paper introduced the Transformer architecture, which has since become the foundation of state-of-the-art models in machine translation, natural language processing (NLP), and beyond. Transformers dramatically changed how sequential data is processed by doing away with recurrence entirely and relying solely on attention mechanisms. In this comprehensive post, I will:
Explain the foundational Transformer paper in depth, including the intuition and mathematics behind self-attention.
Walk through an actual PyTorch implementation of a Transformer (from my
Montekkundan/pytorch-transformer
repo) to see how concepts translate to code.Introduce the Samanantar dataset, a massive multilingual parallel corpus (English ↔ 11 Indic languages) used for training our translation model.
Break down how training and configuration works for a multilingual Transformer, clarifying terms like epochs, batch size, learning rate, and how to adjust them.
Guide you through testing and inference: how to load trained weights, generate translations, and use provided utilities for quick tests.
Explore a simple Streamlit web app for interactive translation, demonstrating how to build a UI akin to Google Translate for model output visualization.
Set the historical context by comparing pre-Transformer translation methods (like RNNs, LSTMs, and early Google Translate systems) with Transformers, discussing trade-offs in complexity vs. quality and whether Transformers are worth it for similar tasks.
Whether you’re a beginner just learning about sequence models or an intermediate practitioner looking to deepen your understanding, this post will take you from the basics of attention to advanced implementation details. Let’s dive in!
Before reading through this whole blog I wrote a smaller and much easier version of how transformers work, as I was writing this blog, I felt the need that there should be a more simpler version so you can read that here.
and this is what we will be building:
“Attention Is All You Need”: Transformer Foundations
The paper “Attention Is All You Need” (Vaswani et al., 2017) dispensed with the recurrent nature of previous sequence models and introduced a novel architecture built entirely on attention mechanisms. Before Transformers, most translation models were based on recurrent neural networks (RNNs) or LSTMs processing sequences token by token. The Transformer’s key insight was to process whole sequences in parallel by leveraging a mechanism called self-attention. This allowed models to consider the entire context of a sentence at once, rather than word-by-word sequential processing.
1.1 Encoder-Decoder Structure with Attention
At a high level, a Transformer is composed of an Encoder and a Decoder, both of which are stacks of layers (the original paper used 6 layers each). The encoder takes an input sentence (e.g., in French) and produces a set of continuous representations. The decoder then generates an output sentence (e.g., in English) from those representations, one token at a time, with the help of attention. Unlike earlier seq2seq models that used RNNs for the encoder and decoder, the Transformer’s encoder and decoder use only attention and feed-forward layers, allowing for significantly more parallelization during training.
Below is a conceptual diagram of the Transformer’s data flow:

Figure: High-level Transformer architecture. The input sequence is first passed through embedding layers (with positional encodings added) to produce continuous vector representations. Multiple encoder layers (stacked) process these representations using self-attention and feed-forward networks, producing encoded features. The decoder (multiple layers stacked) uses self-attention on the target (already generated) tokens and cross-attention over the encoder’s output. Finally, a projection layer produces output token probabilities. The dashed arrow indicates the encoder-to-decoder attention connections.
Each Encoder layer contains two primary sublayers:
Self-attention: The layer allows each position in the input sequence to attend to (i.e., compute relations with) every other position. This helps the model weigh which other words in the input are most relevant to understanding a particular word.
Position-wise Feed-Forward Network: A simple two-layer fully connected network applied to each position’s representation.
Each Decoder layer contains three sublayers:
Self-attention on the decoder’s output so far (with a mask to prevent attending to future tokens that haven’t been generated yet).
Encoder-Decoder attention (cross-attention): This is attention where the queries come from the decoder’s previous layer, and the keys/values come from the encoder’s output. This mechanism allows the decoder to focus on appropriate places in the input sequence (encoder output) for each output token it’s generating.
A Feed-Forward Network, same as in the encoder.
Both encoder and decoder layers use residual connections (adding the input of sublayer to its output) and layer normalization for stable training.
1.2 Self-Attention: Intuition and Math
Self-attention is the core operation that allows Transformers to model dependencies regardless of distance in the sequence. For each position in a sequence, self-attention computes a weighted combination of values from other positions, where the weights (attention scores) are derived from pairwise similarity of query and key vectors.
Mathematically, for a set of queries Q
, keys K
, and values V
(which are all transformations of the input sequence):
This is the scaled dot-product attention formula introduced in the paper. The QK^T
term computes dot products between each query and each key (yielding a score of how much each sequence element should attend to another). Dividing by \sqrt{d_k}
(where d_k
is the dimensionality of keys/queries) stabilizes gradients for large vectors. A softmax turns these scores into a probability distribution over attention “weights.” Finally, these weights are used to sum up the corresponding V
(value) vectors. Each output position is essentially a weighted sum of the input values, where the weights highlight relevant words for that position.
In code, a simplified version of scaled dot-product attention might look like:
1# Pseudocode for scaled dot-product attention
2scores = (Q @ K.T) / math.sqrt(d_k) # Dot products between queries and keys
3if mask is not None:
4 scores = scores.masked_fill(mask == 0, -inf) # Mask out positions (optional, for padding or future tokens)
5weights = softmax(scores, dim=-1) # Normalize to get attention probabilities
6output = weights @ V # Weight the values by the attention
7
This corresponds to what we see in the actual implementation: the code multiplies query @ key.transpose(-2, -1)
and scales by sqrt(d_k)
, applies a mask by filling certain scores with a large negative value (to zero them out after softmax), then softmaxes and multiplies by value
. The result is both the weighted sum (to be passed to the next layer) and the attention weights (often useful for introspection or visualization).
Multi-Head Attention extends this idea by running multiple self-attention operations in parallel. Instead of a single attention with very large vectors, the Transformer uses h
different “heads” (the paper’s default was h=8
heads). For each head, the input vectors are projected into a lower-dimensional subspace (with dimension d_k = d_{\text{model}}/h
), and attention is computed in that subspace. Each head potentially focuses on different types of relationships (e.g., one head might focus on direct word alignments, another on syntactic roles). After the h
independent attentions, their outputs are concatenated and linearly transformed to produce the final output for that layer.
Mathematically:
where each
uses separate projection matrices W^Q_i, W^K_i, W^V_i
for the i
-th head, and W^O
is the output projection. This allows the model to attend to information from different representation subspaces.
In the code implementation, you can see how multi-head attention is realized. First, the model has learned projection matrices w_q
, w_k
, w_v
(and w_o
for output) to project input features into these subspaces. In the forward
method of MultiHeadAttention
, after computing query = self.w_q(q)
(and similarly key, value), the code reshapes these into (batch, n_heads, seq_len, d_k)
by splitting the last dimension (d_model) into (n_heads, d_k)
. This effectively creates the $h$ parallel sets of Q, K, V. Then attention is computed as described above for each head, and finally the heads are concatenated back (the code uses .transpose
and .contiguous().view
to merge the heads and sequence length dimensions back together), followed by the output linear layer w_o
. The implementation closely mirrors the math:
1# Inside MultiHeadAttention.forward
2query = self.w_q(q) # project to d_model
3query = query.view(batch, seq_len, n_heads, d_k).transpose(1, 2) # shape to (batch, n_heads, seq_len, d_k)
4# ... do same for key, value ...
5x, attn_weights = MultiHeadAttention.attention(query, key, value, mask, dropout)
6x = x.transpose(1, 2).contiguous().view(batch, seq_len, d_model) # concat heads
7output = self.w_o(x) # final linear projection
8
(Note: MultiHeadAttention.attention
is a static method carrying out the scaled dot-product logic discussed above.)
One important detail: Masking. As noted in the pseudocode, we apply masks in attention. In the Transformer, masks serve two purposes:
Padding mask: In the encoder (or decoder), to ignore padded positions in sequences of different lengths (so that they don’t influence the attention scores).
Future mask (a.k.a. look-ahead or causal mask): In the decoder’s self-attention, to prevent the model from peeking at later tokens when predicting the next token. This ensures the decoder can be used for autoregressive generation. The mask is usually an upper triangular matrix with zeros in the lower triangle (allowing attention to itself and past tokens, but not future ones).
The code uses such masks. For example, a causal_mask
function creates the triangular mask, and in computing decoder attention, they combine it with a padding mask. The model code’s decoder mask generation looks like:
1decoder_mask = (decoder_input != pad_id).unsqueeze(0).int() & causal_mask(decoder_input.size(0))
2
which ensures decoder self-attention can only attend to earlier tokens and non-pad tokens.
1.3 Positional Encoding: Adding Order Information
One challenge with Transformers is that, unlike RNNs, they don’t inherently encode sequence order (since they look at all tokens simultaneously). The solution proposed is Positional Encoding – adding a vector to each token’s embedding that conveys its position in the sequence. These encodings are designed to allow the model to infer relative positions.
The original paper defined positional encodings using sine and cosine functions of different frequencies:
for each position pos
(0-indexed) and each dimension i
of the positional encoding vector. This means each even-indexed dimension is a sinusoid and each odd-indexed dimension is a cosine, with different wavelengths. The beauty of this formulation is that it allows the model to easily learn to attend by relative positions – for example, the dot product of positional encodings gives information about how far apart two positions are (due to properties of sine and cosine).
In our implementation’s PositionalEncoding
class, these values are computed once in the constructor:
1position = torch.arange(0, seq_len).unsqueeze(1).float() # shape (seq_len, 1)
2div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
3pe[:, 0::2] = torch.sin(position * div_term)
4pe[:, 1::2] = torch.cos(position * div_term)
5pe = pe.unsqueeze(0) # shape (1, seq_len, d_model)
6self.register_buffer('pe', pe) # store for use without gradients
7
This exactly implements the formula above. The forward
method of PositionalEncoding
then simply adds the positional encoding to the input embeddings and applies dropout. By using register_buffer
, the positional encoding pe
is saved inside the model and moved to the appropriate device, but not treated as a parameter to learn (it’s fixed).
After adding positional encodings, the Transformer scales the input embeddings by $\sqrt{d_{\text{model}}}$ (in code, self.embedding(x) * math.sqrt(self.d_model)
) as a minor initialization detail (ensuring the scale of summed embedding + positional encoding isn’t dominated by the embeddings initially).
1.4 Putting It Together: Transformer Encoder & Decoder
To summarize, the Transformer Encoder (stack of N layers) takes an input sequence and produces an output sequence of the same length (each position is now a context-aware vector). Pseudo-steps for the encoder:
Embedding + Positional Encoding: Convert tokens to vectors, add positional info.
For each of N layers:
Self-attention: each position attends to all positions (including itself) in the sequence.
Add & Normalize: Add the original input of this layer (residual connection) and apply layer normalization.
Feed-Forward: a 2-layer MLP applied to each position.
Add & Normalize again.
Output the normalized vectors.
The Transformer Decoder works similarly but with an extra cross-attention:
Embedding + Positional Encoding of the target sequence (so far).
For each of N layers:
Self-attention on the decoder’s own sequence (with causal mask to prevent looking ahead).
Add & Norm.
Encoder-Decoder attention: queries from decoder attend to keys/values from the encoder’s final output. This is where the decoder “looks” at the source sentence to gather relevant information for each output token.
Add & Norm.
Feed-Forward.
Add & Norm.
A final linear layer (projection) converts the decoder’s outputs to scores over the vocabulary (followed by softmax to get probabilities).
Crucially, residual connections (introduced by He et al. for deep networks) are used around every sublayer. In the code, there’s a ResidualConnection
module defined which wraps a sublayer: it takes an input x
, applies layer normalization (self.norm(x)
) then the sublayer function, then adds the original x
(after dropout). This particular implementation uses a Pre-LayerNorm design (normalize before sublayer), which is a slight variation from the original Transformer (which did post-norm). Pre-norm is known to stabilize training for deep Transformers.
An encoder layer is implemented in EncoderBlock
:
1def forward(self, x, src_mask):
2 x = self.residual_connections[0](x,
3 lambda x: self.self_attention_block(x, x, x, src_mask))
4 x = self.residual_connections[1](x, self.feed_forward_block)
5 return x
6
This shows two residual sublayers: first self-attention (with query, key, value all as x
for self-attention on the source, plus a mask), then feed-forward. The decoder layer (DecoderBlock
) is similar but with three sublayers (self-attn, cross-attn, feed-forward).
Finally, the model outputs probabilities via a projection layer:
The paper’s original design uses a linear layer to project the d_{\text{model}}
outputs to vocabulary size, followed by softmax. In code, ProjectionLayer
is a simple nn.Linear(d_model, vocab_size)
. The output of the decoder stack goes through this to produce logits for each vocabulary token.
To reinforce understanding, let’s trace the flow with code from the Transformer
class:
1class Transformer(nn.Module):
2 def __init__(self, encoder, decoder, src_embed, tgt_embed, src_pos, tgt_pos, projection_layer):
3 # (initialization storing submodules) ...
4 def encode(self, src, src_mask):
5 src = self.src_embed(src)
6 src = self.src_pos(src)
7 return self.encoder(src, src_mask) # pass through encoder stack
8 def decode(self, encoder_output, src_mask, tgt, tgt_mask):
9 tgt = self.tgt_embed(tgt)
10 tgt = self.tgt_pos(tgt)
11 return self.decoder(tgt, encoder_output, src_mask, tgt_mask) # decoder stack
12 def project(self, x):
13 return self.projection_layer(x) # linear layer to vocab
14
The encode
method applies source embedding and positional encoding, then runs the encoder. The decode
method embeds the target, adds position encodings, and runs the decoder (note: the decoder needs the encoder_output
and src_mask
to perform cross-attention). Finally, project
maps the decoder’s final hidden states to vocabulary logits.
1.5 Why “Attention Is All You Need” Was a Breakthrough
To appreciate the impact, consider that pre-Transformer models either processed words sequentially (RNNs/LSTMs) or had limited context (CNN-based models could only look at a fixed window of tokens at a time). The Transformer’s attention mechanism allowed every output to directly attend to every input (and, in self-attention, every input to every other input) with no decrease in efficiency for long distances – any token can influence any other token in essentially 1 step. This gave three big advantages:
Parallelism: An RNN of length
T
requiresT
sequential steps (cannot fully parallelize because each step waits for the previous). In a Transformer, the length dimension can be processed in parallel (attention just matrix multiplies across the whole sequence). This leads to faster training on parallel hardware.Long-range dependencies: Transformers handle long sequences with ease because even tokens far apart can directly interact via attention. RNNs had difficulty retaining info over long sequences due to vanishing gradients (LSTMs mitigated this but still struggled beyond certain lengths).
Quality improvements: When properly trained, Transformers achieved higher translation quality (as measured by BLEU scores) than the recurrent models used before. The richness of multi-head attention allowed the model to capture complex language phenomena (like syntax and coreference) better.
The original paper reported state-of-the-art results in machine translation, and importantly, the Transformer big model (with more layers and heads) improved further, showing the scaling potential of the architecture. Soon after, Transformers became the basis of BERT, GPT, and numerous other advanced models, making this paper one of the most influential in AI in the past decade.
In the next section, we’ll bridge theory with practice by examining a PyTorch implementation of a Transformer model – effectively seeing “Attention Is All You Need” in code.
2. Inside the PyTorch Transformer Implementation (model.py
Walkthrough)
To solidify our understanding, let’s go through the code from the repository Montekkundan/pytorch-transformer
, specifically the model.py
file, which implements a Transformer for machine translation (very much following the Vaswani paper’s design). This implementation is a standard encoder-decoder Transformer with multi-head attention, and it’s configured for multilingual translation (we’ll get to the dataset soon). We’ll explain each part of the model in detail and show how the concepts map to code.
Overview of model.py
contents:
Input embedding layer (
InputEmbeddings
)Positional encoding (
PositionalEncoding
)Layer normalization (
LayerNormalization
) – a custom implementation.Feed-forward network block (
FeedForwardBlock
)Multi-head attention block (
MultiHeadAttention
)Residual connection wrapper (
ResidualConnection
)Encoder block (
EncoderBlock
) and Encoder stack (Encoder
)Decoder block (
DecoderBlock
) and Decoder stack (Decoder
)Output projection layer (
ProjectionLayer
)Assembly of the Transformer model (
Transformer
class andbuild_transformer
function)
We’ll go step by step through these.
2.1 Token Embeddings and Positional Encoding
At the very start of the file, we have:
1class InputEmbeddings(nn.Module):
2 def __init__(self, d_model: int, vocab_size: int):
3 super().__init__()
4 self.d_model = d_model
5 self.vocab_size = vocab_size
6 self.embedding = nn.Embedding(vocab_size, d_model)
7 def forward(self, x):
8 return self.embedding(x) * math.sqrt(self.d_model)
9
This InputEmbeddings
layer is simple – it wraps a PyTorch nn.Embedding
(which maps token indices to learned vectors of size d_model
). The forward pass multiplies the embedding by \sqrt{d_{\text{model}}}
as discussed, which is a scaling factor recommended by the original Transformer paper (since they used residual connections right after adding positional encoding, scaling helps initialization).
Next, the PositionalEncoding
class:
1class PositionalEncoding(nn.Module):
2 def __init__(self, d_model: int, seq_len: int, dropout: float) -> None:
3 super().__init__()
4 self.dropout = nn.Dropout(dropout)
5 pe = torch.zeros(seq_len, d_model)
6 position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1)
7 div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
8 pe[:, 0::2] = torch.sin(position * div_term)
9 pe[:, 1::2] = torch.cos(position * div_term)
10 pe = pe.unsqueeze(0) # shape (1, seq_len, d_model)
11 self.register_buffer('pe', pe)
12 def forward(self, x):
13 x = x + (self.pe[:, :x.size(1), :]).requires_grad_(False)
14 return self.dropout(x)
15
We already discussed the math behind this. Key points:
It precomputes a matrix
pe
of shape(seq_len, d_model)
where each row is the positional encoding for that position.It uses
register_buffer
to storepe
so it moves with the model’s device (CPU/GPU) but isn’t a parameter.In
forward
, it slicespe
to the length of the sequence (sincex.size(1)
is the sequence length for that batch) and adds it to the inputx
. Settingrequires_grad_(False)
on that addition ensures that this constant doesn’t get gradients (not strictly necessary due to register_buffer, but a safe practice).A dropout is applied to the sum (the original paper used dropout after each sublayer and on the embeddings as well).
So effectively, when encoding an input sequence:
1embeddings = src_embed(src_tokens) # [batch, seq_len, d_model]
2embeddings = src_pos(embeddings) # add positional encodings + dropout
3
Now embeddings
is ready to feed into the encoder.
2.2 Multi-Head Attention Module
Now, the heart of the model: the MultiHeadAttention
class:
1class MultiHeadAttention(nn.Module):
2 def __init__(self, d_model: int, n_heads: int, dropout: float) -> None:
3 super().__init__()
4 self.d_model = d_model
5 self.n_heads = n_heads
6 assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
7 self.d_k = d_model // n_heads
8 # WQ, WK, WV linear layers
9 self.w_q = nn.Linear(d_model, d_model)
10 self.w_k = nn.Linear(d_model, d_model)
11 self.w_v = nn.Linear(d_model, d_model)
12 # WO output linear layer
13 self.w_o = nn.Linear(d_model, d_model)
14 self.dropout = nn.Dropout(dropout)
15 @staticmethod
16 def attention(query, key, value, mask, dropout: nn.Dropout):
17 d_k = query.shape[-1]
18 attention_scores = (query @ key.transpose(-2, -1)) / math.sqrt(d_k)
19 if mask is not None:
20 attention_scores = attention_scores.masked_fill(mask == 0, -1e9)
21 attention_scores = attention_scores.softmax(dim=-1)
22 if dropout is not None:
23 attention_scores = dropout(attention_scores)
24 return (attention_scores @ value), attention_scores
25 def forward(self, q, k, v, mask):
26 # 1. Linear projections
27 query = self.w_q(q)
28 key = self.w_k(k)
29 value = self.w_v(v)
30 # 2. Reshape to multiple heads
31 query = query.view(query.shape[0], query.shape[1], self.n_heads, self.d_k).transpose(1, 2)
32 key = key.view(key.shape[0], key.shape[1], self.n_heads, self.d_k).transpose(1, 2)
33 value = value.view(value.shape[0], value.shape[1], self.n_heads, self.d_k).transpose(1, 2)
34 # 3. Scaled Dot-Product Attention on each head
35 x, self.attention_scores = MultiHeadAttention.attention(query, key, value, mask, self.dropout)
36 # 4. Concatenate heads
37 x = x.transpose(1, 2).contiguous().view(x.shape[0], -1, self.n_heads * self.d_k)
38 # 5. Final linear projection
39 return self.w_o(x)
40
(I have split the code into steps and added comments for clarity).
I have actually described most of this above. But notice a few things:
assert d_model % n_heads == 0
: ensures the head dimension is an integer (common pitfall if using weird sizes; here 512 model with 8 heads -> d_k=64).w_q, w_k, w_v
: learned weight matrices to project input features toQ, K, V
. They are each of shape (d_model, d_model), meaning they take in a vector of length d_model and output a vector of length d_model. This is effectivelyW^Q, W^K, W^V
.w_o
: the output linear layer to combine heads’ outputs, also (d_model, d_model).In
forward
,q, k, v
could be the same tensor (for self-attention, we pass the same tensor three times) or different (for cross-attention,q
comes from decoder,k=v
come from encoder).After projection, shape of
query
was (batch, seq_len, d_model). Afterview
andtranspose
, it becomes (batch, n_heads, seq_len, d_k). The code usestranspose(1,2)
to swap the sequence length and head dimensions so that the head dimension comes right after batch: we end up with shape (batch, n_heads, seq_len, d_k).They store
self.attention_scores
– this allows later retrieval or debugging of the attention weights (not strictly needed for forward pass, but useful for visualization).Finally, they do the concat:
x.transpose(1,2)
brings it back to (batch, seq_len, n_heads, d_k) and then.view(batch, seq_len, n_heads*d_k)
which is back to (batch, seq_len, d_model). Thenw_o
is applied to mix the heads.
This module is used for three kinds of attention in our Transformer:
Encoder self-attention (called with
q=k=v=
encoder hidden states, mask = padding mask).Decoder self-attention (called with
q=k=v=
decoder hidden states so far, mask = padding + future mask).Encoder-Decoder cross-attention (called with
q=
decoder hidden,k=v=
encoder output, mask = source padding mask so it doesn’t attend to padding positions in encoder output).
They simply instantiate two MultiHeadAttention
objects for the decoder (one for self, one for cross) and one for the encoder (self).
2.3 Feed-Forward Network
The feed-forward network in each layer is a simple two-layer MLP applied independently to each sequence position. In code, FeedForwardBlock
is:
1class FeedForwardBlock(nn.Module):
2 def __init__(self, d_model: int, d_ff: int, dropout: float) -> None:
3 super().__init__()
4 self.linear1 = nn.Linear(d_model, d_ff) # W1
5 self.dropout = nn.Dropout(dropout)
6 self.linear2 = nn.Linear(d_ff, d_model) # W2
7 def forward(self, x):
8 # (Batch, seq_len, d_model) -> (Batch, seq_len, d_ff) -> (Batch, seq_len, d_model)
9 return self.linear2(self.dropout(torch.relu(self.linear1(x))))
10
This corresponds to the “position-wise feed-forward network” in the Transformer paper. Typically d_{\text{ff}}
(here d_ff
) is larger than d_{\text{model}}
(the paper used 2048 vs 512). The network has a ReLU activation in between and a dropout. The code comments clarify shapes: it’s applied to each batch and seq position, which is possible because a linear layer applied to (batch, seq_len, d_model)
will broadcast over the first two dimensions (PyTorch’s linear can accept an input of shape (N, *, in_features)
and treats all but last dim as batch dims).
Combining this with attention gives the model both non-linear transformations (from the FFN) and context-mixing (from attention) in each layer.
2.4 Residual Connections and Layer Normalization
As mentioned, each sub-layer (Attention or FeedForward) is wrapped with a residual add & normalization. The code provides ResidualConnection
:
1class ResidualConnection(nn.Module):
2 def __init__(self, features: int, dropout: float) -> None:
3 super().__init__()
4 self.dropout = nn.Dropout(dropout)
5 self.norm = LayerNormalization(features)
6 def forward(self, x, sublayer):
7 # Pre-LN: Apply norm before sublayer
8 return x + self.dropout(sublayer(self.norm(x)))
9
This is pretty straightforward:
LayerNormalization
is presumably a custom implementation (and indeed above in code,LayerNormalization
is defined using mean and std to normalize each example’s features). PyTorch now hasnn.LayerNorm
, but custom code likely did similar:1class LayerNormalization(nn.Module): 2 def __init__(self, features, eps=1e-6): 3 super().__init__() 4 self.alpha = nn.Parameter(torch.ones(features)) 5 self.bias = nn.Parameter(torch.zeros(features)) 6 self.eps = eps 7 def forward(self, x): 8 mean = x.mean(dim=-1, keepdim=True) 9 std = x.std(dim=-1, keepdim=True) 10 return self.alpha * (x - mean) / (std + self.eps) + self.bias 11
pythonThis matches the standard definition:
\text{LayerNorm}(x) = \frac{x - \mu}{\sigma + \epsilon} * \gamma + \beta$ where $\gamma,\beta
are learned scale and bias (alpha and bias here).In
ResidualConnection.forward
, they first donorm(x)
(this is Pre-LN style). Then callsublayer(norm(x))
. They expectsublayer
to be a function (either a lambda around the attention call, or the feed-forward module itself, as seen inEncoderBlock.forward
).Dropout is applied to the sublayer output, then added to the original
x
.
The decision to do Pre-norm (norm before sublayer) vs Post-norm (norm after adding) is subtle.
The original Transformer (Vaswani et al.) did post-norm (norm after adding residual). Many modern implementations use pre-norm because it tends to help with gradient flow in deep models (preventing extremely large values early in training). Our code above clearly does pre-norm (normalize then sublayer then add).
2.5 Encoder and Decoder Blocks
Using the pieces above:
EncoderBlock combines a
MultiHeadAttention
andFeedForwardBlock
, with twoResidualConnection
wrappers:
1class EncoderBlock(nn.Module):
2 def __init__(self, features, self_attention_block, feed_forward_block, dropout):
3 super().__init__()
4 self.self_attention_block = self_attention_block
5 self.feed_forward_block = feed_forward_block
6 self.residual_connections = nn.ModuleList([ResidualConnection(features, dropout) for _ in range(2)])
7 def forward(self, x, src_mask):
8 x = self.residual_connections[0](x,
9 lambda inp: self.self_attention_block(inp, inp, inp, src_mask))
10 x = self.residual_connections[1](x, self.feed_forward_block)
11 return x
12
The first residual applies self-attention (notice it passes inp
to self.self_attention_block
as q, k, v – using a lambda to capture the src_mask
too). The second applies the feed-forward.
Encoder (the stack) holds multiple EncoderBlocks:
1class Encoder(nn.Module):
2 def __init__(self, features, layers: nn.ModuleList):
3 super().__init__()
4 self.layers = layers
5 self.norm = LayerNormalization(features)
6 def forward(self, x, mask):
7 for layer in self.layers:
8 x = layer(x, mask)
9 return self.norm(x)
10
After passing through each layer in sequence, it applies a final LayerNorm (post-norm at the end of the entire encoder). This final normalization was mentioned in the paper as well.
DecoderBlock is similar but with a cross-attention:
1class DecoderBlock(nn.Module):
2 def __init__(self, features, self_attn_block, cross_attn_block, feed_forward_block, dropout):
3 super().__init__()
4 self.self_attention_block = self_attn_block
5 self.cross_attention_block = cross_attn_block
6 self.feed_forward_block = feed_forward_block
7 self.residual_connections = nn.ModuleList([ResidualConnection(features, dropout) for _ in range(3)])
8 def forward(self, x, encoder_output, src_mask, tgt_mask):
9 x = self.residual_connections[0](x,
10 lambda inp: self.self_attention_block(inp, inp, inp, tgt_mask))
11 x = self.residual_connections[1](x,
12 lambda inp: self.cross_attention_block(inp, encoder_output, encoder_output, src_mask))
13 x = self.residual_connections[2](x, self.feed_forward_block)
14 return x
15
Three residual sublayers: (1) self-attn on decoder input x
(with tgt_mask
for causal masking), (2) cross-attn where queries are inp
(decoder) and keys/values are encoder_output
(this attends to the encoder’s output given the source mask), and (3) feed-forward. Each is wrapped similarly with norm+dropout.
Decoder (the stack) is analogous to Encoder:
1class Decoder(nn.Module):
2 def __init__(self, features, layers: nn.ModuleList):
3 super().__init__()
4 self.layers = layers
5 self.norm = LayerNormalization(features)
6 def forward(self, x, encoder_output, src_mask, tgt_mask):
7 for layer in self.layers:
8 x = layer(x, encoder_output, src_mask, tgt_mask)
9 return self.norm(x)
10
It applies all decoder blocks in sequence and a final norm.
2.6 Building the Transformer Model
Finally, the pieces are tied together in the build_transformer
function:
1def build_transformer(src_vocab_size, tgt_vocab_size, src_seq_len, tgt_seq_len,
2 d_model=512, N=6, h=8, dropout=0.1, d_ff=2048) -> Transformer:
3 # Embedding layers
4 src_embed = InputEmbeddings(d_model, src_vocab_size)
5 tgt_embed = InputEmbeddings(d_model, tgt_vocab_size)
6 # Positional encoding layers
7 src_pos = PositionalEncoding(d_model, src_seq_len, dropout)
8 tgt_pos = PositionalEncoding(d_model, tgt_seq_len, dropout)
9 # Build encoder stack
10 encoder_blocks = []
11 for _ in range(N):
12 attn = MultiHeadAttention(d_model, h, dropout)
13 ffn = FeedForwardBlock(d_model, d_ff, dropout)
14 encoder_blocks.append(EncoderBlock(d_model, attn, ffn, dropout))
15 decoder_blocks = []
16 for _ in range(N):
17 self_attn = MultiHeadAttention(d_model, h, dropout)
18 cross_attn = MultiHeadAttention(d_model, h, dropout)
19 ffn = FeedForwardBlock(d_model, d_ff, dropout)
20 decoder_blocks.append(DecoderBlock(d_model, self_attn, cross_attn, ffn, dropout))
21 encoder = Encoder(d_model, nn.ModuleList(encoder_blocks))
22 decoder = Decoder(d_model, nn.ModuleList(decoder_blocks))
23 projection_layer = ProjectionLayer(d_model, tgt_vocab_size)
24 transformer = Transformer(encoder, decoder, src_embed, tgt_embed, src_pos, tgt_pos, projection_layer)
25 # Initialize parameters with Xavier (Glorot) uniform
26 for p in transformer.parameters():
27 if p.dim() > 1:
28 nn.init.xavier_uniform_(p)
29 return transformer
30
This code is quite self-explanatory but worth noting:
It uses
N=6
by default (6 encoder layers, 6 decoder layers).h=8
attention heads by default,d_model=512
,d_ff=2048
– these are exactly the hyperparameters from the original paper’s “base” model.They create separate
MultiHeadAttention
instances for each layer. (One could also share them and rely on different weights per layer, but here they just create new ones – which is fine, since each layer should have its own parameters).Encoder and decoder blocks are collected into
nn.ModuleList
and then passed intoEncoder
/Decoder
. Using ModuleList is important so that all sub-blocks are properly registered as submodules (ensuring their parameters will be learned).They initialize the weights of all linear layers using Xavier uniform (also known as Glorot initialization), which is a common initialization for deep nets (sets variance based on in/out dimensions). The condition
if p.dim() > 1
ensures we only initialize weight matrices (not biases or layernorm parameters which are 1D). The Transformer paper did something similar – they used Xavier and also scaled some initializations by a factor depending on layer depth in the big model, but Xavier alone is a solid choice.
The returned Transformer
object wraps everything. Remember, this Transformer
class (we saw its methods earlier) defines encode
, decode
, project
. It doesn’t directly implement forward
method that does end-to-end in one call; instead, one is supposed to call encode
then decode
then project
appropriately. That gives flexibility to handle things like teacher forcing during training or different generation strategies during inference (greedy, beam search, etc.). In training, however, one could write a forward
that just calls those in sequence, but here they chose to expose encode/decode.
2.7 Summary of Model Architecture
At this point, we’ve covered how each part of the Transformer is implemented. To summarize in plain language:
Embedding & PositionalEncoding: Turn tokens into vectors and add positional info.
Encoder: N layers each doing self-attention (with multiple heads) and a feed-forward network, with residual connections and normalization around each. The encoder transforms the input sequence into a sequence of context-aware features (same length, but each position’s vector now contains information from the whole sequence).
Decoder: N layers each doing self-attention on the output-so-far, then attention over encoder output (so the decoder can focus on relevant parts of the input), and a feed-forward, all with residual connections. The decoder produces the output sequence one position at a time (during training, we feed the ground-truth shifted targets to it; during inference, it generates step by step).
Projection to vocab: A linear layer that converts each decoder output vector to a logits over the target language vocabulary. Combined with a softmax, this gives a probability distribution for the next token.
By stacking multiple layers and using multi-head attention, the model can capture very complex relationships (some heads might focus on local structure like adjacent words, others on long-range dependencies like subject-verb agreement across a sentence, etc.). The feed-forward adds capacity to transform and mix the information after attention. Residual connections ensure gradients flow well and you can train deeper networks effectively.
This particular implementation aligns closely with the original Transformer (sometimes called the “transformer base” implementation). It’s a solid foundation for any sequence-to-sequence task. In our case, the task is multilingual translation with the Samanantar dataset, which we’ll explore next.
3. The Samanantar Dataset
Training a translation model requires a parallel corpus – a dataset of sentence pairs in the source and target languages. For training a multilingual English–Indic translator (where one system can handle multiple Indian languages), we need a large, diverse corpus covering English and many Indic languages. Enter Samanantar, the largest publicly available parallel corpus collection for Indic languages.
According to its creators, Samanantar contains a total of 49.7 million sentence pairs between English and 11 Indic languages. It was constructed by combining existing public corpora (~12.4M pairs) with an additional ~37.4M pairs mined from the web. This massive effort increased available parallel data by about 4x for these languages. The languages covered (11 Indic languages) include:
Assamese (as)
Bengali (bn)
Gujarati (gu)
Hindi (hi)
Kannada (kn)
Malayalam (ml)
Marathi (mr)
Odia (or)
Punjabi (pa)
Tamil (ta)
Telugu (te)
All are paired with English (en). Additionally, Samanantar provides mined parallel pairs between the Indian languages themselves (covering 55 possible pairs among 11 languages, totaling ~83.4M extra pairs), though our focus is on the English-centric portion.
Dataset structure: On Hugging Face datasets (where Samanantar is hosted as "ai4bharat/samanantar"
), the dataset is organized by language. The code in our training script suggests that to load, for example, Hindi-English data, one would do:
1load_dataset("ai4bharat/samanantar", "hi", split='train')
2
where "hi"
is the configuration name (for Hindi-English). The repository’s config.py
defines a list of LANGUAGES = ["as","bn","gu","hi","kn","ml","mr","or","pa","ta","te"]
and constructs lang_pairs = [f"{lang}-en" for lang in LANGUAGES]
. This implies each language code corresponds to a pairing with English. Indeed, in training, they iterate through lang_src_codes
(the list above) and always pair with target 'en'
. So effectively, they load 11 separate subsets (as-en, bn-en, … te-en) and then concatenate them into one big training set.
This kind of data organization (one combined dataset) is common for multilingual models. The model doesn’t explicitly know which language is which unless we tell it – and we do, via special tokens (prefixes indicating the language, which we’ll cover in the next section on training).
What makes Samanantar special? Its scale and diversity. It used multiple strategies to mine translations:
Crawling monolingual Indic and English text and finding alignments.
Using OCR on scanned documents to extract text.
Using multilingual representation models and approximate nearest neighbor search to match sentences across languages.
Validating with human evaluation to ensure quality of mined pairs.
The result is a high-quality, large dataset. The availability of Samanantar is a huge boost for building translation systems for Indian languages, many of which are low-resource (previously had little parallel data).
Size and considerations: 49.7 million pairs is a lot! Training on the full set can be time-consuming and resource-intensive. In fact, the config by default uses only 1% of the dataset ("dataset_fraction": 0.01
in get_config()
) for experimentation. That’s ~0.5 million pairs, which is manageable on a single GPU for a proof-of-concept. For full training, one would set dataset_fraction: 1.0
to use all data (assuming you have the compute for it).
We’ll see in the training code how the dataset is loaded and prepared, and how the languages are handled with special tokens.
4. Configuration and Training Process
Training a Transformer on a large multilingual dataset involves many moving parts: data loading and preprocessing, model instantiation, setting hyperparameters, and managing the training loop. In our project, these are handled by config.py
and train.py
. Let’s break down how training is configured and executed, while clarifying common terms.
4.1 Configuration (config.py
)
The get_config()
function defines all the key settings in a dictionary. Some important config fields and their meanings:
batch_size
: Number of examples per training batch (default 32). This is how many sentence pairs we process before updating the model weights. A larger batch can stabilize training (via averaging gradients) but uses more memory.num_epochs
: Number of full passes through the training dataset (default 5). Each epoch means the model has seen every training example once. In practice, one often trains for many epochs until convergence (or uses early stopping).lr
: Learning rate (initial) for the optimizer (default 1e-4). This controls the step size in gradient descent. However, note: in code they actually override this by using a special scheduler (the config’slr
is not used directly as final LR; it’s more like a base).seq_len
: Maximum sequence length (default 350). Sentences will be truncated or padded to this length. This should be long enough to cover most sentences in the data. (350 is quite high; many sentences are shorter, but some could be longer especially from web-mined data.)d_model
: Model dimensionality for embeddings and hidden layers (512). Must match what the model expects. 512 is standard as discussed.datasource
: The Hugging Face dataset path ('ai4bharat/samanantar'
).dataset_fraction
: Fraction of the dataset to use (1.0 = full, 0.01 = 1%). This is extremely useful for quick experiments. Using 1% for initial runs is wise to ensure everything works.lang_pairs
: A list like["as-en","bn-en", ...]
for all languages. Not directly used in training code, but it’s informational.lang_src_codes
: List of source language codes (the 11 listed above).lang_tgt_code
: Target language code ( "en" ).src_lang_prefix
andtgt_lang_token
: These define special tokens used to tag the source and target languages in the text. By default,SRC_LANG_PREFIX = "<2"
andTGT_LANG_TOKEN = "<2en>"
for English target. For example, if source is Hindi ("hi"
), they will prepend the token<2hi>
to each source sentence, and prepend<2en>
to each target.model_folder
&model_basename
: Directory and filename prefix for saving model weights.preload
: If set to"latest"
, it will try to load the most recent checkpoint to resume training.tokenizer_file
: Name of the file to save/load the tokenizer (shared across languages).experiment_name
: Directory for logging (e.g., for TensorBoard logs).warmup_epochs
: Duration of learning rate warmup in epochs (1 by default).
Additional helper functions like get_weights_file_path
and latest_weights_file_path
construct the path for saving/loading model checkpoints based on config (they incorporate the datasource name and epoch, etc.).
The special tokens mechanism deserves a quick note: By adding tokens like <2hi>
at the start of a Hindi sentence, the model can identify that the following text is in Hindi. Similarly, by adding <2en>
at the start of the target, the model knows it should produce English. This is a typical trick in multilingual models to condition the output language. In our case, since English is always the target, the <2en>
is a constant start token for the decoder (like a task specifier). The source prefix <2xx>
tells the model which language the source is in (so it can, e.g., differentiate Hindi vs Tamil input internally).
4.2 Data Loading and Preprocessing
The function get_ds(config)
in train.py
handles loading the datasets for all language pairs and preparing them for training:
It initializes an empty list
all_ds_raw
.For each
src_lang
inconfig['lang_src_codes']
(i.e., each of the 11 languages):It sets
tgt_lang = config['lang_tgt_code']
(which is "en").It forms
lang_pair_str
like "hi-en" just for logging.It attempts to load the dataset:
1ds_pair = load_dataset(config['datasource'], src_lang, split=split_str) 2
pythonwhere
split_str
is'train'
or a percentage slice like'train[:1%]'
depending ondataset_fraction
. So iffraction=0.01
(1%), it will actually load just the first 1% of the train split of that sub-dataset.It then maps through
ds_pair
to add a new field'src_lang'
with the language code. This ensures each example knows its source language (again for use in adding the prefix later).Appends the dataset to
all_ds_raw
.If a dataset for a certain language can’t be loaded (maybe dataset not present or a different split name), it catches exception and logs a warning, but for these 11 languages, it should work if data is available.
After the loop,
all_ds_raw
is a list of datasets (one per language). They callconcatenate_datasets(all_ds_raw)
to combine them into one big dataset. This combined dataset has examples from all languages.They optionally limit the dataset to
max_test_samples
if that is in config (for quick tests, e.g., in a unit test they set 100).Next, tokenizer: They call
get_or_build_tokenizer_multi(config, [combined_ds_raw])
. This function will either load an existing tokenizer fromtokenizer_file
or train a new one on the provided datasets. The code for this function shows:If the tokenizer file exists, load it.
If not, it builds a
WordLevel
tokenizer (from Hugging Facetokenizers
library) with a list ofspecial_tokens
including[UNK],[PAD],[SOS],[EOS]
and all the language tokens (<2as>, ..., <2en>
).It then trains the tokenizer on an iterator of all sentences from the datasets. The
get_all_sentences_multi
function just yields every source and target sentence from each dataset provided.This will build a single vocabulary that covers all languages (plus English). The decision to use a shared vocabulary is common in multilingual models because it forces the model to allocate capacity to all languages and share embeddings (which can aid transfer learning, especially if languages share scripts or partial words – though here scripts vary).
After training, it saves the tokenizer to a file for reuse.
In our scenario, presumably the tokenizer might already be provided or can be built. Using a WordLevel tokenizer means it treats each word (whitespace separated) as a token – this might be less flexible than subword tokenizers (like SentencePiece/BPE) especially for such diverse languages, but maybe it’s sufficient or a design choice for simplicity. (Vocabulary size will be large though; possibly truncated by a min_frequency=2 in trainer, meaning very rare words are dropped and [UNK] used.)
Logging the combined dataset size and shuffling it. They then do a 90/10 train/val split using
random_split
. This means no separate provided validation set; they’re slicing the training data.Now the BilingualDataset class is used to wrap the raw dataset into a PyTorch
Dataset
that yields model-ready tensors. This is important:BilingualDataset(train_ds_raw, tokenizer, src_lang_prefix, tgt_lang_token, seq_len)
constructs an object where each item is processed.Let’s inspect
BilingualDataset.__getitem__
to see what it does with an example:1src_text = example['src']; tgt_text = example['tgt']; src_lang = example['src_lang'] 2src_lang_token_str = f"{src_lang_prefix}{src_lang}>" 3# Convert text to token IDs 4enc_input_tokens = tokenizer.encode(src_text).ids 5dec_input_tokens = tokenizer.encode(tgt_text).ids 6# Prepare encoder input: [SOS] + <2src> + src_tokens + [EOS] + [PAD...] 7# Prepare decoder input: [<2tgt>] + tgt_tokens + [PAD...] 8# Prepare label: tgt_tokens + [EOS] + [PAD...] 9# Create masks as well 10
pythonEssentially, for each example, it adds the special tokens:
For the encoder input sequence, they prepend
[SOS]
and the source language token (like<2hi>
if src_lang is "hi"), and append[EOS]
. Then pad toseq_len
with[PAD]
.For the decoder input sequence (what we feed into the decoder at time 0), they prepend the target language token
<2en>
and then the target sentence tokens (no EOS at end for decoder input). Then pad toseq_len
.For the target labels (the expected output), they take the target sentence tokens and append
[EOS]
, then pad. This label sequence is what the model should output (one position ahead of the decoder input).
Masks:
encoder_mask
: a tensor of shape (1, 1, seq_len) where positions that are not PAD are 1, PAD are 0. The model will use this to mask out attention to pad tokens in encoder.decoder_mask
: a tensor (1, seq_len, seq_len) combining the causal (triangular) mask with a PAD mask for the decoder input. They compute it as(decoder_input != pad).unsqueeze(0).int() & causal_mask(seq_len)
. In effect, it’s 1 for allowed attention positions, 0 where not allowed.(They expand dimensions to be broadcastable in attention, e.g., encoder_mask shape might become [batch, 1, 1, seq_len] when used, matching [batch, n_heads, query_len, key_len].)
The
BilingualDataset
returns a dict with these fields:"encoder_input", "decoder_input", "encoder_mask", "decoder_mask", "label", "src_text", "tgt_text", "src_lang"
. The last few (texts and src_lang) are probably just for reference or debugging; the model training really uses the first five (inputs, masks, label).This dataset class encapsulates all preprocessing so that the training loop can just fetch ready tensors.
Finally,
DataLoader
objects are created:train_dataloader
with batch_size 32 (shuffled) andval_dataloader
with batch_size 1 (not shuffled). Using batch_size 1 for validation means they evaluate one example at a time (makes metric calculation or printing easier perhaps). It’s fine given validation isn’t huge.
At this point, we have:
A model (built by
build_transformer
).A tokenizer.
Data loaders for training and validation that will yield batches of prepared inputs.
4.3 Training Loop (train.py
)
The train_model(config)
function orchestrates the training. Let’s outline it:
Setup:
Fix random seeds for reproducibility (
set_seed(42)
).Select device (CUDA, MPS, or CPU).
Log the config and device.
Create the weights output folder (in case it doesn’t exist).
Load data and tokenizer:
train_dataloader, val_dataloader, tokenizer,
= get
ds(config)
.Get vocab size from tokenizer for model.
Build model:
model = get_model(config, vocab_size)
which callsbuild_transformer
internally.Move model to device (model.to(device)).
Set up a TensorBoard SummaryWriter (for logging training loss/metrics if needed).
Optimizer and Learning Rate Scheduler:
They use
torch.optim.Adam
withlr=1.0
andeps=1e-9
. Why lr=1.0? Because they are going to use a custom learning rate schedule that controls the effective learning rate.They compute
warmup_steps = num_batches_per_epoch * warmup_epochs
. If we have, say, 1000 batches per epoch and warmup_epochs=1, warmup_steps=1000.The scheduler is LambdaLR with a lambda function:
1lr_lambda = lambda step: d_model**(-0.5) * min((step+1)**(-0.5), (step+1) * warmup_steps**(-1.5)) 2
python. This is exactly the learning rate schedule described in the Transformer paper:
It increases linearly for the first
warmup_steps
steps, then decays proportionally to $1/\sqrt{step}$ afterwards.Specifically, $(step+1) * warmup_steps^{-1.5}$ is $\frac{step+1}{(warmup_steps)^{1.5}}$ which is small at first and grows until it equals $(step+1)^{-0.5}$ at the point step = warmup_steps. After that, $(step+1)^{-0.5}$ is the smaller term, so it decays.
Multiplying by $d_{\text{model}}^{-0.5}$ is also from the paper (they found scaling by model dimension helped).
They even set a tiny LR for step 0 (1e-8) to avoid 0.
This schedule means the actual learning rate will start very low, increase to a peak at the end of warmup (which will be around
d_model^-0.5 * warmup_steps^-0.5
), then decrease.
optimizer
andlr_scheduler
are now ready.
Resume Training (if applicable):
If
preload
is set in config (e.g.,'latest'
by default), they attempt to find the latest checkpoint and load it.latest_weights_file_path(config)
will find the newest file matching the pattern in the weights folder.If found, they load
state = torch.load(model_filename)
. Thisstate
is a dict containing:'epoch'
: last epoch completed'model_state_dict'
: the model weights'optimizer_state_dict'
: optimizer state (momentum etc.)'scheduler_state_dict'
: scheduler state (current step and LR)'global_step'
: total steps done'config'
: the config dict
They then load those into model, optimizer, scheduler (if present).
Set
initial_epoch = state['epoch']+1
so we start at next epoch.This allows training to resume seamlessly. If no checkpoint, it logs starting from scratch.
Loss Function:
They use
nn.CrossEntropyLoss
withignore_index=pad_id
andlabel_smoothing=0.1
.This is significant: label smoothing of 0.1 means the target distribution for the true token is 0.9 on the correct token and 0.1 distributed over the rest (basically a slight mix with uniform). This was also used in the original paper. It helps prevent the model from becoming over-confident on training data and can improve generalization.
ignore_index=pad_id
ensures that pad tokens in the label don’t contribute to the loss (since we padded sequences to equal length).They move the loss function to device (though that’s not necessary in PyTorch IIRC; loss can compute on device if inputs are on device, but it’s fine).
Training Epochs:
They loop for
epoch in range(initial_epoch, config['num_epochs'])
.At each epoch:
Log epoch start, empty CUDA cache (to potentially free memory).
Set
model.train()
(enables dropout, etc.).Create a tqdm progress bar over
train_dataloader
.Iterate over each
batch
:Move data to device:
encoder_input, decoder_input, encoder_mask, decoder_mask, label
all to device.Forward pass:
1encoder_output = model.encode(encoder_input, encoder_mask) 2decoder_output = model.decode(encoder_output, encoder_mask, decoder_input, decoder_mask) 3proj_output = model.project(decoder_output) 4
python.
This yieldsproj_output
of shape (batch_size, seq_len, vocab_size). It’s basically the logits for each position.Compute loss:
1loss = loss_fn(proj_output.view(-1, vocab_size), label.view(-1)) 2
python. They flatten the predictions and labels so that it’s a 1D list of predictions vs targets (ignoring pad by the earlier ignore_index).
Log the current loss in the progress bar postfix, along with current LR (from scheduler).
Do
loss.backward()
to compute gradients.optimizer.step()
to update weights, thenoptimizer.zero_grad(set_to_none=True)
to clear grads.lr_scheduler.step()
to update the learning rate for the next step.Increment
global_step
.
After the epoch loop:
Switch to eval and run
run_validation
on the val_dataloader. Therun_validation
function presumably runs a similar loop on val data, computes some metrics or at least prints a couple of example translations for inspection (the code indicates it collectspredicted
vsexpected
and maybe calculates BLEU or something; it even catches iftorchmetrics
is not installed, so maybe it would compute BLEU or WER if available).Save a checkpoint:
1torch.save({ 2 'epoch': epoch, 3 'model_state_dict': model.state_dict(), 4 'optimizer_state_dict': optimizer.state_dict(), 5 'scheduler_state_dict': lr_scheduler.state_dict(), 6 'global_step': global_step, 7 'config': config 8}, model_filename_save) 9
python, where
model_filename_save = get_weights_file_path(config, f"{epoch:02d}")
(so it will include epoch number). They log that checkpoint is saved.
Loop continues to next epoch.
After training, or on interrupt, they would have saved the last epoch. If
num_epochs=5
, you get files liketmodel_multi_00.pt, tmodel_multi_01.pt, ..., tmodel_multi_04.pt
in a folder namedai4bharat_samanantar_weights_multi
(constructed from datasource + model_folder).
Understanding key terms in context:
Epoch: as we saw, one full pass over data. They have
num_epochs=5
by default, meaning the model will see the 0.01% of dataset 5 times. If using full data, often fewer epochs might suffice (or more if needed).Batch size (32): means at each step, 32 sentences (actually 32 pairs) are processed in parallel. With sequence length 350, that means the model processes 32*350 = 11,200 tokens per batch (actually double if you consider both source and target sequences of similar length, plus masks, etc.). That’s moderate.
Learning rate & warmup: The initial small LR grows to a peak then decays. The choice of 1e-4 in config is somewhat superseded by the scheduler logic (peak effective LR will be around 0.0007 when using d_model=512 and warmup 1 epoch).
Adjusting based on system capacity:
If you have a GPU with more memory, you might increase
batch_size
to utilize it (larger batch can improve gradient estimation). If you have less memory, you lowerbatch_size
(even to 16 or 8) so it fits.seq_len
could be lowered if memory is a big issue or if you know your sentences aren’t that long. Lowering it to 100, for instance, would reduce memory and computation (self-attention is $O(n^2)$ in seq_len).d_model
andd_ff
could be reduced to make the model smaller (say 256 and 1024) at the cost of some accuracy.num_epochs
: If using the full dataset, 5 epochs might be overkill (given 50M examples, even 1 epoch is a lot of training). Actually, large datasets often require fewer epochs as they already have so many examples. With 1% of data, more epochs helps to reuse data. So one should tune this: maybe train until validation metrics stop improving rather than an arbitrary number.Gradient Accumulation: not in this code, but if one wants effectively larger batches without increasing memory, one can accumulate gradients over multiple smaller batches before calling
optimizer.step()
.Mixed precision: training could be sped up using float16 (not shown here, but a consideration if using modern GPUs).
The training as implemented is pretty standard for a Transformer. It even integrates best practices like warmup and label smoothing. With this training loop, if we feed it the data and let it run, it will gradually learn to translate from all supported languages into English. But how do we evaluate or use the model after training? That’s where testing and inference come in.
5. Testing and Inference
Once a model is trained (or while it’s being trained), we want to test it – that is, feed in some inputs and see the translated output. This can be done in code or via provided scripts/notebooks. We also need to understand how to use the saved weights (checkpoints) for inference without retraining everything.
5.1 Using Saved Weights (What are “weights” anyway?)
Model weights refer to the learned parameters of the model (the numbers in all those matrices and embeddings) after training. In PyTorch, model.state_dict()
will give you a dictionary of all parameter tensors. The training code saved a checkpoint file (like .../tmodel_multi_04.pt
) containing the model’s state_dict and other info.
To use these weights for inference:
We need the model architecture code (which we have in
model.py
).We load the checkpoint, initialize a model with the same architecture, then load the state dict into it. This is exactly what the training code does when resuming:
model.load_state_dict(state['model_state_dict'])
.
So for inference, one can:
1config = get_config()
2tokenizer = Tokenizer.from_file(config['tokenizer_file'])
3model = build_transformer(vocab_size, vocab_size, config["seq_len"], config["seq_len"], d_model=config["d_model"])
4state = torch.load(latest_weights_file_path(config), map_location=device)
5model.load_state_dict(state['model_state_dict'])
6model.eval()
7
This is essentially what the translate.py
script in the repo does. It hides some details but we saw relevant parts:
It loads config and sets src_lang and tgt_lang (defaults to first if not given).
Loads the Tokenizer from file.
Builds the model with
build_transformer(...)
using shared vocab (vocab_size for both src and tgt).Loads the latest weights file via
latest_weights_file_path
andtorch.load
.Loads the state dict into the model.
Puts model in eval mode and on the right device.
Thus, “weights” are just those learned values, and loading them into the model restores it to the trained state. Without doing this, the model would be at random initialization (and output gibberish).
5.2 Greedy Decoding vs. Teacher Forcing (during training)
During training, we used teacher forcing: we gave the decoder the ground truth previous token to predict the next token at each time step (via decoder_input
which was the true target shifted right). At test time, we don’t have the ground truth output— we must generate it. There are different strategies for generation (greedy, beam search, etc.).
The simplest is greedy decoding: at each step, pick the most probable next token and append it, then feed that as input to generate the following token, until you hit an end-of-sentence token.
The repository provides a translate
function (and a similar greedy_decode
in train.py) that does exactly this:
1# Pseudocode for greedy generation:
2encoder_output = model.encode(source_ids, source_mask)
3decoder_input = [ <2en> ] # start with target language token
4while True:
5 decoder_mask = causal_mask(decoder_input_length)
6 dec_out = model.decode(encoder_output, source_mask, decoder_input, decoder_mask)
7 next_token_logits = model.project(dec_out)[:, -1, :] # last time step logits
8 next_token_id = argmax(next_token_logits) # pick highest probability token
9 append next_token_id to decoder_input
10 if next_token_id == EOS or length == max_len: break
11
In the actual code:
They prepare
source
by encoding the input sentence and adding[SOS] ... [EOS]
etc.. Then createsource_mask
.encoder_output = model.encode(source, source_mask)
.Initialize
decoder_input
with the<2en>
token id.Loop:
Create a
decoder_mask = torch.triu(torch.ones((1, L, L)), diagonal=1)
where L is current decoder_input length (this is the causal mask matrix, 1s in lower triangle including diagonal, 0s above).out = model.decode(encoder_output, source_mask, decoder_input, decoder_mask)
.prob = model.project(out[:, -1])
gives the logits for the last time step (since out is shape [1, L, d_model], projecting the last index gives [1, vocab_size])., next
word = torch.max(prob, dim=1)
picks the argmax token id.Append that to decoder_input (which is a tensor; they create a new tensor and cat).
If the next_word is EOS, break or if length hits seq_len, break (to avoid infinite loop).
Return the decoded sequence as text via
tokenizer.decode(...)
.
The greedy_decode
in train.py is similar but structured slightly differently (and it logs debug info). The result is the same: you get a translated sequence of token IDs, which is then converted back to text.
Using the Test folder / notebooks:
The repository has a
tests/
directory. For example,tests/test_train_loop.py
runs a quick training on a tiny dataset fraction just to ensure the training code works.In absence of a dedicated “inference notebook”, the
translate.py
script is the inference script to use. One could integrate that into a notebook or a simple CLI tool.
So to test the model with an example, you might do:
1from translate import translate
2sentence = "My name is montek" # This is a test sentence
3result = translate(sentence, src_lang="hi")
4print(result)
5
5.3 Using Already-Trained Weights for Quick Testing
Ensure
weights_multi
folder has a checkpoint (if training completed).Call
translate("Some sentence", src_lang="xx")
with appropriate src_lang.
6. Streamlit App: Building a Mini Google Translate UI
To make our model accessible and interactive, the repository includes a Streamlit app (translate_app.py
). Streamlit is a Python framework for creating web apps with minimal code, often used for data science demos. The app essentially creates a simple web interface where a user can input text in English and get translations in any of the supported languages, mimicking a Google Translate-like experience.
Let’s walk through what this app does:
It imports Streamlit and the
translate
function we discussed.It defines a list of languages (code, name) for the 11 Indic languages and creates mappings (code->name and name->code).
Sets the page title and layout, and displays a title for the app. The title printed is "Ai4bharart Samanater\n Bridging Language Barriers with AI" (there’s a small typo in "Samanantar" there, but that’s fine).
Creates two columns: left for English input, right for translated output.
In the left column:
It shows a subheader "Enter English text".
A text area for input text (keyed as "input_text").
In the right column:
It shows "Translation" subheader.
A selectbox with all target language names, defaulting to Hindi.
An empty string
translated_text
is prepared.If the user has entered something (
if english_text.strip():
):It tries to run translation:
1device = torch.device("cuda" if available else "mps" if available else "cpu") 2translated_text = translate(english_text, src_lang=lang_name_to_code[selected_lang_name], device=device) 3
pythonIf
translate
raises an exception, it catches it and setstranslated_text
to an error message.
Finally, it displays a text area with the
translated_text
(marked as output, and not editable).
So, the UI is essentially:
Left: big box to input English text.
Right: dropdown to choose one of the 11 languages, and a big box showing the translation.
When you run this app (with streamlit run translate_app.py
), it will load the model on startup (the first call to translate will trigger model loading because inside translate
it loads weights). So the first translation might take a moment as the model and tokenizer load into memory. After that, each time you modify the English text or change the language, it will call translate
again.
The Streamlit app provides an easy way for non-technical users (or for a demo) to try the translator. You type an English sentence, select “Tamil”, and see the Tamil translation appear.
7. Historical Context: Pre-Transformer Translation and Trade-offs
To fully appreciate Transformers (and also to question whether they’re always the right choice), let’s briefly look at what came before and how Transformers compare.
7.1 Pre-Transformer: RNNs, LSTMs, and Seq2Seq Models
The idea of sequence-to-sequence (seq2seq) learning for translation was popularized around 2014. The classic architecture (Sutskever et al. 2014) used two RNNs: an encoder RNN to read the source sentence into a context vector, and a decoder RNN to generate the target sentence from that vector. This initial approach struggled with long sentences because the entire source had to be compressed into a single vector at the encoder’s last state.
Attention mechanism (Bahdanau et al. 2015) was a breakthrough that alleviated this by allowing the decoder to “attend” to the encoder’s output at each time step, rather than just relying on one vector. This gave rise to the RNN Encoder-Attention-Decoder architecture (often just called “seq2seq with attention”). For a while (2015-2017), this was the state-of-the-art in machine translation:
Typically using LSTM or GRU networks (which are types of RNNs that handle long-term dependencies better via gating).
E.g., a 2-layer Bi-LSTM encoder, and a 2-layer LSTM decoder with attention was common for many translation systems.
Google Translate (early versions): Initially (pre-2016), Google Translate was not neural at all – it used a phrase-based statistical machine translation (SMT) system. In 2016, Google announced they switched to a neural approach called Google Neural Machine Translation (GNMT). GNMT’s architecture (Wu et al. 2016) was a massive LSTM-based seq2seq model: 8 encoder LSTM layers, 8 decoder LSTM layers, with residual connections between layers, and an attention mechanism. It achieved large improvements in translation quality over the phrase-based system.
So up until 2017, the best translation models were using recurrent networks + attention. These models have some characteristics:
Sequential processing: The RNN reads one word at a time (though bidirectional encoders read both directions separately). This makes it difficult to parallelize on hardware, as discussed – if you have a long sentence, you can’t easily shorten that wall-clock time by using more GPUs, you still have to step through word by word.
Memory issues: RNNs have to maintain a hidden state that carries information through the sequence. Very long sequences can still be problematic, though attention mitigated the bottleneck by letting the model lookup specific parts of the source via attention instead of storing everything in the hidden state.
Fewer parameters for same layer size: An LSTM layer with hidden size 512 will have far fewer parameters than a self-attention layer with 8 heads of 64 each plus all those matrices. So RNN models often had fewer parameters than an equivalent Transformer. (However, they might need more layers to get similar expressiveness. And Transformers often scale parameters up to use the extra capacity beneficially.)
Training data needs: Both approaches benefit from more data, but Transformers, with more parameters and no recurrence, often needed even more data to generalize well initially (plus techniques like regularization, dropout, etc., which both needed).
Convergence speed: Transformers typically converge faster per epoch because of parallelism and perhaps easier access to long-range info. But each epoch might process more data due to parallelism, so it’s somewhat balanced.
Other pre-Transformer methods:
There were also convolutional seq2seq models (e.g., Facebook’s ConvS2S in 2017) that used CNNs over sequences + attention, offering parallelism over timesteps but limited context per layer.
Various tricks like Byte Pair Encoding (BPE) for subword tokenization (still used with Transformers) were developed in the RNN era to address open vocabulary.
7.2 Why Transformers Replaced RNNs in Translation
When Transformers came, they demonstrated superior performance on translation tasks (e.g., English-German, English-French benchmarks in the paper) at a fraction of the training time (because you can train on more GPUs effectively). Over time, as hardware advanced, Transformers scaled to unprecedented sizes (see models like GPT-3 with 175 billion parameters – impossible with an RNN-based approach).
Quality: The self-attention mechanism is very expressive. It can learn alignments like the earlier attention mechanisms, but also more complex relationships. Multi-head attention means the model can capture multiple different alignment patterns or linguistic correlations simultaneously. This often leads to better fluency and adequacy in translations – especially for longer sentences where an RNN might forget or mix context, a Transformer can directly attend to the relevant part even if it’s far away.
Computational cost: Transformers are not strictly “cheaper” – in fact, self-attention is $O(n^2)$ in sequence length for computation and memory (because of the attention matrix of size seq_len x seq_len). RNNs are $O(n)$ in sequence length (linear). So for very long sequences, vanilla Transformers can be slow or memory-heavy. In translation, typical sentences aren’t crazy long (a few dozen words on average), so it’s usually fine. But if you tried to use a Transformer on say a book-length input (thousands of tokens), you’d hit efficiency issues – and indeed research into efficient/longform Transformers is ongoing.
In practice, for sentences up to a few hundred tokens, the parallelism advantage outweighs the quadratic cost, given modern GPUs can handle matrix multiplies very well. So Transformers train faster and yield better quality for typical translation tasks. RNNs might still be useful for streaming scenarios or very low-resource environments, but most benchmarks and industry systems have moved to Transformers.
Is it worth using Transformers for similar tasks? For translation, yes – almost unequivocally, transformers are the state-of-art. For other sequence tasks like speech recognition or certain low-latency applications, one might still consider RNNs or hybrids. But in NLP, transformers dominate.
One caveat: transformer models are larger and require more memory. If one had to deploy a translation model on a mobile device, a full transformer might be too heavy, and a smaller LSTM model or a distilled transformer might be chosen. There’s also the concept of knowledge distillation – one can train a big transformer then distill it into a smaller model (even an RNN) if needed.
For the languages in question (Indic languages), having a unified multilingual transformer helps with efficiency (one model for all languages) and can do zero-shot or transfer learning (maybe the model could translate a language pair it never explicitly saw via bridging through English, etc.). RNN-based multilingual models also existed, but again, the transformer’s capacity is beneficial for juggling many languages.
Early Google Translate vs Transformers: Google’s LSTM-based system (GNMT) was very advanced for its time, using attention and residual connections, etc. When Google switched to Transformers (around 2018 in production, as they reported), they saw inference speed improvements and quality gains. The trade-off was that they had to retrain models and the models might have been bigger, but Google has the resources.
For an individual or researcher, the question “Is it worth using Transformers for similar tasks?” usually yes if you have the compute – because they often give better accuracy per time spent. But if one just has a CPU or a mobile device, a small GRU model might be “worth it” in that constrained scenario.
7.3 Final Thoughts on Historical Evolution
The sequence modeling field has largely coalesced around the transformer architecture, even beyond text (see Vision Transformers in computer vision, etc.). The ideas from RNN seq2seq and attention laid the groundwork, but the transformer took it to the next level.
From SMT to LSTMs to Transformers, each leap brought better translation quality and easier scaling:
SMT (phrase-based): fast and rule-based, but required huge human effort for features and couldn’t capture context well.
LSTM seq2seq: learned from data directly, handled arbitrary sentences, improved with attention.
Transformers: further improved learning by removing sequential bottlenecks and enabling training on much bigger data.
In the context of our project: using a Transformer for English-Indic translation is definitely justified given the complexity of languages and nuances (like handling long-distance agreement or reordering, which attention can capture). The quality improvement is generally worth the computational cost, especially if targeting high-quality translation outputs. The Samanantar dataset’s scale also practically demands a model that can effectively leverage lots of data – something Transformers excel at due to their scalability.
Conclusion
In this post, we covered a lot of ground – from the theoretical underpinnings of the Transformer model in the “Attention Is All You Need” paper, to a hands-on understanding of each component via a PyTorch implementation. We also discussed the multilingual Samanantar dataset that enables translation between English and 11 Indic languages, and how to train a Transformer on such data (including setting hyperparameters and using training tricks like warmup scheduling and label smoothing).
We explored how to actually generate translations with a trained model, emphasizing how to load model weights and perform inference step-by-step. Additionally, we saw how easy it is to wrap the model in a Streamlit app to create a mini Google Translate-like interface for demonstrations. Finally, we put everything in context by reviewing how translation models evolved from recurrent networks to Transformers and what trade-offs come with these choices.
Transformers continue to be an exciting area, with ongoing research improving efficiency (for long sequences), adapting them to low-resource settings, and extending their capabilities. But the fundamentals remain as described here. If you’ve made it this far, you should have a solid understanding of how Transformers work and how to implement and use them for translation tasks.
Happy translating with Transformers!
References:
Vaswani et al., Attention Is All You Need, NeurIPS 2017. (Introduced the Transformer model).
Artificial intelligence sheds light on how the brain processes language
AI Chatbots Work by Predicting the Next Word. So Do Our Brains. Is There a Connection?
What ChatGPT understands: Large language models and the neuroscience of meaning