Using a Bert model to get the predictability of words in their context
Source:vignettes/articles/intro-bert.Rmd
intro-bert.Rmd
Whereas the vignette about GPT-2 presents a very popular way to calculate word probabilities using GPT-like models, masked models present an alternative, especially, when we just care about the final word following a certain context.
A masked language model (also called BERT-like, or encoder model) is a type of large language model that can be used to predict the content of a mask in a sentence. BERT is an example of a masked language model (see also Devlin et al. 2018).
First load the following packages:
Notice the following potential pitfall. This would be a bad approach for making predictions in a masked model:
masked_tokens_pred_tbl("The apple doesn't fall far from the [MASK]")
#> Processing using masked model 'bert-base-uncased/' ...
#> # A tidytable: 30,522 × 4
#> masked_sentence token pred mask_n
#> <chr> <chr> <dbl> <int>
#> 1 The apple doesn't fall far from the [MASK] . -0.0579 1
#> 2 The apple doesn't fall far from the [MASK] ; -3.21 1
#> 3 The apple doesn't fall far from the [MASK] ! -4.83 1
#> 4 The apple doesn't fall far from the [MASK] ? -5.33 1
#> 5 The apple doesn't fall far from the [MASK] ... -7.84 1
#> 6 The apple doesn't fall far from the [MASK] | -8.11 1
#> 7 The apple doesn't fall far from the [MASK] tree -8.76 1
#> 8 The apple doesn't fall far from the [MASK] - -9.69 1
#> 9 The apple doesn't fall far from the [MASK] ' -9.87 1
#> 10 The apple doesn't fall far from the [MASK] : -10.5 1
#> # ℹ 30,512 more rows
(The pretrained models and tokenizers will be downloaded from https://huggingface.co/ the first time they are used.)
The most common predictions are punctuation marks, because BERT uses the left and right context. In this case, the right context indicates that the mask is the final token of the sentence. More expected results are obtained in the following way:
masked_tokens_pred_tbl("The apple doesn't fall far from the [MASK].")
#> Processing using masked model 'bert-base-uncased/' ...
#> # A tidytable: 30,522 × 4
#> masked_sentence token pred mask_n
#> <chr> <chr> <dbl> <int>
#> 1 The apple doesn't fall far from the [MASK]. tree -0.691 1
#> 2 The apple doesn't fall far from the [MASK]. ground -1.98 1
#> 3 The apple doesn't fall far from the [MASK]. sky -2.13 1
#> 4 The apple doesn't fall far from the [MASK]. table -4.02 1
#> 5 The apple doesn't fall far from the [MASK]. floor -4.31 1
#> 6 The apple doesn't fall far from the [MASK]. top -4.48 1
#> 7 The apple doesn't fall far from the [MASK]. ceiling -4.62 1
#> 8 The apple doesn't fall far from the [MASK]. window -4.87 1
#> 9 The apple doesn't fall far from the [MASK]. trees -4.94 1
#> 10 The apple doesn't fall far from the [MASK]. apple -4.95 1
#> # ℹ 30,512 more rows
We can mask several tokens as well (but bear in mind that this type of models are trained with only 10-15% of masks):
df_masks <-
masked_tokens_pred_tbl("The apple doesn't fall far from the [MASK][MASK]")
#> Processing using masked model 'bert-base-uncased/' ...
df_masks |> filter(mask_n == 1)
#> # A tidytable: 30,522 × 4
#> masked_sentence token pred mask_n
#> <chr> <chr> <dbl> <int>
#> 1 The apple doesn't fall far from the [MASK][MASK] tree -0.738 1
#> 2 The apple doesn't fall far from the [MASK][MASK] ground -1.72 1
#> 3 The apple doesn't fall far from the [MASK][MASK] sky -2.31 1
#> 4 The apple doesn't fall far from the [MASK][MASK] table -3.67 1
#> 5 The apple doesn't fall far from the [MASK][MASK] floor -4.47 1
#> 6 The apple doesn't fall far from the [MASK][MASK] top -4.67 1
#> 7 The apple doesn't fall far from the [MASK][MASK] ceiling -4.89 1
#> 8 The apple doesn't fall far from the [MASK][MASK] window -5.02 1
#> 9 The apple doesn't fall far from the [MASK][MASK] bush -5.02 1
#> 10 The apple doesn't fall far from the [MASK][MASK] vine -5.03 1
#> # ℹ 30,512 more rows
df_masks |> filter(mask_n == 2)
#> # A tidytable: 30,522 × 4
#> masked_sentence token pred mask_n
#> <chr> <chr> <dbl> <int>
#> 1 The apple doesn't fall far from the [MASK][MASK] . -0.0570 2
#> 2 The apple doesn't fall far from the [MASK][MASK] ; -2.91 2
#> 3 The apple doesn't fall far from the [MASK][MASK] ! -7.33 2
#> 4 The apple doesn't fall far from the [MASK][MASK] ? -9.09 2
#> 5 The apple doesn't fall far from the [MASK][MASK] ... -11.9 2
#> 6 The apple doesn't fall far from the [MASK][MASK] , -12.4 2
#> 7 The apple doesn't fall far from the [MASK][MASK] - -12.8 2
#> 8 The apple doesn't fall far from the [MASK][MASK] | -13.3 2
#> 9 The apple doesn't fall far from the [MASK][MASK] so -13.4 2
#> 10 The apple doesn't fall far from the [MASK][MASK] : -13.9 2
#> # ℹ 30,512 more rows
We can also use BERT to examine the predictability of words assuming that both the left and right contexts are known:
(df_sent <- data.frame(
left = c("The", "The"),
critical = c("apple", "pear"),
right = c(
"doesn't fall far from the tree.",
"doesn't fall far from the tree."
)
))
#> left critical right
#> 1 The apple doesn't fall far from the tree.
#> 2 The pear doesn't fall far from the tree.
The function masked_targets_pred()
will give us the
log-probability of the target word (and will take care of summing the
log-probabilities in case the target is composed by several tokens).
df_sent <- df_sent %>%
mutate(lp = masked_targets_pred(
prev_contexts = left,
targets = critical,
after_contexts = right
))
#> Processing using masked model 'bert-base-uncased/' ...
#> Processing 1 batch(es) of 13 tokens.
#> The [apple] doesn't fall far from the tree.
#> Processing 1 batch(es) of 13 tokens.
#> The [pear] doesn't fall far from the tree.
#> ***
df_sent
#> # A tidytable: 2 × 4
#> left critical right lp
#> <chr> <chr> <chr> <dbl>
#> 1 The apple doesn't fall far from the tree. -4.68
#> 2 The pear doesn't fall far from the tree. -8.60
As expected (given the popularity of the proverb), “apple” is a more likely target word than “pear”.