The hardware and bandwidth for this mirror is donated by dogado GmbH, the Webhosting and Full Service-Cloud Provider. Check out our Wordpress Tutorial.
If you wish to report a bug, or if you are interested in having us mirror your free-software or open-source project, please feel free to contact us at mirror[@]dogado.de.
This vignette describes how to implement the attention mechanism - which forms the basis of transformers - in the R language.
We begin by generating encoder representations of four different words.
# encoder representations of four different words
word_1 = matrix(c(1,0,0), nrow=1)
word_2 = matrix(c(0,1,0), nrow=1)
word_3 = matrix(c(1,1,0), nrow=1)
word_4 = matrix(c(0,0,1), nrow=1)
Next, we stack the word embeddings into a single array (in this case
a matrix) which we call words
.
Let’s see what this looks like.
Next, we generate random integers on the domain
[0,3]
.
# initializing the weight matrices (with random values)
set.seed(0)
W_Q = matrix(floor(runif(9, min=0, max=3)),nrow=3,ncol=3)
W_K = matrix(floor(runif(9, min=0, max=3)),nrow=3,ncol=3)
W_V = matrix(floor(runif(9, min=0, max=3)),nrow=3,ncol=3)
Next, we generate the Queries (Q
), Keys
(K
), and Values (V
). The %*%
operator performs the matrix multiplication. You can view the R help
page using help('%*%')
(or the online An
Introduction to R).
Following this, we score the Queries (Q
) against the Key
(K
) vectors (which are transposed for the multiplation
using t()
, see help('t')
for more info).
# scoring the query vectors against all key vectors
scores = Q %*% t(K)
print(scores)
#> [,1] [,2] [,3] [,4]
#> [1,] 6 4 10 5
#> [2,] 4 6 10 6
#> [3,] 10 10 20 11
#> [4,] 3 1 4 2
We now generate the weights
matrix.
Let’s have a look at the weights
matrix.
print(weights)
#> [,1] [,2] [,3] [,4]
#> [1,] -0.2986355 -2.6877197 4.479533 -1.4931776
#> [2,] -3.1208558 -0.6241712 4.369198 -0.6241712
#> [3,] -1.7790165 -1.7790165 4.690134 -1.1321014
#> [4,] 1.2167336 -3.6502008 3.650201 -1.2167336
Finally, we compute the attention
as a weighted sum of
the value vectors (which are combined in the matrix V
).
Now we can view the results using:
print(attention)
#> [,1] [,2] [,3]
#> [1,] 7.167252 6.868617 -1.4931776
#> [2,] 4.993369 1.872514 -0.6241712
#> [3,] 6.469151 4.690134 -1.1321014
#> [4,] 7.300402 8.517135 -1.2167336
These binaries (installable software) and packages are in development.
They may not be fully stable and should be used with caution. We make no claims about them.
Health stats visible at Monitor.