Self-Attention to Multi-Head Attention From Scratch Explained

Table of contents

Scaled Dot-product Attention Mechanism

In the previous notebook, titled Simple Self-Attention Without Trainable Weights, we implemented a simplified attention mechanism without any trainable weights.

In this notebook, we will take it a step further by implementing self-attention with trainable weights.

Let’s see how it works!

Computing the attention weights step by step

To compute the attention weights step by step, let’s first introduce the concepts clearly, and then work through an example.

We have three trainable weight matrices: Wq, Wk, and Wv. These matrices are responsible for projecting the embedded input tokens into query, key, and value vectors, respectively.

import torch
inputs = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)
inputs.shape
torch.Size([6, 3])

Query, Key and Value Vectors

Screenshot 2024-10-14 at 12.15.59 PM.png

Credit: Build a Large Language Model (From Scratch)

Lets’ define few variables:

x_2 = inputs[1]
d_in = inputs.shape[1]
d_out = 2

Now, we will initalize three weight matrices Wq, Wk, and Wv.

torch.manual_seed(42)
w_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
w_key = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
w_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
w_query.shape
torch.Size([3, 2])

Next, we compute query, key and value matrices for second input:

query_2 = x_2 @ w_query
key_2 = x_2 @ w_key
value_2 = x_2 @ w_value
query_2
tensor([1.0760, 1.7344])

Weight parameters in weight matrices are fundamental, learned coefficients that define the network connections, while attention weights are dynamic and context-specific values.

Next, we will obtain all the keys and values by performing matrix multiplication.

keys = inputs @ w_key
values = inputs @ w_value
keys.shape, values.shape
(torch.Size([6, 2]), torch.Size([6, 2]))

Attention Score

Screenshot 2024-10-14 at 12.15.02 PM.png

Credit: Build a Large Language Model (From Scratch)

Next, we compute attention score w22.

keys_2 = keys[1]
attention_score_22 = query_2.dot(key_2)
attention_score_22
tensor(3.3338)

Also, we can generalize this computation to all attention scores by doing matrix multiplication:

attention_score_2 = query_2 @ keys.T
attention_score_2
tensor([2.7084, 3.3338, 3.3013, 1.7563, 1.7869, 2.1966])

Attention Weight

Screenshot 2024-10-14 at 12.17.02 PM.png

Credit: Build a Large Language Model (From Scratch)

Now, we compute attention weights by calling attention scores and using the softmax function. However, we scale the attention scores by dividing them by the square root of the embedding dimension of the keys.

keys.shape
torch.Size([6, 2])
d_k = keys.shape[-1]
attention_weights_2 = torch.softmax(attention_score_2 / d_k**0.5, dim=-1)
attention_weights_2
tensor([0.1723, 0.2681, 0.2620, 0.0879, 0.0898, 0.1200])

The scaling mechanism by the square root of the embedding dimension is the reason why self-attention is called the scaled dot product mechanism.

Context Vector

Now it’s time to compute context vectors. It is represented as a weighted sum over input vectors. We will compute the context vector as a weighted sum over the value vectors.

Screenshot 2024-10-14 at 12.17.44 PM.png

Credit: Build a Large Language Model (From Scratch)

context_vec_2 = attention_weights_2 @ values
context_vec_2
tensor([1.4201, 0.8892])

Implementing Compact SelfAttention Class

import torch.nn as nn

Next, we will define a Python class called SelfAttention_v1. This class implements the self-attention mechanism using PyTorch’s nn.Parameter to store the weight matrices (query, key, and value). The forward method shows how these weights are used to compute attention scores, weights, and context vectors.

class SelfAttention_v1(nn.Module):

  def __init__(self, d_in, d_out):
    super().__init__()
    self.W_query = nn.Parameter(torch.randn(d_in, d_out))
    self.W_key = nn.Parameter(torch.randn(d_in, d_out))
    self.W_value = nn.Parameter(torch.randn(d_in, d_out))

  def forward(self, x):
    keys = x @ self.W_key
    queries = x @ self.W_query
    values = x @ self.W_value

    attention_scores = queries @ keys.T
    attention_weights = torch.softmax(
        attention_scores / keys.shape[-1] ** 0.5, dim = -1
    )

    context_vector = attention_weights @ values
    return context_vector
torch.manual_seed(123)
self_attention_v1 = SelfAttention_v1(d_in, d_out)
print(self_attention_v1(inputs))
tensor([[0.2845, 0.4071],
        [0.2854, 0.4081],
        [0.2854, 0.4075],
        [0.2864, 0.3974],
        [0.2863, 0.3910],
        [0.2860, 0.4039]], grad_fn=<MmBackward0>)

Now, we will define a new class SelfAttention_v2 that is similar to SelfAttention_v1 but uses nn.Linear layers for the query, key, and value transformations. It also includes an option for bias in these linear layers.

