Source code for fairseq2.recipes.lm._text_generate

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

import json
from dataclasses import dataclass, field
from pathlib import Path
from typing import Iterable, Mapping, TextIO, final

import torch
from typing_extensions import override

from fairseq2.context import RuntimeContext
from fairseq2.data.text.tokenizers import TextTokenDecoder, TextTokenizer
from fairseq2.datasets import StaticBatching, SyncMode
from fairseq2.datasets.instruction import (
    GENERIC_INSTRUCTION_DATASET_FAMILY,
    InstructionDataset,
    InstructionPromptReadOptions,
)
from fairseq2.error import InternalError, ProgramError
from fairseq2.gang import Gangs
from fairseq2.generation import SamplingConfig, SequenceGenerator
from fairseq2.models.decoder import DecoderModel
from fairseq2.models.sequence import SequenceBatch
from fairseq2.recipes.common import (
    create_generator,
    create_seq_generator,
    load_dataset,
    load_text_tokenizer,
    register_extra_asset_paths,
    setup_gangs,
    setup_reference_model,
)
from fairseq2.recipes.config import (
    CommonSection,
    DatasetSection,
    GangSection,
    GeneratorSection,
    ReferenceModelSection,
    SequenceGeneratorSection,
)
from fairseq2.recipes.error import UnitError
from fairseq2.recipes.generator import AbstractGeneratorUnit, Generator
from fairseq2.recipes.metrics import SequenceGenerationMetricBag
from fairseq2.typing import CPU
from fairseq2.utils.file import FileMode
from fairseq2.utils.rng import manual_seed
from fairseq2.utils.structured import structure
from fairseq2.utils.validation import validate


