Generate a list of predictability matrices using a causal transformer model
Source:R/tr_causal.R
causal_pred_mats.Rd
This function computes a list of matrices, where each matrix corresponds to a
unique group specified by the by
argument. Each matrix represents the
predictability of every token in the input text (x
) based on preceding
context, as evaluated by a causal transformer model.
Arguments
- x
A character vector of words, phrases, or texts to evaluate.
- by
A grouping variable indicating how texts are split into groups.
- sep
A string specifying how words are separated within contexts or groups. Default is
" "
. For languages that don't have spaces between words (e.g., Chinese), setsep = ""
.- log.p
Base of the logarithm used for the output predictability values. If
TRUE
(default), the natural logarithm (base e) is used. IfFALSE
, the raw probabilities are returned. Alternatively,log.p
can be set to a numeric value specifying the base of the logarithm (e.g.,2
for base-2 logarithms). To get surprisal in bits (rather than predictability), setlog.p = 1/2
.- sorted
When default FALSE it will retain the order of groups we are splitting by. When TRUE then sorted (according to
by
) list(s) are returned.- model
Name of a pre-trained model or folder. One should be able to use models based on "gpt2". See hugging face website.
- checkpoint
Folder of a checkpoint.
- add_special_tokens
Whether to include special tokens. It has the same default as the AutoTokenizer method in Python.
- decode
Logical. If
TRUE
, decodes the tokens into human-readable strings, handling special characters and diacritics. Default isFALSE
.- config_model
List with other arguments that control how the model from Hugging Face is accessed.
- config_tokenizer
List with other arguments that control how the tokenizer from Hugging Face is accessed.
- batch_size
Maximum size of the batch. Larger batches speed up processing but take more memory.
- ...
Currently not in use.
Details
The function splits the input x
into groups specified by the by
argument
and processes each group independently. For each group, the model computes
the predictability of each token in its vocabulary based on preceding
context.
Each matrix contains:
Rows representing the model's vocabulary.
Columns corresponding to tokens in the group (e.g., a sentence or paragraph).
By default, values in the matrices are the natural logarithm of word probabilities.
More details about causal models
A causal language model (also called GPT-like, auto-regressive, or decoder model) is a type of large language model usually used for text-generation that can predict the next word (or more accurately in fact token) based on a preceding context.
If not specified, the causal model used will be the one set in the global
option pangoling.causal.default
, this can be
accessed via getOption("pangoling.causal.default")
(by default
"gpt2"). To change the default option
use options(pangoling.causal.default = "newcausalmodel")
.
A list of possible causal models can be found in Hugging Face website.
Using the config_model
and config_tokenizer
arguments, it's possible to
control how the model and tokenizer from Hugging Face is accessed, see the
Python method
from_pretrained
for details.
In case of errors when a new model is run, check the status of https://status.huggingface.co/
See also
Other causal model functions:
causal_next_tokens_pred_tbl()
,
causal_words_pred()
Examples
data("df_sent")
df_sent
#> # A tidytable: 15 × 2
#> sent_n word
#> <int> <chr>
#> 1 1 The
#> 2 1 apple
#> 3 1 doesn't
#> 4 1 fall
#> 5 1 far
#> 6 1 from
#> 7 1 the
#> 8 1 tree.
#> 9 2 Don't
#> 10 2 judge
#> 11 2 a
#> 12 2 book
#> 13 2 by
#> 14 2 its
#> 15 2 cover.
list_of_mats <- causal_pred_mats(
x = df_sent$word,
by = df_sent$sent_n,
model = "gpt2"
)
#> Processing using causal model 'gpt2/' ...
#> Processing a batch of size 1 with 10 tokens.
#> Processing a batch of size 1 with 9 tokens.
# View the structure of the resulting list
list_of_mats |> str()
#> List of 2
#> $ 1: num [1:50257, 1:10] NA NA NA NA NA NA NA NA NA NA ...
#> ..- attr(*, "dimnames")=List of 2
#> .. ..$ : chr [1:50257] "!" "\"" "#" "$" ...
#> .. ..$ : chr [1:10] "The" "Ġapple" "Ġdoesn" "'t" ...
#> $ 2: num [1:50257, 1:9] NA NA NA NA NA NA NA NA NA NA ...
#> ..- attr(*, "dimnames")=List of 2
#> .. ..$ : chr [1:50257] "!" "\"" "#" "$" ...
#> .. ..$ : chr [1:9] "Don" "'t" "Ġjudge" "Ġa" ...
# Inspect the last rows of the first matrix
list_of_mats[[1]] |> tail()
#> The Ġapple Ġdoesn 't Ġfall Ġfar Ġfrom
#> ominated NA -15.142192 -19.096537 -32.80610 -28.53754 -25.08115 -27.33087
#> Ġregress NA -13.093204 -14.924685 -32.17484 -12.01535 -20.19420 -21.34752
#> ĠCollider NA -13.339488 -16.405062 -32.32004 -29.73265 -28.05914 -25.85092
#> Ġinformants NA -12.950152 -14.873056 -34.56336 -24.34810 -18.73824 -23.02469
#> Ġgazed NA -13.809768 -12.320757 -38.64273 -19.81046 -19.47638 -19.91819
#> <|endoftext|> NA -7.353133 -9.627634 -14.55747 -13.46111 -11.10686 -12.69976
#> Ġthe Ġtree .
#> ominated -20.71178 -21.72472 -22.202908
#> Ġregress -18.08488 -16.63838 -18.050945
#> ĠCollider -21.89429 -23.19034 -20.985302
#> Ġinformants -18.23528 -18.46630 -17.959438
#> Ġgazed -17.88233 -19.51767 -17.944866
#> <|endoftext|> -13.89558 -15.86696 -8.816073
# Inspect the last rows of the second matrix
list_of_mats[[2]] |> tail()
#> Don 't Ġjudge Ġa Ġbook Ġby Ġits
#> ominated NA -13.038710 -24.86986 -22.828133 -20.23553 -20.51216 -21.54437
#> Ġregress NA -14.332914 -12.26382 -14.384551 -14.45043 -18.78388 -17.94085
#> ĠCollider NA -15.067999 -16.55659 -18.564629 -17.95191 -20.40189 -22.59220
#> Ġinformants NA -16.096315 -18.12209 -14.583617 -18.38805 -21.28853 -16.55606
#> Ġgazed NA -15.309744 -18.21523 -18.643326 -17.84563 -20.48769 -17.91865
#> <|endoftext|> NA -6.514345 -10.62883 -9.849662 -12.23088 -10.47432 -12.96377
#> Ġcover .
#> ominated -25.93669 -25.934177
#> Ġregress -23.10072 -20.106548
#> ĠCollider -25.76248 -25.129421
#> Ġinformants -19.48792 -21.326099
#> Ġgazed -22.67084 -20.240665
#> <|endoftext|> -16.37127 -8.178905