class SelfAttention_v2(nn.Module):

  def __init__(self, d_in, d_out, qkv_bias = False):
    super().__init__()
    self.W_query = nn.Linear(d_in, d_out, bias = qkv_bias)
    self.W_key = nn.Linear(d_in, d_out, bias = qkv_bias)
    self.W_value = nn.Linear(d_in, d_out, bias = qkv_bias)

  def forward(self, x):
    keys    = self.W_key(x)
    queries = self.W_query(x)
    values  = self.W_value(x)

    attention_scores = queries @ keys.T
    attention_weights = torch.softmax(
        attention_scores / keys.shape[-1] ** 0.5, dim = -1
    )

    context_vector = attention_weights @ values
    return context_vector
torch.manual_seed(768)
self_attention_v2 = SelfAttention_v2(d_in, d_out)
print(self_attention_v2(inputs))
tensor([[-0.0256, -0.0702],
        [-0.0175, -0.0742],
        [-0.0175, -0.0744],
        [-0.0177, -0.0735],
        [-0.0187, -0.0765],
        [-0.0175, -0.0721]], grad_fn=<MmBackward0>)

Notice that SelfAttention_v1 and SelfAttention_v2 produce different outputs because they are initialized with different weight matrices.

Hiding future words with casual attention

Applying casual attention mask

To implement casual self-sttention, let’s work with attention scores and weights from the previous section:

# Reuse the query and key weight matrices of the
# SelfAttention_v2 object from the previous section for convenience

queries = self_attention_v2.W_query(inputs)
keys = self_attention_v2.W_key(inputs)

attention_scores = queries @ keys.T
attention_weights = torch.softmax(
    attention_scores / keys.shape[-1] ** 0.5, dim = -1
    )
print(attention_weights)
tensor([[0.1546, 0.1686, 0.1686, 0.1703, 0.1673, 0.1706],
        [0.1741, 0.1653, 0.1653, 0.1649, 0.1653, 0.1651],
        [0.1742, 0.1652, 0.1651, 0.1651, 0.1648, 0.1656],
        [0.1731, 0.1658, 0.1659, 0.1644, 0.1672, 0.1637],
        [0.1737, 0.1636, 0.1631, 0.1691, 0.1561, 0.1744],
        [0.1722, 0.1667, 0.1671, 0.1624, 0.1722, 0.1593]],
       grad_fn=<SoftmaxBackward0>)

using Torch.tril

context_length = attention_scores.shape[0]
mask_simple = torch.tril(
    torch.ones(context_length, context_length)
)
print(mask_simple)
tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])
masked_simple = attention_weights * mask_simple
print(masked_simple)
tensor([[0.1546, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1741, 0.1653, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1742, 0.1652, 0.1651, 0.0000, 0.0000, 0.0000],
        [0.1731, 0.1658, 0.1659, 0.1644, 0.0000, 0.0000],
        [0.1737, 0.1636, 0.1631, 0.1691, 0.1561, 0.0000],
        [0.1722, 0.1667, 0.1671, 0.1624, 0.1722, 0.1593]],
       grad_fn=<MulBackward0>)

To ensure that each row sums to 1, we normalize the attention weights as follows:

row_sums = masked_simple.sum(dim=-1, keepdim=True)
masked_simple_norm = masked_simple / row_sums
print(masked_simple_norm)
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5130, 0.4870, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3453, 0.3274, 0.3273, 0.0000, 0.0000, 0.0000],
        [0.2586, 0.2478, 0.2479, 0.2457, 0.0000, 0.0000],
        [0.2104, 0.1982, 0.1975, 0.2048, 0.1890, 0.0000],
        [0.1722, 0.1667, 0.1671, 0.1624, 0.1722, 0.1593]],
       grad_fn=<DivBackward0>)

Instead of zeroing out the attention weights above the diagonal and then renormalizing, we can apply a causal mask before the softmax by setting the unnormalized attention scores above the diagonal to negative infinity. This ensures that the softmax operation naturally assigns zero probability to those masked positions.

mask = torch.triu(
    torch.ones(context_length, context_length), diagonal=1)
masked = attention_scores.masked_fill(mask.bool(), -torch.inf)
print(masked)
tensor([[-0.1101,    -inf,    -inf,    -inf,    -inf,    -inf],
        [ 0.0592, -0.0141,    -inf,    -inf,    -inf,    -inf],
        [ 0.0569, -0.0181, -0.0186,    -inf,    -inf,    -inf],
        [ 0.0615,  0.0007,  0.0014, -0.0112,    -inf,    -inf],
        [-0.0016, -0.0861, -0.0909, -0.0399, -0.1530,    -inf],
        [ 0.0848,  0.0392,  0.0422,  0.0024,  0.0853, -0.0248]],
       grad_fn=<MaskedFillBackward0>)
