,

Skip to contents

This function computes a list of matrices, where each matrix corresponds to a unique group specified by the by argument. Each matrix represents the predictability of every token in the input text (x) based on preceding context, as evaluated by a causal transformer model.

Usage

causal_pred_mats(
  x,
  by = rep(1, length(x)),
  sep = " ",
  log.p = getOption("pangoling.log.p"),
  sorted = FALSE,
  model = getOption("pangoling.causal.default"),
  checkpoint = NULL,
  add_special_tokens = NULL,
  decode = FALSE,
  config_model = NULL,
  config_tokenizer = NULL,
  batch_size = 1,
  ...
)

Arguments

x

A character vector of words, phrases, or texts to evaluate.

by

A grouping variable indicating how texts are split into groups.

sep

A string specifying how words are separated within contexts or groups. Default is " ". For languages that don't have spaces between words (e.g., Chinese), set sep = "".

log.p

Base of the logarithm used for the output predictability values. If TRUE (default), the natural logarithm (base e) is used. If FALSE, the raw probabilities are returned. Alternatively, log.p can be set to a numeric value specifying the base of the logarithm (e.g., 2 for base-2 logarithms). To get surprisal in bits (rather than predictability), set log.p = 1/2.

sorted

When default FALSE it will retain the order of groups we are splitting by. When TRUE then sorted (according to by) list(s) are returned.

model

Name of a pre-trained model or folder. One should be able to use models based on "gpt2". See hugging face website.

checkpoint

Folder of a checkpoint.

add_special_tokens

Whether to include special tokens. It has the same default as the AutoTokenizer method in Python.

decode

Logical. If TRUE, decodes the tokens into human-readable strings, handling special characters and diacritics. Default is FALSE.

config_model

List with other arguments that control how the model from Hugging Face is accessed.

config_tokenizer

List with other arguments that control how the tokenizer from Hugging Face is accessed.

batch_size

Maximum size of the batch. Larger batches speed up processing but take more memory.

...

Currently not in use.

Value

A list of matrices with tokens in their columns and the vocabulary of the model in their rows

Details

The function splits the input x into groups specified by the by argument and processes each group independently. For each group, the model computes the predictability of each token in its vocabulary based on preceding context.

Each matrix contains:

  • Rows representing the model's vocabulary.

  • Columns corresponding to tokens in the group (e.g., a sentence or paragraph).

  • By default, values in the matrices are the natural logarithm of word probabilities.

More details about causal models

A causal language model (also called GPT-like, auto-regressive, or decoder model) is a type of large language model usually used for text-generation that can predict the next word (or more accurately in fact token) based on a preceding context.

If not specified, the causal model used will be the one set in the global option pangoling.causal.default, this can be accessed via getOption("pangoling.causal.default") (by default "gpt2"). To change the default option use options(pangoling.causal.default = "newcausalmodel").

A list of possible causal models can be found in Hugging Face website.

Using the config_model and config_tokenizer arguments, it's possible to control how the model and tokenizer from Hugging Face is accessed, see the Python method from_pretrained for details.

In case of errors when a new model is run, check the status of https://status.huggingface.co/

See also

Other causal model functions: causal_next_tokens_pred_tbl(), causal_words_pred()

Examples

data("df_sent")
df_sent
#> # A tidytable: 15 × 2
#>    sent_n word   
#>     <int> <chr>  
#>  1      1 The    
#>  2      1 apple  
#>  3      1 doesn't
#>  4      1 fall   
#>  5      1 far    
#>  6      1 from   
#>  7      1 the    
#>  8      1 tree.  
#>  9      2 Don't  
#> 10      2 judge  
#> 11      2 a      
#> 12      2 book   
#> 13      2 by     
#> 14      2 its    
#> 15      2 cover. 
list_of_mats <- causal_pred_mats(
                       x = df_sent$word,
                       by = df_sent$sent_n,  
                       model = "gpt2"
                )
#> Processing using causal model 'gpt2/' ...
#> Processing a batch of size 1 with 10 tokens.
#> Processing a batch of size 1 with 9 tokens.

# View the structure of the resulting list
list_of_mats |> str()
#> List of 2
#>  $ 1: num [1:50257, 1:10] NA NA NA NA NA NA NA NA NA NA ...
#>   ..- attr(*, "dimnames")=List of 2
#>   .. ..$ : chr [1:50257] "!" "\"" "#" "$" ...
#>   .. ..$ : chr [1:10] "The" "Ġapple" "Ġdoesn" "'t" ...
#>  $ 2: num [1:50257, 1:9] NA NA NA NA NA NA NA NA NA NA ...
#>   ..- attr(*, "dimnames")=List of 2
#>   .. ..$ : chr [1:50257] "!" "\"" "#" "$" ...
#>   .. ..$ : chr [1:9] "Don" "'t" "Ġjudge" "Ġa" ...

# Inspect the last rows of the first matrix
list_of_mats[[1]] |> tail()
#>               The     Ġapple     Ġdoesn        't     Ġfall      Ġfar     Ġfrom
#> ominated       NA -15.142192 -19.096537 -32.80610 -28.53754 -25.08115 -27.33087
#> Ġregress       NA -13.093204 -14.924685 -32.17484 -12.01535 -20.19420 -21.34752
#> ĠCollider      NA -13.339488 -16.405062 -32.32004 -29.73265 -28.05914 -25.85092
#> Ġinformants    NA -12.950152 -14.873056 -34.56336 -24.34810 -18.73824 -23.02469
#> Ġgazed         NA -13.809768 -12.320757 -38.64273 -19.81046 -19.47638 -19.91819
#> <|endoftext|>  NA  -7.353133  -9.627634 -14.55747 -13.46111 -11.10686 -12.69976
#>                    Ġthe     Ġtree          .
#> ominated      -20.71178 -21.72472 -22.202908
#> Ġregress      -18.08488 -16.63838 -18.050945
#> ĠCollider     -21.89429 -23.19034 -20.985302
#> Ġinformants   -18.23528 -18.46630 -17.959438
#> Ġgazed        -17.88233 -19.51767 -17.944866
#> <|endoftext|> -13.89558 -15.86696  -8.816073

# Inspect the last rows of the second matrix
list_of_mats[[2]] |> tail()
#>               Don         't    Ġjudge         Ġa     Ġbook       Ġby      Ġits
#> ominated       NA -13.038710 -24.86986 -22.828133 -20.23553 -20.51216 -21.54437
#> Ġregress       NA -14.332914 -12.26382 -14.384551 -14.45043 -18.78388 -17.94085
#> ĠCollider      NA -15.067999 -16.55659 -18.564629 -17.95191 -20.40189 -22.59220
#> Ġinformants    NA -16.096315 -18.12209 -14.583617 -18.38805 -21.28853 -16.55606
#> Ġgazed         NA -15.309744 -18.21523 -18.643326 -17.84563 -20.48769 -17.91865
#> <|endoftext|>  NA  -6.514345 -10.62883  -9.849662 -12.23088 -10.47432 -12.96377
#>                  Ġcover          .
#> ominated      -25.93669 -25.934177
#> Ġregress      -23.10072 -20.106548
#> ĠCollider     -25.76248 -25.129421
#> Ġinformants   -19.48792 -21.326099
#> Ġgazed        -22.67084 -20.240665
#> <|endoftext|> -16.37127  -8.178905