,

Skip to contents

Whereas the vignette about GPT-2 presents a very popular way to calculate word probabilities using GPT-like models, masked models present an alternative, especially, when we just care about the final word following a certain context.

A masked language model (also called BERT-like, or encoder model) is a type of large language model that can be used to predict the content of a mask in a sentence. BERT is an example of a masked language model (see also Devlin et al. 2018).

First load the following packages:

library(pangoling)
library(tidytable) # fast alternative to dplyr

Notice the following potential pitfall. This would be a bad approach for making predictions in a masked model:

masked_tokens_pred_tbl("The apple doesn't fall far from the [MASK]")
#> Processing using masked model 'bert-base-uncased/' ...
#> # A tidytable: 30,522 × 4
#>    masked_sentence                            token     pred mask_n
#>    <chr>                                      <chr>    <dbl>  <int>
#>  1 The apple doesn't fall far from the [MASK] .      -0.0579      1
#>  2 The apple doesn't fall far from the [MASK] ;      -3.21        1
#>  3 The apple doesn't fall far from the [MASK] !      -4.83        1
#>  4 The apple doesn't fall far from the [MASK] ?      -5.33        1
#>  5 The apple doesn't fall far from the [MASK] ...    -7.84        1
#>  6 The apple doesn't fall far from the [MASK] |      -8.11        1
#>  7 The apple doesn't fall far from the [MASK] tree   -8.76        1
#>  8 The apple doesn't fall far from the [MASK] -      -9.69        1
#>  9 The apple doesn't fall far from the [MASK] '      -9.87        1
#> 10 The apple doesn't fall far from the [MASK] :     -10.5         1
#> # ℹ 30,512 more rows

(The pretrained models and tokenizers will be downloaded from https://huggingface.co/ the first time they are used.)

The most common predictions are punctuation marks, because BERT uses the left and right context. In this case, the right context indicates that the mask is the final token of the sentence. More expected results are obtained in the following way:

masked_tokens_pred_tbl("The apple doesn't fall far from the [MASK].")
#> Processing using masked model 'bert-base-uncased/' ...
#> # A tidytable: 30,522 × 4
#>    masked_sentence                             token     pred mask_n
#>    <chr>                                       <chr>    <dbl>  <int>
#>  1 The apple doesn't fall far from the [MASK]. tree    -0.691      1
#>  2 The apple doesn't fall far from the [MASK]. ground  -1.98       1
#>  3 The apple doesn't fall far from the [MASK]. sky     -2.13       1
#>  4 The apple doesn't fall far from the [MASK]. table   -4.02       1
#>  5 The apple doesn't fall far from the [MASK]. floor   -4.31       1
#>  6 The apple doesn't fall far from the [MASK]. top     -4.48       1
#>  7 The apple doesn't fall far from the [MASK]. ceiling -4.62       1
#>  8 The apple doesn't fall far from the [MASK]. window  -4.87       1
#>  9 The apple doesn't fall far from the [MASK]. trees   -4.94       1
#> 10 The apple doesn't fall far from the [MASK]. apple   -4.95       1
#> # ℹ 30,512 more rows

We can mask several tokens as well (but bear in mind that this type of models are trained with only 10-15% of masks):

df_masks <- 
  masked_tokens_pred_tbl("The apple doesn't fall far from the [MASK][MASK]")
#> Processing using masked model 'bert-base-uncased/' ...
df_masks |> filter(mask_n == 1)
#> # A tidytable: 30,522 × 4
#>    masked_sentence                                  token     pred mask_n
#>    <chr>                                            <chr>    <dbl>  <int>
#>  1 The apple doesn't fall far from the [MASK][MASK] tree    -0.738      1
#>  2 The apple doesn't fall far from the [MASK][MASK] ground  -1.72       1
#>  3 The apple doesn't fall far from the [MASK][MASK] sky     -2.31       1
#>  4 The apple doesn't fall far from the [MASK][MASK] table   -3.67       1
#>  5 The apple doesn't fall far from the [MASK][MASK] floor   -4.47       1
#>  6 The apple doesn't fall far from the [MASK][MASK] top     -4.67       1
#>  7 The apple doesn't fall far from the [MASK][MASK] ceiling -4.89       1
#>  8 The apple doesn't fall far from the [MASK][MASK] window  -5.02       1
#>  9 The apple doesn't fall far from the [MASK][MASK] bush    -5.02       1
#> 10 The apple doesn't fall far from the [MASK][MASK] vine    -5.03       1
#> # ℹ 30,512 more rows
df_masks |> filter(mask_n == 2)
#> # A tidytable: 30,522 × 4
#>    masked_sentence                                  token     pred mask_n
#>    <chr>                                            <chr>    <dbl>  <int>
#>  1 The apple doesn't fall far from the [MASK][MASK] .      -0.0570      2
#>  2 The apple doesn't fall far from the [MASK][MASK] ;      -2.91        2
#>  3 The apple doesn't fall far from the [MASK][MASK] !      -7.33        2
#>  4 The apple doesn't fall far from the [MASK][MASK] ?      -9.09        2
#>  5 The apple doesn't fall far from the [MASK][MASK] ...   -11.9         2
#>  6 The apple doesn't fall far from the [MASK][MASK] ,     -12.4         2
#>  7 The apple doesn't fall far from the [MASK][MASK] -     -12.8         2
#>  8 The apple doesn't fall far from the [MASK][MASK] |     -13.3         2
#>  9 The apple doesn't fall far from the [MASK][MASK] so    -13.4         2
#> 10 The apple doesn't fall far from the [MASK][MASK] :     -13.9         2
#> # ℹ 30,512 more rows

We can also use BERT to examine the predictability of words assuming that both the left and right contexts are known:

(df_sent <- data.frame(
  left = c("The", "The"),
  critical = c("apple", "pear"),
  right = c(
    "doesn't fall far from the tree.",
    "doesn't fall far from the tree."
  )
))
#>   left critical                           right
#> 1  The    apple doesn't fall far from the tree.
#> 2  The     pear doesn't fall far from the tree.

The function masked_targets_pred() will give us the log-probability of the target word (and will take care of summing the log-probabilities in case the target is composed by several tokens).

df_sent <- df_sent %>%
  mutate(lp = masked_targets_pred(
    prev_contexts = left,
    targets = critical,
    after_contexts = right
  ))
#> Processing using masked model 'bert-base-uncased/' ...
#> Processing 1 batch(es) of 13 tokens.
#> The [apple] doesn't fall far from the tree.
#> Processing 1 batch(es) of 13 tokens.
#> The [pear] doesn't fall far from the tree.
#> ***
df_sent
#> # A tidytable: 2 × 4
#>   left  critical right                              lp
#>   <chr> <chr>    <chr>                           <dbl>
#> 1 The   apple    doesn't fall far from the tree. -4.68
#> 2 The   pear     doesn't fall far from the tree. -8.60

As expected (given the popularity of the proverb), “apple” is a more likely target word than “pear”.

References

Devlin, Jacob, Ming-Wei Chang, Kenton Lee, and Kristina Toutanova. 2018. BERT: Pre-Training of Deep Bidirectional Transformers for Language Understanding.” arXiv. https://doi.org/10.48550/ARXIV.1810.04805.