3 minute read

EDIT 2022-06-24: this code is now available (with helper functions) in the R package attention, which is on CRAN. You can install it simply using:

install.packages('attention')

See also my blog post attention on CRAN. The development takes place on GitHub.


This post describes how to implement the attention mechanism - which forms the basis of transformers - in the R language.

The code is translated from the Python original by Stefania Cristina (University of Malta) in her post The Attention Mechanism from Scratch

We begin by generating encoder representations of four different words.

# encoder representations of four different words
word_1 = matrix(c(1,0,0), nrow=1)
word_2 = matrix(c(0,1,0), nrow=1)
word_3 = matrix(c(1,1,0), nrow=1)
word_4 = matrix(c(0,0,1), nrow=1)

Next, we stack the word embeddings into a single array (in this case a matrix).

# stacking the word embeddings into a single array
words = rbind(word_1,
              word_2,
              word_3,
              word_4)

Next, we generate random integers on the domain [0,3].

# generating the weight matrices
set.seed(0)
W_Q = matrix(floor(runif(9, min=0, max=3)),nrow=3,ncol=3)
W_K = matrix(floor(runif(9, min=0, max=3)),nrow=3,ncol=3)
W_V = matrix(floor(runif(9, min=0, max=3)),nrow=3,ncol=3)

In order to keep the numbers the same as in the original Python code, you can overwrite the randomly generated values with the values as they were generated by Python.

# redefine matrices to match random numbers generated by Python in the original code
W_Q = matrix(c(2,0,2,
               2,0,0,
               2,1,2),
             nrow=3,
             ncol=3,
             byrow = TRUE)
W_K = matrix(c(2,2,2,
               0,2,1,
               0,1,1),
             nrow=3,
             ncol=3,
             byrow = TRUE)
W_V = matrix(c(1,1,0,
               0,1,1,
               0,0,0),
             nrow=3,
             ncol=3,
             byrow = TRUE)

Next, we generate the Queries (Q), Keys (K), and Values (V). The %*% operator performs the matrix multiplication. You can view the R help page using help('%*%').

# generating the queries, keys and values
Q = words %*% W_Q
K = words %*% W_K
V = words %*% W_V

Following this, we score the Queries (Q) against the Key (K) vectors.

# scoring the query vectors against all key vectors
scores = Q %*% t(K)
print(scores)
     [,1] [,2] [,3] [,4]
[1,]    8    2   10    2
[2,]    4    0    4    0
[3,]   12    2   14    2
[4,]   10    4   14    3

We now need to find the maximum value for each row of the scores matrix. We can do this by using the apply() (see help('apply')) with the max() function on margin=1 (i.e. rows). Don’t worry too much about how this works, the key takeaway is that we find the maximum for each row (and using the wrapping in as.matrix() we keep the maxima on their corresponding rows in the new maxs matrix.

maxs = as.matrix(apply(scores, margin=1, max))
print(maxs)
     [,1]
[1,]   10
[2,]    4
[3,]   14
[4,]   14

As you can see, the value for each row in maxs is the maximum value of the corresponding row in scores.

The weights matrix will be populated using a for loop (see help('for')). Since the loop does not edit the dimensions of the matrix, we generate a zero matrix (i.e. all values are set to 0) beforehand, which we then populate using the for loop.

# initialize weights matrix
weights = matrix(0, nrow=4, ncol=4)

We now populate the weights matrix using the for loop.

# computing the weights by a softmax operation
for (i in 1:dim(scores)[1]) {
  weights[i,] = exp((scores[i,]-maxs[i,]) / ncol(K) ^ 0.5)/sum(exp((scores[i,]-maxs[i,]) / ncol(K) ^ 0.5))
}

Finally, we compute the attention as a weighted sum of the value vectors.

# computing the attention by a weighted sum of the value vectors
attention = weights %*% V

Now we can view the results using:

print(attention)

This gives:

          [,1]     [,2]      [,3]
[1,] 0.9852202 1.741741 0.7565203
[2,] 0.9096526 1.409653 0.5000000
[3,] 0.9985123 1.758493 0.7599811
[4,] 0.9956039 1.904073 0.9084692

As you can see, these are the same values as those computed in Python in the original post.

The complete code is also available as a Gist on GitHub.