End To End Paper Implementation “Attention Is All You Need”

Attention is All You Need: Complete Implementation

In this article, I present an end-to-end implementation of the paper “Attention is All You Need”, along with selected quotes from the paper.

This article focuses only on implementation. For…


This content originally appeared on DEV Community and was authored by Ramazan Turan

Attention is All You Need: Complete Implementation

In this article, I present an end-to-end implementation of the paper "Attention is All You Need", along with selected quotes from the paper.

This article focuses only on implementation. For a more explanatory and conceptual guide, I recommend the following YouTube video: https://www.youtube.com/watch?v=KJtZARuO3JY

Detailed Implementation

Components

  • Encoder
  • Decoder
  • Attention Mechanisms
    • Self-Attention (Encoder)
    • Masked Self-Attention (Decoder)
    • Cross-Attention (Encoder-Decoder Attention)
  • Position-wise Feed-Forward Networks
  • Layer Normalization
  • Positional Encoding
  • Embedding and Output Layers
    • Input Embedding
    • Output Linear ve Softmax

Architecture

Encoder

"The encoder is composed of a stack of N = 6 identical layers. Each layer has two sub-layers. The first is a multi-head self-attention mechanism, and the second is a simple, position-wise fully connected feed-forward network. We employ a residual connection around each of the two sub-layers, followed by layer normalization. That is, the output of each sub-layer is LayerNorm(x + Sublayer(x)), where Sublayer(x) is the function implemented by the sub-layer itself. To facilitate these residual connections, all sub-layers in the model, as well as the embedding layers, produce outputs of dimension dmodel = 512."

Positional Encoding

class PositionalEncoding(nn.Module):
    """
    Adds positional encoding to the token embeddings for the Transformer model
    Paper Reference: Section 3.5 "Positional Encoding"
    Described in Equations 5 and 6.
    """
    def __init__(self, d_model, max_len=5000, dropout=0.1):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        # Positional encoding calculation
        # Paper Reference: Section 3.5, Equations 5 and 6
        # PE(pos, 2i) = sin(pos / 10000^(2i/d_model))
        # PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)  # Sine for even indices
        pe[:, 1::2] = torch.cos(position * div_term)  # Cosine for odd indices
        pe = pe.unsqueeze(0).transpose(0, 1)  # [max_len, 1, d_model]

        # Register as a persistent buffer (not a model parameter)
        self.register_buffer('pe', pe)

    def forward(self, x):
        """
        Args:
            x: [seq_len, batch_size, d_model]
        """
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

Multi-Head Self-Attention


class MultiHeadAttention(nn.Module):
    """
    Multi-Head Attention mechanism
    Paper Reference: Section 3.2.2 "Multi-Head Attention"
    Described in Equations 1 and 2.
    Structure is shown on the left side of Figure 2 in the paper.
    """

    def __init__(self, d_model, n_head, dropout=0.1):
        super(MultiHeadAttention, self).__init__()
        assert d_model % n_head == 0

        self.d_model = d_model
        self.n_head = n_head
        self.d_k = d_model // n_head

        # Linear projections
        self.wq = nn.Linear(d_model, d_model)
        self.wk = nn.Linear(d_model, d_model)
        self.wv = nn.Linear(d_model, d_model)
        self.fc = nn.Linear(d_model, d_model)

        self.dropout = nn.Dropout(dropout)

    def forward(self, query, key, value, mask=None):
        """
        Args:
            query: [batch_size, seq_len_q, d_model]
            key: [batch_size, seq_len_k, d_model]
            value: [batch_size, seq_len_v, d_model]
            mask: [batch_size, seq_len_q, seq_len_k] or [batch_size, 1, seq_len_q, seq_len_k]

        Paper Reference: Sections 3.2.1 and 3.2.2
        Attention(Q, K, V) = softmax(QK^T / sqrt(d_k))V
        MultiHead(Q, K, V) = Concat(head_1, ..., head_h)W^O
            head_i = Attention(QW_i^Q, KW_i^K, VW_i^V)
        """
        batch_size = query.size(0)

        # Linear projections and head separation
        # [batch_size, seq_len, n_head, d_k]
        q = self.wq(query).view(batch_size, -1, self.n_head, self.d_k).transpose(1, 2)
        k = self.wk(key).view(batch_size, -1, self.n_head, self.d_k).transpose(1, 2)
        v = self.wv(value).view(batch_size, -1, self.n_head, self.d_k).transpose(1, 2)

        # Scaled Dot-Product Attention calculation
        # Paper Reference: Section 3.2.1, Equation 1
        # Attention(Q, K, V) = softmax(QK^T / sqrt(d_k))V
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)

        # Masking (optional)
        if mask is not None:
            # Handle different mask dimensions
            if mask.dim() == 3:  # [batch_size, seq_len_q, seq_len_k]
                mask = mask.unsqueeze(1)  # [batch_size, 1, seq_len_q, seq_len_k]
            elif mask.dim() == 4:  # [batch_size, 1, seq_len_q, seq_len_k]
                pass  # Already correct dimension

            # Expand mask for all heads
            mask = mask.expand(batch_size, self.n_head, -1, -1)
            scores = scores.masked_fill(mask == 0, -1e9)

        # Softmax and Dropout
        attn = F.softmax(scores, dim=-1)
        attn = self.dropout(attn)

        # Output calculation
        out = torch.matmul(attn, v)  # [batch_size, n_head, seq_len_q, d_k]
        out = out.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
        out = self.fc(out)

        return out