attention_weights = torch.softmax(
    masked / keys.shape[-1]**0.5, dim=-1
)
print(attention_weights)
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5130, 0.4870, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3453, 0.3274, 0.3273, 0.0000, 0.0000, 0.0000],
        [0.2586, 0.2478, 0.2479, 0.2457, 0.0000, 0.0000],
        [0.2104, 0.1982, 0.1975, 0.2048, 0.1890, 0.0000],
        [0.1722, 0.1667, 0.1671, 0.1624, 0.1722, 0.1593]],
       grad_fn=<SoftmaxBackward0>)

Masking additional weights with dropout

Next, we create a Dropout layer that randomly “turns off” 50% of the values. We also make a 6×6 matrix filled with ones. When we pass this matrix through the Dropout layer, half of the values become zero randomly, and the rest are scaled up.

torch.manual_seed(123)
dropout = torch.nn.Dropout(0.5) # dropout rate = 50%
example = torch.ones(6, 6)      # create a matrix of ones

print(dropout(example))
tensor([[2., 2., 0., 2., 2., 0.],
        [0., 0., 0., 2., 0., 2.],
        [2., 2., 2., 2., 0., 2.],
        [0., 2., 2., 0., 0., 2.],
        [0., 2., 0., 2., 0., 2.],
        [0., 2., 2., 2., 2., 0.]])

Then we apply the same Dropout layer to attention_weights. This shows how Dropout can also be used on attention weights to add randomness and prevent the model from relying too much on certain patterns.

torch.manual_seed(123)
print(dropout(attention_weights))
tensor([[2.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.6905, 0.6549, 0.6546, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.4955, 0.4958, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.3964, 0.0000, 0.4096, 0.0000, 0.0000],
        [0.0000, 0.3334, 0.3341, 0.3249, 0.3445, 0.0000]],
       grad_fn=<MulBackward0>)

Implementing compact casual self-attention class

# 2 inputs with 6 tokens each, and each token has embedding dimension 3
batch = torch.stack((inputs, inputs), dim=0)
print(batch.shape)
torch.Size([2, 6, 3])

Next up, we create a small custom attention layer called CasualAttention. This layer helps the model look at previous tokens while blocking future ones. That’s why we call it causal — the model can only “see” the past.

class CasualAttention(nn.Module):

    def __init__(
        self,
        d_in,
        d_out,
        context_length,
        dropout,
        qkv_bias=False
    ):
        super().__init__()
        self.d_out = d_out
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
            'mask',
            torch.triu(torch.ones(context_length, context_length), diagonal=1)
        )

    def forward(self, x):

        b, num_tokens, d_in = x.shape
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

        attention_scores = queries @ keys.transpose(1, 2)
        attention_scores.masked_fill(
            self.mask.bool()[:num_tokens, :num_tokens],
            -torch.inf
        )

        attention_weights = torch.softmax(
            attention_scores / keys.shape[-1] ** 0.5,
            dim = -1
        )
        attention_weights = self.dropout(attention_weights)

        context_vector = attention_weights @ values
        return context_vector

Here we test the attention layer.

  • Create an instance of CasualAttention.
  • Pass our batch through it to get context vectors.

This confirms the layer runs correctly and produces the expected output size.

torch.manual_seed(123)

context_length = batch.shape[1]
ca = CasualAttention(d_in, d_out, context_length, 0.0)
context_vecs = ca(batch)

print(context_vecs)
print(f"Context vector shape: {context_vecs.shape}")
tensor([[[-0.5337, -0.1051],
         [-0.5323, -0.1080],
         [-0.5323, -0.1079],
         [-0.5297, -0.1076],
         [-0.5311, -0.1066],
         [-0.5299, -0.1081]],

        [[-0.5337, -0.1051],
         [-0.5323, -0.1080],
         [-0.5323, -0.1079],
         [-0.5297, -0.1076],
         [-0.5311, -0.1066],
         [-0.5299, -0.1081]]], grad_fn=<UnsafeViewBackward0>)
Context vector shape: torch.Size([2, 6, 2])

Extending single-head attention to multi-head attention

In this code, we build a simple version of multi-head attention. The idea is that instead of using just one attention operation, we run several attention heads in parallel. Each head learns to focus on different parts of the input, so the model can understand the data from multiple angles at the same time.

class MultiHeadAttentionWrapper(nn.Module):

    def __init__(
        self,
        d_in,
        d_out,
        context_length,
        dropout,
        num_heads,
        qkv_bias = False
    ):
        super().__init__()
        self.heads = nn.ModuleList(
            [CasualAttention(d_in, d_out, context_length, dropout, qkv_bias) for _ in range(num_heads)]
        )

    def forward(self, x):
        return torch.cat([head(x) for head in self.heads], dim=-1)

