The sbo
package provides utilities for building and evaluating next-word prediction functions based on Stupid Back-off N-gram models in R. In this vignette, I illustrate the functions and classes exported by sbo
, the typical workflow for building a text predictor from a given training corpus, and the evaluation of next-word predictions through a test corpus. In the last section, I list some upcoming features in a future version of sbo
.
The sbo
package pivots around two (S3) object classes:
kgram_freqs
: A collection of \(k\)-gram frequency tables, with \(k\) up to a given order \(N\).sbo_preds
: A collection of tables employed to store and retrieve next-word predictions in a compact and efficient way.The functions get_word_freqs
and get_kgram_freqs
are used to extract word and \(k\)-gram frequency tables from a training corpus, and the function build_sbo_preds
constructs a next-word prediction table from a kgram_freqs
object. I illustrate the entire process of building a text-prediction function from a training corpus in the next section.
sbo
In this and the next section we will employ the twitter_train
and twitter_test
example datasets, included in sbo
for illustrative purpose:
These are small samples of \(10^5\) and \(10^4\) entries, respectively, from the “Tweets” Swiftkey dataset fully available here. Each entry consists of a single tweet in English, e.g.:
head(train, 3)
#> [1] "Just realized that Cedar Block is equal parts nutjob conspiracy theorist and pragmatic skeptic. Which side will win? Stay tuned."
#> [2] "Doesn't get any stricter than a book set in the past!"
#> [3] "Hunger Games! So excited! Want go!"
The prototypical workflow for building a text-predictor in sbo
goes as follows:
Step 0 (optional). Build a dictionary from training set, typically keeping the top \(V\) most frequent words:
# N.B.: get_word_freqs(train) returns a tibble with a 'word' column
# and a 'counts' column, sorted by decreasing counts.
dict <- get_word_freqs(train) %>% names %>% .[1:1000]
head(dict)
#> [1] "the" "to" "i" "a" "you" "and"
Alternatively, one may use a predefined dictionary.
Step 1. Get \(k\)-gram frequencies from training corpus:
(freqs <- get_kgram_freqs(train, N = 3, dict)) # 'N' is the order of n-grams
#> k-gram frequency table
#>
#> Order (N): 3
#> Dictionary size: 1000 words
#>
#> # of unique 1-grams: 1002
#> # of unique 2-grams: 85677
#> # of unique 3-grams: 318410
#>
#> Object size: 5.9 Mb
#>
#> See ?get_kgram_freqs for help.
Step 2. Build next-word prediction tables:
( sbo <- build_sbo_preds(freqs) )
#> Next-word prediction table for Stupid Back-off n-gram model
#>
#> Order (N): 3
#> Dictionary size: 1000 words
#> Back-off penalization (lambda): 0.4
#> Maximum number of predictions (L): 3
#>
#> Object size: 1.5 Mb
#>
#> See ?build_sbo_preds, ?predict.sbo_preds for help.
At this point we can predict next words from our model, by using predict
(see ?predict.sbo_preds
for help on the relevant predict
method):
predict(sbo, "i love") # a character vector
#> [1] "you" "it" "the"
predict(sbo, c("Colorless green ideas sleep", "See you")) # a char matrix
#> [,1] [,2] [,3]
#> [1,] "<EOS>" "in" "and"
#> [2,] "there" "<EOS>" "at"
Last, but not least, we can employ our model for generating some beautiful non-sense:
set.seed(840)
babble(sbo)
#> [1] "who's ready."
babble(sbo)
#> [1] "isn't it."
babble(sbo)
#> [1] "news is welcome and best."
If we wish to save the frequency tables, or the final prediction tables, and reload them in a future session, we can easily do this through save
/load
, e.g.
For convenience, the objects created in this section are preloaded in sbo
as twitter_dict
, twitter_freqs
and twitter_sbo
.
At the present stage, both get_word_freqs
and get_kgram_freqs
employ internal functions for text preprocessing and tokenization. Preprocessing consists of the following steps, in this order:
Words (including the Begin/End-Of-Sentence tokens) are thus tokenized by splitting sentences in correspondence of space. In get_kgram_freqs
, each out-of-vocabulary word is replaced by an unknown word token.
Once we have built our next-word predictor, we may want to directly test its predictions on an independent corpus. For this purpose, sbo
offers the function eval_sbo_preds
, which performs the following test:
As a concrete example, we test the text-predictor trained in the previous section over the Twitter (independent) test set.
set.seed(840)
(eval <- eval_sbo_preds(sbo, test))
#> # A tibble: 18,497 x 4
#> input true preds[,1] [,2] [,3] correct
#> <chr> <chr> <chr> <chr> <chr> <lgl>
#> 1 "oh hey" shirtless a <EOS> i'm FALSE
#> 2 " " how i <EOS> thanks FALSE
#> 3 " ah" no <EOS> yes i FALSE
#> 4 "he estudiado" <EOS> <EOS> the it TRUE
#> 5 "nada d" <EOS> <EOS> from project TRUE
#> 6 "mama no" esta <EOS> more matter FALSE
#> 7 "ya mean" <EOS> <EOS> to i TRUE
#> 8 "tennis the" scoring word same best FALSE
#> 9 " thanks" for for <EOS> to TRUE
#> 10 "concert wasn't" over a even that FALSE
#> # … with 18,487 more rows
As it is seen, eval_sbo_preds
returns a tibble containing the input \((N-1)\)-grams, the true completions, the predicted completions and a column indicating whether one of the predictions were correct or not.
We can estimate predictive accuracy as follows (the uncertainty in the estimate is approximated by the binomial formula \(\sigma = \sqrt{\frac{p(1-p)}{M}}\), where \(M\) is the number of trials):
eval %>% summarise(accuracy = sum(correct)/n(),
uncertainty = sqrt( accuracy*(1-accuracy) / n() )
)
#> # A tibble: 1 x 2
#> accuracy uncertainty
#> <dbl> <dbl>
#> 1 0.344 0.00349
We may want to exclude from the test \(N\)-grams ending by the End-Of-Sentence token (here represented by "."
):
eval %>% # Accuracy for in-sentence predictions
filter(true != ".") %>%
summarise(accuracy = sum(correct)/n(),
uncertainty = sqrt( accuracy*(1-accuracy) / n() )
)
#> # A tibble: 1 x 2
#> accuracy uncertainty
#> <dbl> <dbl>
#> 1 0.344 0.00349
In trying to reduce the size (in physical memory) of your text-predictor, it might be useful to prune the model dictionary. The following command plots an histogram of the distribution of correct predictions in our test.
if (require(ggplot2)) {
eval %>%
filter(correct, true != ".") %>%
transmute(rank = match(true, table = sbo$dict)) %>%
ggplot(aes(x = rank)) + geom_histogram(binwidth = 25)
}
#> Loading required package: ggplot2
#> Warning: Removed 3471 rows containing non-finite values (stat_bin).
Apparently, the large majority of correct predictions come from the first ~ 300 words of the dictionary, so that if we prune the dictionary excluding words with rank greater than, e.g., 500 we can reduce the size of our model without seriously affecting its prediction accuracy.