Positional Feed-Forward

class PositionwiseFeedForward(nn.Module):
    """
    Two-layer Feed-Forward Network
    Paper Reference: Section 3.3 "Position-wise Feed-Forward Networks"
    FFN(x) = max(0, xW₁ + b₁)W₂ + b₂ 
    Described in Equation 2.
    """
    def __init__(self, d_model, d_ff, dropout=0.1):
        super(PositionwiseFeedForward, self).__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        """
        Args:
            x: [batch_size, seq_len, d_model]
        """
        return self.fc2(self.dropout(F.relu(self.fc1(x))))

Encoder Layer

class EncoderLayer(nn.Module):
    """
    Transformer Encoder Layer
    Paper Reference: Section 3.1 "Encoder and Decoder Stacks" - Encoder part
    Structure shown on the left side of Figure 1.
    Each encoder layer has a multi-head self-attention and a feed-forward network.
    Each sublayer has a residual connection and layer normalization.
    """

    def __init__(self, d_model, n_head, d_ff, dropout=0.1):
        super(EncoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(d_model, n_head, dropout)
        self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        # Quote: "We apply dropout [33] to the output of each sub-layer, before it is 
        # added to the sub-layer input and normalized"
        # Section: 5.4 Regularization

    def forward(self, x, mask=None):
        """
        Args:
            x: [batch_size, seq_len, d_model]
            mask: [batch_size, seq_len, seq_len]

        Paper Reference: Section 3.1, "Sublayer Connection"
        For each sublayer: LayerNorm(x + Sublayer(x))
        """
        # First sublayer: Multi-Head Self-Attention
        attn_output = self.self_attn(x, x, x, mask)
        x = self.norm1(x + self.dropout1(attn_output)) # Residual connection + LayerNorm

        # Second sublayer: Position-wise Feed-Forward Network
        ff_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout2(ff_output)) # Residual connection + LayerNorm

        return x

Encoder