Here we test the multi-head attention module. We set a random seed, define input sizes, create the attention wrapper with two heads, and pass the batch through it. The printed output and shape help us confirm that the model is working correctly.

torch.manual_seed(123)

context_length = batch.shape[1]   # number of tokens
d_in, d_out = 3, 2

mha = MultiHeadAttentionWrapper(
    d_in, d_out, context_length, 0.0, num_heads=2
)
context_vecs = mha(batch)

print(context_vecs)
print(f"context vectors shape: {context_vecs.shape}")
tensor([[[-0.5337, -0.1051,  0.5085,  0.3508],
         [-0.5323, -0.1080,  0.5084,  0.3508],
         [-0.5323, -0.1079,  0.5084,  0.3506],
         [-0.5297, -0.1076,  0.5074,  0.3471],
         [-0.5311, -0.1066,  0.5076,  0.3446],
         [-0.5299, -0.1081,  0.5077,  0.3493]],

        [[-0.5337, -0.1051,  0.5085,  0.3508],
         [-0.5323, -0.1080,  0.5084,  0.3508],
         [-0.5323, -0.1079,  0.5084,  0.3506],
         [-0.5297, -0.1076,  0.5074,  0.3471],
         [-0.5311, -0.1066,  0.5076,  0.3446],
         [-0.5299, -0.1081,  0.5077,  0.3493]]], grad_fn=<CatBackward0>)
context vectors shape: torch.Size([2, 6, 4])

Implementing multi-head attention with weight splits

This implementation provides a stand-alone MultiHeadAttention module that achieves the same functionality as using multiple single-head causal attention blocks. Instead of concatenating the outputs of separate attention heads, we generate a single set of W_query, W_key, and W_value projection matrices, apply them to the input, and then split the resulting projections across multiple heads for parallel attention computation.

class MultiHeadAttention(nn.Module):
    def __init__(
        self,
        d_in,
        d_out,
        context_length,
        dropout,
        num_heads,
        qkv_bias = False
    ):

        super().__init__()
        assert (d_out % num_heads) == 0, "d_out must be divisible by num_heads"

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads

        # Linear layers for Q, K, V projections
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

        # linear layer to combine head outputs
        self.out_proj = nn.Linear(d_out, d_out)

        self.dropout = nn.Dropout(dropout)
        # Causal mask: allows attending only to past tokens
        self.register_buffer(
            "mask",
            torch.triu(torch.ones(context_length, context_length), diagonal=1)
        )


    def forward(self, x):

        b, num_tokens, d_in = x.shape
        keys = self.W_key(x)           # (b, num_tokens, d_out)
        queries = self.W_query(x)
        values = self.W_value(x)

         # (b, num_tokens, num_heads, head_dim)
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)     
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)

        # (b, num_heads, num_tokens, head_dim)
        keys = keys.transpose(1, 2)       
        values = values.transpose(1, 2)
        queries = queries.transpose(1, 2)

        # (b, num_heads, num_tokens, num_tokens)
        attention_scores = queries @ keys.transpose(2, 3) / keys.shape[-1]**0.5

        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
        attention_scores.masked_fill_(mask_bool, -torch.inf)

        attention_weights = torch.softmax(
            attention_scores / keys.shape[-1]**0.5,
            dim=-1
        )      # (b, num_heads, num_tokens, num_tokens)

        attention_weights = self.dropout(attention_weights)

        # (b, num_heads, num_tokens, head_dim) -> (b, num_tokens, num_heads, head_dim)
        context_vec = (attention_weights @ values).transpose(1, 2)  
        # (b, num_tokens, d_out)
        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)   
        context_vec = self.out_proj(context_vec)    # (b, num_tokens, d_out)

        return context_vec
torch.manual_seed(123)

batch_size, context_length, d_in = batch.shape
d_out = 2

mha = MultiHeadAttention(
    d_in, d_out, context_length, 0.0, num_heads=2
)
context_vecs = mha(batch)

print(context_vecs)
print(f"context vectors shape: {context_vecs.shape}")
tensor([[[0.3190, 0.4858],
         [0.2943, 0.3897],
         [0.2856, 0.3593],
         [0.2693, 0.3873],
         [0.2639, 0.3928],
         [0.2575, 0.4028]],

        [[0.3190, 0.4858],
         [0.2943, 0.3897],
         [0.2856, 0.3593],
         [0.2693, 0.3873],
         [0.2639, 0.3928],
         [0.2575, 0.4028]]], grad_fn=<ViewBackward0>)
context vectors shape: torch.Size([2, 6, 2])

If you prefer a more compact and optimized implementation, you can also use PyTorch’s built-in torch.nn.MultiheadAttention class.