[docs] @dataclass(kw_only=True) class TextGenerateConfig: model: ReferenceModelSection = field( default_factory=lambda: ReferenceModelSection(name="llama3_8b_instruct") ) dataset: TextGenerateDatasetSection = field( default_factory=lambda: TextGenerateDatasetSection() ) gang: GangSection = field(default_factory=lambda: GangSection()) generator: GeneratorSection = field( default_factory=lambda: GeneratorSection(dtype=torch.bfloat16) ) seq_generator: SequenceGeneratorSection = field( default_factory=lambda: SequenceGeneratorSection( config=SamplingConfig(), batch_size=1 ) ) common: CommonSection = field(default_factory=lambda: CommonSection())
[docs] @dataclass(kw_only=True) class TextGenerateDatasetSection(DatasetSection): name: str = "foo" # TODO: change! family: str = GENERIC_INSTRUCTION_DATASET_FAMILY path: Path | None = None split: str = "default" min_seq_len: int = 1 """The minimum sequence length.""" max_seq_len: int = 8192 """The maximum sequence length.""" num_prefetch: int = 4 """The number of batches to prefetch in background.""" extras: dict[str, object] = field(default_factory=dict) """The dataset-specific extra options."""
def register_text_generate_configs(context: RuntimeContext) -> None: registry = context.get_config_registry(TextGenerateConfig) preset = registry.decorator @preset("llama3_instruct") def llama3_instruct() -> TextGenerateConfig: return TextGenerateConfig() @preset("llama3_70b_instruct") def llama3_70b_instruct() -> TextGenerateConfig: config = llama3_instruct() config.model.name = "llama3_70b_instruct" config.gang.tensor_parallel_size = 8 return config @preset("llama3_1_instruct") def llama3_1_instruct() -> TextGenerateConfig: config = llama3_instruct() config.model.name = "llama3_1_8b_instruct" return config @preset("llama3_1_70b_instruct") def llama3_1_70b_instruct() -> TextGenerateConfig: config = llama3_70b_instruct() config.model.name = "llama3_1_70b_instruct" return config
[docs] @torch.inference_mode() def load_text_generator( context: RuntimeContext, config: object, output_dir: Path ) -> Generator[SequenceBatch]: config = structure(config, TextGenerateConfig) validate(config) register_extra_asset_paths(context, config) torch.set_float32_matmul_precision("high") gangs = setup_gangs(context, config) seed = config.common.seed manual_seed(seed, CPU, gangs.root.device) seed += 1 model = setup_reference_model( DecoderModel, context, config.model.name, gangs, config.generator.dtype, config.generator.amp, config.generator.torch_compile, ) dataset = load_dataset(InstructionDataset, context, config, gangs) tokenizer = load_text_tokenizer(context, config) # Initialize the unit. seq_generator = create_seq_generator(context, config, model) if gangs.tp.rank == 0: file_system = context.file_system rank = gangs.dp.rank text_file = output_dir.joinpath(f"output/rank_{rank}.txt") json_file = output_dir.joinpath(f"output/rank_{rank}.jsonl") try: try: file_system.make_directory(text_file.parent) except OSError as ex: raise UnitError( f"The '{text_file.parent}' output directory cannot be created. See the nested exception for details." ) from ex try: text_fp = file_system.open_text(text_file, mode=FileMode.WRITE) except OSError as ex: raise UnitError( f"The '{text_file}' output file cannot be created. See the nested exception for details." ) from ex try: json_fp = file_system.open_text(json_file, mode=FileMode.WRITE) except OSError as ex: raise UnitError( f"The '{json_file}' output file cannot be created. See the nested exception for details." ) from ex except UnitError as ex: raise ProgramError( "The generation unit cannot be initialized. See the nested exception for details." ) from ex else: text_fp = None json_fp = None unit = TextGenerateUnit( seq_generator, tokenizer, gangs, text_output_stream=text_fp, json_output_stream=json_fp, ) batching = StaticBatching(config.seq_generator.batch_size) read_options = InstructionPromptReadOptions( batching=batching, sync_mode=SyncMode.UNTIL_LAST, num_prefetch=config.dataset.num_prefetch, seed=seed, extras=config.dataset.extras, ) data_reader = dataset.create_prompt_reader( config.dataset.split, tokenizer, gangs.dp, config.dataset.min_seq_len, config.dataset.max_seq_len, read_options, ) seed += 1 return create_generator(context, config, output_dir, unit, data_reader, gangs, seed)
[docs] @final class TextGenerateUnit(AbstractGeneratorUnit[SequenceBatch]): """Represents a text generation unit.""" _generator: SequenceGenerator _text_decoder: TextTokenDecoder _text_output_stream: TextIO | None _json_output_stream: TextIO | None _metric_bag: SequenceGenerationMetricBag def __init__( self, generator: SequenceGenerator, tokenizer: TextTokenizer, gangs: Gangs, text_output_stream: TextIO | None, json_output_stream: TextIO | None, ) -> None: super().__init__(generator.model) self._generator = generator self._text_decoder = tokenizer.create_decoder() self._text_output_stream = text_output_stream self._json_output_stream = json_output_stream self._metric_bag = SequenceGenerationMetricBag(gangs.dp) @override def __call__(self, batch: SequenceBatch) -> None: if batch.example is None: raise ValueError("`batch.example` must not be `None`.") if not isinstance(batch.example, Mapping): raise TypeError( f"`batch.example` must be of type `{Mapping}`, but is of type `{type(batch.example)}` instead." ) try: prompts = batch.example["prompt"] except KeyError: raise ValueError("`batch.example` must contain a 'prompt' item.") from None if not isinstance(prompts, Iterable): raise TypeError( f"`batch.example['prompt'] must be an iterable of strings, but is of type `{type(prompts)}` instead." ) ids = batch.example["id"] output = self._generator(batch.seqs, batch.padding_mask) self._metric_bag.update_batch_metrics(output) # Check if we are in the first tensor parallel group. if self._text_output_stream is None and self._json_output_stream is None: return try: for id_, prompt, hypotheses in zip(ids, prompts, output.hypotheses): if len(hypotheses) == 0: raise InternalError( "The sequence generator returned no hypothesis. Please file a bug report." ) hypothesis = hypotheses[0] seq = hypothesis.seq response = self._text_decoder(seq) token_indices = seq.tolist() if hypothesis.score is None: score = None else: score = float(hypothesis.score) if hypothesis.step_scores is None: step_scores = None else: step_scores = hypothesis.step_scores.tolist() # Dump as text. stream = self._text_output_stream if stream is not None: if id_ is not None: stream.write("<<<<< ID >>>>>") stream.write("\n") stream.write(f"{id_}") stream.write("\n\n") stream.write("<<<<< PROMPT >>>>>") stream.write("\n") stream.write(prompt) stream.write("\n\n") stream.write("<<<<< RESPONSE >>>>>") stream.write("\n") stream.write(response) stream.write("\n\n") stream.write("<<<<< TOKEN INDICES >>>>>") stream.write("\n") stream.write(", ".join(f"{t}" for t in token_indices)) if score is not None: stream.write("\n\n") stream.write("<<<<< SCORE >>>>>") stream.write("\n") stream.write(f"{score:.8f}") if step_scores is not None: stream.write("\n\n") stream.write("<<<<< STEP SCORES >>>>>") stream.write("\n") stream.write(", ".join(f"{s:.8f}" for s in step_scores)) stream.write("\n\n\n============================\n\n\n") # Dump as JSON. stream = self._json_output_stream if stream is not None: json_output = { "id": id_, "prompt": prompt, "response": response, "token_indices": token_indices, "score": score, "step_scores": step_scores, } json.dump(json_output, stream, indent=None) stream.write("\n") stream = self._text_output_stream if stream is not None: stream.flush() stream = self._json_output_stream if stream is not None: stream.flush() except OSError as ex: raise UnitError( "The generator output cannot be written to the stream. See the nested exception for details." ) from ex @property @override def metric_bag(self) -> SequenceGenerationMetricBag: return self._metric_bag