class Encoder(nn.Module):
    """
    Transformer Encoder
    Paper Reference: Section 3.1 "Encoder and Decoder Stacks"
    The encoder consists of N=6 identical encoder layers.
    First applies token embedding and positional encoding.
    """
    def __init__(self, vocab_size, d_model, n_head, d_ff, n_layers, dropout=0.1, max_len=5000):
        super(Encoder, self).__init__()
        # Paper Reference: Section 3.4, "Embeddings and Softmax"
        # "We multiply those weights by sqrt(d_model)"
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = PositionalEncoding(d_model, max_len, dropout)
        self.layers = nn.ModuleList([
            EncoderLayer(d_model, n_head, d_ff, dropout) for _ in range(n_layers)
        ])
        self.norm = nn.LayerNorm(d_model)

    def forward(self, x, mask=None):
        """
        Args:
            x: [batch_size, src_seq_len]
            mask: [batch_size, src_seq_len, src_seq_len]

        Paper Reference: Section 3.1 "Encoder"
        The encoder consists of N identical encoder layers.
        """
        # Embedding and Positional Encoding
        # Paper Reference: Section 3.4, "We multiply those weights by sqrt(d_model)"
        x = self.embedding(x) * math.sqrt(self.embedding.embedding_dim)
        x = self.pos_encoding(x)

        for layer in self.layers:
            x = layer(x, mask)

        x = self.norm(x)
        return x

Decoder

"The decoder is also composed of a stack of N = 6 identical layers. In addition to the two sub-layers in each encoder layer, the decoder inserts a third sub-layer, which performs multi-head attention over the output of the encoder stack. Similar to the encoder, we employ residual connections around each of the sub-layers, followed by layer normalization. We also modify the self-attention sub-layer in the decoder stack to prevent positions from attending to subsequent positions. This masking, combined with fact that the output embeddings are offset by one position, ensures that the predictions for position i can depend only on the known outputs at positions less than i."

Decoder Layer

class DecoderLayer(nn.Module):
    """
    Transformer Decoder Layer
    Paper Reference: Section 3.1 "Encoder and Decoder Stacks" - Decoder part
    Structure shown on the right side of Figure 1.
    Each decoder layer has a masked multi-head self-attention, 
    a multi-head cross-attention, and a feed-forward network.
    Each sublayer has a residual connection and layer normalization.
    """
    def __init__(self, d_model, n_head, d_ff, dropout=0.1):
        super(DecoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(d_model, n_head, dropout)
        self.cross_attn = MultiHeadAttention(d_model, n_head, dropout)
        self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.dropout3 = nn.Dropout(dropout)

    def forward(self, x, enc_output, self_mask=None, cross_mask=None):
        """
        Args:
            x: [batch_size, seq_len, d_model]
            enc_output: [batch_size, src_seq_len, d_model]
            self_mask: [batch_size, seq_len, seq_len]
            cross_mask: [batch_size, seq_len, src_seq_len]

        Paper Reference: Section 3.1, "Decoder"
        The decoder has masked multi-head attention, multi-head attention, and 
        feed-forward network sublayers.
        For each sublayer: LayerNorm(x + Sublayer(x))
        """
        attn_output = self.self_attn(x, x, x, self_mask)
        x = self.norm1(x + self.dropout1(attn_output))  # Residual connection + LayerNorm

        # Sublayer with Cross-Attention (Decoder attends to encoder output)
        attn_output = self.cross_attn(x, enc_output, enc_output, cross_mask)
        x = self.norm2(x + self.dropout2(attn_output)) # Residual connection + LayerNorm

        # Sublayer with Feed-Forward Network
        ff_output = self.feed_forward(x)
        x = self.norm3(x + self.dropout3(ff_output)) # Residual connection + LayerNorm

        return x

Decoder

class Decoder(nn.Module):
    """
    Transformer Decoder
    Paper Reference: Section 3.1 "Encoder and Decoder Stacks"
    The decoder consists of N=6 identical decoder layers.
    Like the encoder, it first applies token embedding and positional encoding.
    The decoder also uses masking for subsequent positions (section 3.2.3).
    """

    def __init__(self, vocab_size, d_model, n_head, d_ff, n_layers, dropout=0.1, max_len=5000):
        super(Decoder, self).__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = PositionalEncoding(d_model, max_len, dropout)
        self.layers = nn.ModuleList([
            DecoderLayer(d_model, n_head, d_ff, dropout) for _ in range(n_layers)
        ])
        self.norm = nn.LayerNorm(d_model)

    def forward(self, x, enc_output, self_mask=None, cross_mask=None):
        """
        Args:
            x: [batch_size, tgt_seq_len]
            enc_output: [batch_size, src_seq_len, d_model]
            self_mask: [batch_size, tgt_seq_len, tgt_seq_len] or [batch_size, 1, tgt_seq_len, tgt_seq_len]
            cross_mask: [batch_size, tgt_seq_len, src_seq_len] or [batch_size, 1, tgt_seq_len, src_seq_len]
        """
        # Embedding and Positional Encoding
        x = self.embedding(x) * math.sqrt(self.embedding.embedding_dim)
        x = self.pos_encoding(x)

        # Pass through decoder layers
        for layer in self.layers:
            x = layer(x, enc_output, self_mask, cross_mask)

        x = self.norm(x)
        return x

Transformer Model

class Transformer(nn.Module):
    """
    Transformer model (Attention is All You Need)
    Paper Reference: The entire paper, especially Section 3 and Figure 1
    The Transformer consists of an encoder and a decoder.
    The output projection converts decoder output to target word distributions.
    """
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model=512, n_head=8, 
                 d_ff=2048, n_layers=6, dropout=0.1, max_len=5000):
        super(Transformer, self).__init__()
        # Paper Reference: Section 3.1 and Table 1
        # Base model: d_model=512, n_head=8, d_ff=2048, n_layers=6, dropout=0.1
        self.encoder = Encoder(src_vocab_size, d_model, n_head, d_ff, n_layers, dropout, max_len)
        self.decoder = Decoder(tgt_vocab_size, d_model, n_head, d_ff, n_layers, dropout, max_len)
        self.projection = nn.Linear(d_model, tgt_vocab_size)

    def forward(self, src, tgt, src_mask=None, tgt_mask=None, memory_mask=None):
        """
        Args:
            src: [batch_size, src_seq_len]
            tgt: [batch_size, tgt_seq_len]
            src_mask: [batch_size, src_seq_len, src_seq_len]
            tgt_mask: [batch_size, tgt_seq_len, tgt_seq_len]
            memory_mask: [batch_size, tgt_seq_len, src_seq_len]

        Paper Reference: The entire paper, especially Figure 1
        Transformer model flow: 
        1. Encoder takes input and produces encoder output
        2. Decoder takes encoder output and its own input
        3. Final projection produces target word distribution
        """
        enc_output = self.encoder(src, src_mask)
        dec_output = self.decoder(tgt, enc_output, tgt_mask, memory_mask)
        output = self.projection(dec_output)
        return output

