neuralset.extractors.text.HuggingFaceText¶
- class neuralset.extractors.text.HuggingFaceText(*, model_name: str = 'openai-community/gpt2', device: Literal['auto', 'cpu', 'cuda', 'accelerate'] = 'auto', layers: float | list[float] | Literal['all'] = 0.6666666666666666, cache_n_layers: int | None = None, layer_aggregation: Literal['mean', 'sum', 'group_mean'] | None = 'mean', token_aggregation: Literal['first', 'last', 'mean', 'sum', 'max'] | None = 'mean', event_types: Literal['Word', 'Sentence'] = 'Word', aggregation: Literal['single', 'sum', 'mean', 'first', 'middle', 'last', 'cat', 'stack', 'trigger'] = 'single', allow_missing: bool = False, frequency: float = 0.0, infra: MapInfra = MapInfra(folder=None, cluster=None, logs='{folder}/logs/{user}/%j', job_name=None, timeout_min=25, nodes=1, tasks_per_node=1, cpus_per_task=10, gpus_per_node=1, mem_gb=None, max_pickle_size_gb=None, slurm_constraint=None, slurm_partition=None, slurm_account=None, slurm_qos=None, slurm_use_srun=False, slurm_additional_parameters=None, conda_env=None, workdir=None, permissions=511, version='v6', keep_in_ram=True, max_jobs=128, min_samples_per_job=4096, forbid_single_item_computation=False, mode='cached'), batch_size: int = 32, contextualized: bool = True, pretrained: bool | Literal['part-reversal'] = True)[source][source]¶
Get embeddings from HuggingFace language models. This extractor can be applied to any kind of event which has a text attribute: Word, Sentence, etc.
- Parameters:
batch_size (int) – Batch size for the language model.
contextualized (bool) – True by default, the context of the event is used to compute the embeddings.
pretrained (bool or "part-reversal") – use pretrained model if True, untrained intial model if False, or custom scrambling of the model pretrained weights if “part-reveral”
Note
The tokenizer truncates the input to the maximum size specified by the model. An empty context will raise an error to the default HuggingFaceText since contextualized is True by default. To get non-contextualized embeddings, set contextualized to False.