Implementing Self-Attention from Scratch in PyTorch

Mohd Faraaz
4 min readAug 31, 2024

--

In this article we will see the step by step tutorial of the self attention mechanism which is at the heart of the transformer architecture which has revolutionized the world with LLMs. This this can help you in understanding LLMs text generation mechanism better.

Introduction to Self-Attention

Self-attention, often referred to as scaled dot-product attention, is a mechanism that allows a model to weigh the importance of different parts of the input data when processing each element. It’s a crucial component in transformer models, enabling them to capture dependencies regardless of their distance in the input sequence. This ability makes self-attention particularly powerful for tasks like machine translation, text generation, and more.

How Does Self-Attention Work?

The self-attention mechanism operates in the following steps:

  1. Create Query, Key, and Value Vectors: Each input token is transformed into three vectors: a query vector, a key vector, and a value vector. These vectors are derived from the original embeddings of the tokens.
  2. Calculate Attention Scores: The attention scores are computed by taking the dot product of the query vector of a particular token with the key vectors of all tokens in the sequence. These scores measure the relevance of other tokens to the current token.
  3. Apply Softmax: The attention scores are normalized using the softmax function to get the attention weights. This ensures that the weights sum to 1, highlighting the most relevant tokens.
  4. Generate Context Vectors: Finally, each token’s value vector is weighted by the corresponding attention weight, and these weighted vectors are summed to produce a context vector for each token.

Implementing Self-Attention from Scratch

We will implement a simple self-attention mechanism using PyTorch. For this walkthrough, we’ll use the sentence "The quick brown fox jumps over a lazy dog" and follow through each step of the self-attention process.

1. Data Preparation

First, we convert the sentence into a list of integers, where each unique word is assigned a unique index.

sentence = 'The quick brown fox jumps over a lazy dog'
dc = {s: i for i, s in enumerate(sorted(sentence.replace(',', '').split()))}
print(dc)

Output:

{'The': 0, 'a': 1, 'brown': 2, 'dog': 3, 'fox': 4, 'jumps': 5, 'lazy': 6, 'over': 7, 'quick': 8}

Now, we map the sentence to a tensor of these integer indices.

r = [dc[i] for i in sentence.replace(',', '').split()]
sentence_int = torch.tensor(r)
print(sentence_int)

Output:

tensor([0, 8, 2, 4, 5, 7, 1, 6, 3])

2. Embedding the Sentence

We embed the integer indices into vectors using an embedding layer.

import torch
import torch.nn as nn
vocab_size = 50000  # Assume a large vocabulary size
torch.manual_seed(123)
embed = nn.Embedding(vocab_size, 3)
embedded_sentence = embed(sentence_int).detach()
print(embedded_sentence)
print(embedded_sentence.

Output:

tensor([[-0.4512, -1.0873,  0.0702],
[-0.1546, -0.5932, 1.3215],
[-1.4360, 1.1305, -1.1885],
[ 0.8595, 0.1979, -1.2390],
[-0.3320, 1.8563, 0.0967],
[-1.7215, -1.0282, 0.0643],
[-0.6843, 0.8705, -0.7858],
[-1.3713, 0.4267, -0.5338],
[-0.6796, 0.2673, 1.2914]])
torch.Size([9, 3])

The sentence is now represented as a 9x3 matrix, with each word embedded into a 3-dimensional vector.

3. Self-Attention Mechanism

We transform the embedded sentence to apply self-attention, starting by creating query, key, and value matrices.

torch.manual_seed(123)
d = embedded_sentence.shape[1] # Dimension of embeddings
d_q, d_k, d_v = 2, 2, 4 # Dimensions for query, key, and value matrices
W_query = torch.nn.Parameter(torch.rand(d, d_q))
W_key = torch.nn.Parameter(torch.rand(d, d_k))
W_value = torch.nn.Parameter(torch.rand(d, d_v))
query = embedded_sentence @ W_query
key = embedded_sentence @ W_key
value = embedded_sentence @ W_value

4. Calculating Attention Scores

The attention scores are calculated as the dot product of query and key matrices, followed by scaling.

import math
import torch.nn.functional as F
attention_scores = query @ key.T
attention_scores = attention_scores / math.sqrt(d_k)
attention_weights = F.softmax(attention_scores, dim=-1)

Output:

tensor([[1.0000, 0.0000, ...],
[0.0000, 1.0000, ...],
...
[0.0000, 0.0000, ...]])

5. Generating Context Vectors

Finally, we use the attention weights to compute the context vectors.

context_vector = attention_weights @ value
print(context_vector)

Output:

tensor([[ 0.0630, -0.3225, -0.1370, 0.3293],
[-0.2525, -0.0768, -0.3272, -0.3774],
...
[-0.2211, 0.2421, -0.4866, 0.1577]])

6. Encapsulating in a Module

Let’s encapsulate the entire self-attention mechanism in a PyTorch module.

class SelfAttention(nn.Module):
def __init__(self, d, d_q, d_k, d_v):
super(SelfAttention, self).__init__()
self.d = d
self.d_q = d_q
self.d_k = d_k
self.d_v = d_v
self.W_query = nn.Parameter(torch.rand(d, d_q))
self.W_key = nn.Parameter(torch.rand(d, d_k))
self.W_value = nn.Parameter(torch.rand(d, d_v))
def forward(self, x):
Q = x @ self.W_query
K = x @ self.W_key
V = x @ self.W_value
attention_scores = Q @ K.T / math.sqrt(self.d_k)
attention_weights = F.softmax(attention_scores, dim=-1)
context_vector = attention_weights @ V
return context_vector

7. Using the Self-Attention Module

We instantiate our SelfAttention module and pass the embedded sentence through it.

sa = SelfAttention(d=3, d_q=2, d_k=2, d_v=4)
cv = sa(embedded_sentence)
print(cv.shape)
print(cv)

Output:

torch.Size([9, 4])
tensor([[-0.2494, -0.3343, -0.0018, -0.2718],
[ 0.0908, -0.1548, 0.1734, 0.0008],
...
[ 0.1173, -0.1928, -0.1367, -0.0935]])

Conclusion

In this article, we’ve built a self-attention mechanism from scratch using PyTorch. We followed a structured approach:

  1. Data Preparation: Converted the sentence into numerical indices.
  2. Embedding: Mapped words to continuous vectors.
  3. Self-Attention Mechanism: Created query, key, and value matrices, and calculated attention scores and context vectors.
  4. Encapsulation: Encapsulated the process in a PyTorch module.

This simple implementation highlights the core concepts of self-attention, which are integral to transformer models used in advanced NLP tasks. Understanding this mechanism is foundational for diving deeper into modern deep learning architectures.

Feel free to experiment with different dimensions and input data to see how the self-attention mechanism adapts and scales.

--

--

Responses (1)