Full Code


python
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

class PositionalEncoding(nn.Module):
    """
    Adds positional encoding to the token embeddings for the Transformer model
    Paper Reference: Section 3.5 "Positional Encoding"
    Described in Equations 5 and 6.
    """

    def __init__(self, d_model, max_len=5000, dropout=0.1):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        # Positional encoding calculation
        # Paper Reference: Section 3.5, Equations 5 and 6
        # PE(pos, 2i) = sin(pos / 10000^(2i/d_model))
        # PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)  # Even indices use sine
        pe[:, 1::2] = torch.cos(position * div_term)  # Odd indices use cosine
        pe = pe.unsqueeze(0)  # [1, max_len, d_model]

        # Register as persistent buffer (not a model parameter)
        self.register_buffer('pe', pe)

    def forward(self, x):
        """
        Args:
            x: [batch_size, seq_len, d_model]
        """
        x = x + self.pe[:, :x.size(1), :]
        return self.dropout(x)

class MultiHeadAttention(nn.Module):
    """
    Multi-Head Attention mechanism
    Paper Reference: Section 3.2.2 "Multi-Head Attention"
    Described in Equations 1 and 2.
    Structure is shown on the left side of Figure 2 in the paper.
    """

    def __init__(self, d_model, n_head, dropout=0.1):
        super(MultiHeadAttention, self).__init__()
        assert d_model % n_head == 0

        self.d_model = d_model
        self.n_head = n_head
        self.d_k = d_model // n_head

        # Linear projections
        self.wq = nn.Linear(d_model, d_model)
        self.wk = nn.Linear(d_model, d_model)
        self.wv = nn.Linear(d_model, d_model)
        self.fc = nn.Linear(d_model, d_model)

        self.dropout = nn.Dropout(dropout)

    def forward(self, query, key, value, mask=None):
        """
        Args:
            query: [batch_size, seq_len_q, d_model]
            key: [batch_size, seq_len_k, d_model]
            value: [batch_size, seq_len_v, d_model]
            mask: [batch_size, seq_len_q, seq_len_k] or [batch_size, 1, seq_len_q, seq_len_k]

        Paper Reference: Sections 3.2.1 and 3.2.2
        Attention(Q, K, V) = softmax(QK^T / sqrt(d_k))V
        MultiHead(Q, K, V) = Concat(head_1, ..., head_h)W^O
            head_i = Attention(QW_i^Q, KW_i^K, VW_i^V)
        """
        batch_size = query.size(0)

        # Linear projections and head separation
        # [batch_size, seq_len, n_head, d_k]
        q = self.wq(query).view(batch_size, -1, self.n_head, self.d_k).transpose(1, 2)
        k = self.wk(key).view(batch_size, -1, self.n_head, self.d_k).transpose(1, 2)
        v = self.wv(value).view(batch_size, -1, self.n_head, self.d_k).transpose(1, 2)

        # Scaled Dot-Product Attention calculation
        # Paper Reference: Section 3.2.1, Equation 1
        # Attention(Q, K, V) = softmax(QK^T / sqrt(d_k))V
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)

        # Masking (optional)
        if mask is not None:
            # Handle different mask dimensions
            if mask.dim() == 3:  # [batch_size, seq_len_q, seq_len_k]
                mask = mask.unsqueeze(1)  # [batch_size, 1, seq_len_q, seq_len_k]
            elif mask.dim() == 4:  # [batch_size, 1, seq_len_q, seq_len_k]
                pass  # Already correct dimension

            # Expand mask for all heads
            mask = mask.expand(batch_size, self.n_head, -1, -1)
            scores = scores.masked_fill(mask == 0, -1e9)

        # Softmax and Dropout
        attn = F.softmax(scores, dim=-1)
        attn = self.dropout(attn)

        # Output calculation
        out = torch.matmul(attn, v)  # [batch_size, n_head, seq_len_q, d_k]
        out = out.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
        out = self.fc(out)

        return out

class PositionwiseFeedForward(nn.Module):
    """
    Two-layer Feed-Forward Network
    Paper Reference: Section 3.3 "Position-wise Feed-Forward Networks"
    FFN(x) = max(0, xW₁ + b₁)W₂ + b₂
    Described in Equation 2.
    """

    def __init__(self, d_model, d_ff, dropout=0.1):
        super(PositionwiseFeedForward, self).__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        """
        Args:
            x: [batch_size, seq_len, d_model]
        """
        return self.fc2(self.dropout(F.relu(self.fc1(x))))

class EncoderLayer(nn.Module):
    """
    Transformer Encoder Layer
    Paper Reference: Section 3.1 "Encoder and Decoder Stacks" - Encoder part
    Structure shown on the left side of Figure 1.
    Each encoder layer has a multi-head self-attention and a feed-forward network.
    Each sublayer has a residual connection and layer normalization.
    """

    def __init__(self, d_model, n_head, d_ff, dropout=0.1):
        super(EncoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(d_model, n_head, dropout)
        self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        # Quote: "We apply dropout [33] to the output of each sub-layer, before it is 
        # added to the sub-layer input and normalized"
        # Section: 5.4 Regularization
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        """
        Args:
            x: [batch_size, seq_len, d_model]
            mask: [batch_size, seq_len, seq_len]

        Paper Reference: Section 3.1, "Sublayer Connection"
        For each sublayer: LayerNorm(x + Sublayer(x))
        """
        # First sublayer: Multi-Head Self-Attention
        attn_output = self.self_attn(x, x, x, mask)
        x = self.norm1(x + self.dropout1(attn_output)) # Residual connection + LayerNorm

        # Second sublayer: Position-wise Feed-Forward Network
        ff_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout2(ff_output)) # Residual connection + LayerNorm

        return x

class DecoderLayer(nn.Module):
    """
    Transformer Decoder Layer
    Paper Reference: Section 3.1 "Encoder and Decoder Stacks" - Decoder part
    Structure shown on the right side of Figure 1.
    Each decoder layer has a masked multi-head self-attention,
    a multi-head cross-attention, and a feed-forward network.
    Each sublayer has a residual connection and layer normalization.
    """

    def __init__(self, d_model, n_head, d_ff, dropout=0.1):
        super(DecoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(d_model, n_head, dropout)
        self.cross_attn = MultiHeadAttention(d_model, n_head, dropout)
        self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.dropout3 = nn.Dropout(dropout)

    def forward(self, x, enc_output, self_mask=None, cross_mask=None):
        """
        Args:
            x: [batch_size, seq_len, d_model]
            enc_output: [batch_size, src_seq_len, d_model]
            self_mask: [batch_size, seq_len, seq_len]
            cross_mask: [batch_size, seq_len, src_seq_len]

        Paper Reference: Section 3.1, "Decoder"
        The decoder has masked multi-head attention, multi-head attention, and 
        feed-forward network sublayers.
        For each sublayer: LayerNorm(x + Sublayer(x))
        """
        attn_output = self.self_attn(x, x, x, self_mask)
        x = self.norm1(x + self.dropout1(attn_output))  # Residual connection + LayerNorm

        # Sublayer with Cross-Attention (Decoder attends to encoder output)
        attn_output = self.cross_attn(x, enc_output, enc_output, cross_mask)
        x = self.norm2(x + self.dropout2(attn_output))
        # Sublayer with Feed-Forward Network
        ff_output = self.feed_forward(x)
        x = self.norm3(x + self.dropout3(ff_output)) # Residual connection + LayerNorm

        return x

class Encoder(nn.Module):
    """
    Transformer Encoder
    Paper Reference: Section 3.1 "Encoder and Decoder Stacks"
    The encoder consists of N=6 identical encoder layers.
    First applies token embedding and positional encoding.
    """

    def __init__(self, vocab_size, d_model, n_head, d_ff, n_layers, dropout=0.1, max_len=5000):
        super(Encoder, self).__init__()
        # Paper Reference: Section 3.4, "Embeddings and Softmax"
        # "We multiply those weights by sqrt(d_model)"
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = PositionalEncoding(d_model, max_len, dropout)
        self.layers = nn.ModuleList([
            EncoderLayer(d_model, n_head, d_ff, dropout) for _ in range(n_layers)
        ])
        self.norm = nn.LayerNorm(d_model)

    def forward(self, x, mask=None):
        """
        Args:
            x: [batch_size, src_seq_len]
            mask: [batch_size, src_seq_len, src_seq_len] or [batch_size, 1, src_seq_len, src_seq_len]

        Paper Reference: Section 3.1 "Encoder"
        The encoder consists of N identical encoder layers.
        """
        # Embedding and Positional Encoding
        # Paper Reference: Section 3.4, "We multiply those weights by sqrt(d_model)"
        x = self.embedding(x) * math.sqrt(self.embedding.embedding_dim)
        x = self.pos_encoding(x)

        # Pass through encoder layers
        for layer in self.layers:
            x = layer(x, mask)

        x = self.norm(x)
        return x

class Decoder(nn.Module):
    """
    Transformer Decoder
    Paper Reference: Section 3.1 "Encoder and Decoder Stacks"
    The decoder consists of N=6 identical decoder layers.
    Like the encoder, it first applies token embedding and positional encoding.
    The decoder also uses masking for subsequent positions (section 3.2.3).
    """

    def __init__(self, vocab_size, d_model, n_head, d_ff, n_layers, dropout=0.1, max_len=5000):
        super(Decoder, self).__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = PositionalEncoding(d_model, max_len, dropout)
        self.layers = nn.ModuleList([
            DecoderLayer(d_model, n_head, d_ff, dropout) for _ in range(n_layers)
        ])
        self.norm = nn.LayerNorm(d_model)

    def forward(self, x, enc_output, self_mask=None, cross_mask=None):
        """
        Args:
            x: [batch_size, tgt_seq_len]
            enc_output: [batch_size, src_seq_len, d_model]
            self_mask: [batch_size, tgt_seq_len, tgt_seq_len] or [batch_size, 1, tgt_seq_len, tgt_seq_len]
            cross_mask: [batch_size, tgt_seq_len, src_seq_len] or [batch_size, 1, tgt_seq_len, src_seq_len]
        """
        # Embedding and Positional Encoding
        x = self.embedding(x) * math.sqrt(self.embedding.embedding_dim)
        x = self.pos_encoding(x)

        # Pass through decoder layers
        for layer in self.layers:
            x = layer(x, enc_output, self_mask, cross_mask)

        x = self.norm(x)
        return x

class Transformer(nn.Module):
    """
    Transformer model (Attention is All You Need)
    Paper Reference: The entire paper, especially Section 3 and Figure 1
    The Transformer consists of an encoder and a decoder.
    The output projection converts decoder output to target word distributions.
    """

    def __init__(self, src_vocab_size, tgt_vocab_size, d_model=512, n_head=8,
                 d_ff=2048, n_layers=6, dropout=0.1, max_len=5000):
        super(Transformer, self).__init__()
        # Paper Reference: Section 3.1 and Table 1
        # Base model: d_model=512, n_head=8, d_ff=2048, n_layers=6, dropout=0.1
        self.encoder = Encoder(src_vocab_size, d_model, n_head, d_ff, n_layers, dropout, max_len)
        self.decoder = Decoder(tgt_vocab_size, d_model, n_head, d_ff, n_layers, dropout, max_len)
        self.projection = nn.Linear(d_model, tgt_vocab_size)

    def forward(self, src, tgt, src_mask=None, tgt_mask=None, memory_mask=None):
        """
        Args:
            src: [batch_size, src_seq_len]
            tgt: [batch_size, tgt_seq_len]
            src_mask: [batch_size, src_seq_len, src_seq_len]
            tgt_mask: [batch_size, tgt_seq_len, tgt_seq_len]
            memory_mask: [batch_size, tgt_seq_len, src_seq_len]

        Paper Reference: The entire paper, especially Figure 1
        Transformer model flow:
        1. Encoder takes input and produces encoder output
        2. Decoder takes encoder output and its own input
        3. Final projection produces target word distribution
        """
        enc_output = self.encoder(src, src_mask)
        dec_output = self.decoder(tgt, enc_output, tgt_mask, memory_mask)
        output = self.projection(dec_output)
        return output

def create_masks(src, tgt, pad_idx):
    """
    Creates padding and subsequent masks
    Paper Reference: Section 3.2.3 "Attention Masking"
    Masking is applied in the decoder to prevent seeing future positions.
    Masking is also applied for padding tokens.
    """
    # Encoder masking (padding mask)
    # src_mask: [batch_size, src_len] -> [batch_size, 1, src_len, src_len]
    src_pad_mask = (src != pad_idx).unsqueeze(1).unsqueeze(1)  # [B, 1, 1, src_len]
    src_len = src.size(1)
    src_mask = src_pad_mask.expand(-1, -1, src_len, -1)  # [B, 1, src_len, src_len]

    # Decoder self-attention masking (padding mask + subsequent mask)
    tgt_len = tgt.size(1)
    tgt_pad_mask = (tgt != pad_idx).unsqueeze(1).unsqueeze(1)  # [B, 1, 1, tgt_len]
    tgt_pad_mask = tgt_pad_mask.expand(-1, -1, tgt_len, -1)  # [B, 1, tgt_len, tgt_len]

    # Subsequent mask (lower triangular)
    tgt_sub_mask = torch.tril(torch.ones((tgt_len, tgt_len), device=tgt.device)).bool()
    tgt_sub_mask = tgt_sub_mask.unsqueeze(0).unsqueeze(0)  # [1, 1, tgt_len, tgt_len]

    tgt_mask = tgt_pad_mask & tgt_sub_mask

    # Cross attention masking (encoder-decoder attention)
    # memory_mask: [batch_size, tgt_len, src_len]
    src_pad_mask_for_cross = (src != pad_idx).unsqueeze(1).unsqueeze(1)  # [B, 1, 1, src_len]
    memory_mask = src_pad_mask_for_cross.expand(-1, -1, tgt_len, -1)  # [B, 1, tgt_len, src_len]

    return src_mask, tgt_mask, memory_mask

# Example of model usage
def example_usage():
    """
    Example usage with proper error handling
    """
    try:
        # Parameters
        # Paper Reference: Section 3.1 and Table 1
        # Base model: d_model=512, n_head=8, d_ff=2048, n_layers=6, dropout=0.1
        src_vocab_size = 10000
        tgt_vocab_size = 10000
        d_model = 512
        n_head = 8
        d_ff = 2048
        n_layers = 6
        dropout = 0.1
        pad_idx = 0

        # Create model
        transformer = Transformer(src_vocab_size, tgt_vocab_size, d_model, n_head, d_ff, n_layers, dropout)

        # Print model parameters
        total_params = sum(p.numel() for p in transformer.parameters())
        print(f"Total parameters: {total_params:,}")

        # Example input data
        src = torch.randint(1, src_vocab_size, (64, 10))  # [batch_size=64, src_seq_len=10]
        tgt = torch.randint(1, tgt_vocab_size, (64, 20))  # [batch_size=64, tgt_seq_len=20]

        # Create masks
        src_mask, tgt_mask, memory_mask = create_masks(src, tgt, pad_idx)

        print(f"Source mask shape: {src_mask.shape}")
        print(f"Target mask shape: {tgt_mask.shape}")
        print(f"Memory mask shape: {memory_mask.shape}")

        # Forward pass
        output = transformer(src, tgt, src_mask, tgt_mask, memory_mask)
        print(f"Output shape: {output.shape}")  # [64, 20, tgt_vocab_size]
        print("Forward pass successful!")

        return output

    except Exception as e:
        print(f"Error occurred: {str(e)}")
        return None

example_usage()


This content originally appeared on DEV Community and was authored by Ramazan Turan


Print Share Comment Cite Upload Translate Updates
APA

Ramazan Turan | Sciencx (2025-07-18T20:31:55+00:00) End To End Paper Implementation “Attention Is All You Need”. Retrieved from https://www.scien.cx/2025/07/18/end-to-end-paper-implementation-attention-is-all-you-need/

MLA
" » End To End Paper Implementation “Attention Is All You Need”." Ramazan Turan | Sciencx - Friday July 18, 2025, https://www.scien.cx/2025/07/18/end-to-end-paper-implementation-attention-is-all-you-need/
HARVARD
Ramazan Turan | Sciencx Friday July 18, 2025 » End To End Paper Implementation “Attention Is All You Need”., viewed ,<https://www.scien.cx/2025/07/18/end-to-end-paper-implementation-attention-is-all-you-need/>
VANCOUVER
Ramazan Turan | Sciencx - » End To End Paper Implementation “Attention Is All You Need”. [Internet]. [Accessed ]. Available from: https://www.scien.cx/2025/07/18/end-to-end-paper-implementation-attention-is-all-you-need/
CHICAGO
" » End To End Paper Implementation “Attention Is All You Need”." Ramazan Turan | Sciencx - Accessed . https://www.scien.cx/2025/07/18/end-to-end-paper-implementation-attention-is-all-you-need/
IEEE
" » End To End Paper Implementation “Attention Is All You Need”." Ramazan Turan | Sciencx [Online]. Available: https://www.scien.cx/2025/07/18/end-to-end-paper-implementation-attention-is-all-you-need/. [Accessed: ]
rf:citation
» End To End Paper Implementation “Attention Is All You Need” | Ramazan Turan | Sciencx | https://www.scien.cx/2025/07/18/end-to-end-paper-implementation-attention-is-all-you-need/ |

Please log in to upload a file.




There are no updates yet.
Click the Upload button above to add an update.

You must be logged in to translate posts. Please log in or register.