Llm finetuning

LLM fine-tuning with SPDL data loading pipeline.

Fine-tunes LLaMA 3.2 1B on Alpaca-style instruction data using LoRA, with SPDL PipelineBuilder for high-performance concurrent tokenization.

SPDL Data Pipeline

The core of this example is the SPDL data loading pipeline:

  1. DistributedRandomSampler – distributes sample indices across ranks with per-epoch reshuffling

  2. pipe(tokenize, concurrency=N) – concurrent Alpaca-format prompt formatting and tokenization

  3. aggregate(batch_size) – groups into batches

  4. pipe(collate) – stacks tensors

  5. add_sink(buffer_size=3) – prefetch buffer for the training loop

Data

Download instruction-following datasets:

# https://github.com/tatsu-lab/stanford_alpaca
python download_alpaca.py --output /tmp/alpaca.jsonl
# https://huggingface.co/datasets/databricks/databricks-dolly-15k
python download_dolly.py --output /tmp/dolly.jsonl

Data format (JSONL with Alpaca-style fields):

{"instruction": "Explain what a linked list is.", "input": "", "output": "A linked list is..."}

Usage

torchrun \
  --nproc_per_node 8 \
  -m spdl.examples.llm_finetune.llm_finetuning \
  --model-path /path/to/Llama-3.2-1B-Instruct \
  --data-path \
    /tmp/alpaca.jsonl \
    /tmp/dolly.jsonl

With the default settings (global batch size 8x32), the training throughput reaches roughly ~570 samples on H100 GPUs.

Source

Source

Click here to see the source.
  1# Copyright (c) Meta Platforms, Inc. and affiliates.
  2# All rights reserved.
  3#
  4# This source code is licensed under the BSD-style license found in the
  5# LICENSE file in the root directory of this source tree.
  6
  7"""LLM fine-tuning with SPDL data loading pipeline.
  8
  9Fine-tunes LLaMA 3.2 1B on Alpaca-style instruction data using LoRA,
 10with SPDL PipelineBuilder for high-performance concurrent tokenization.
 11
 12SPDL Data Pipeline
 13^^^^^^^^^^^^^^^^^^
 14
 15The core of this example is the SPDL data loading pipeline:
 16
 171. ``DistributedRandomSampler`` -- distributes sample indices across ranks
 18   with per-epoch reshuffling
 192. ``pipe(tokenize, concurrency=N)`` -- concurrent Alpaca-format prompt
 20   formatting and tokenization
 213. ``aggregate(batch_size)`` -- groups into batches
 224. ``pipe(collate)`` -- stacks tensors
 235. ``add_sink(buffer_size=3)`` -- prefetch buffer for the training loop
 24
 25Data
 26^^^^
 27
 28Download instruction-following datasets::
 29
 30    # https://github.com/tatsu-lab/stanford_alpaca
 31    python download_alpaca.py --output /tmp/alpaca.jsonl
 32    # https://huggingface.co/datasets/databricks/databricks-dolly-15k
 33    python download_dolly.py --output /tmp/dolly.jsonl
 34
 35Data format (JSONL with Alpaca-style fields)::
 36
 37    {"instruction": "Explain what a linked list is.", "input": "", "output": "A linked list is..."}
 38
 39Usage
 40^^^^^
 41
 42::
 43
 44    torchrun \\
 45      --nproc_per_node 8 \\
 46      -m spdl.examples.llm_finetune.llm_finetuning \\
 47      --model-path /path/to/Llama-3.2-1B-Instruct \\
 48      --data-path \\
 49        /tmp/alpaca.jsonl \\
 50        /tmp/dolly.jsonl
 51
 52With the default settings (global batch size 8x32), the training throughput reaches roughly ~570
 53samples on H100 GPUs.
 54"""
 55
 56from __future__ import annotations
 57
 58__all__ = [
 59    "build_model",
 60    "build_pipeline",
 61    "iterate_pipeline",
 62    "load_data",
 63    "main",
 64    "train",
 65]
 66
 67# pyre-strict
 68
 69import argparse
 70import logging
 71import os
 72import threading
 73import time
 74from collections.abc import Callable, Iterator
 75from datetime import timedelta
 76from pathlib import Path
 77from typing import TYPE_CHECKING
 78
 79import torch
 80import torch.distributed as dist
 81from spdl.pipeline import PipelineBuilder
 82from spdl.source import DistributedRandomSampler
 83from torch.nn.parallel import DistributedDataParallel as DDP
 84
 85if TYPE_CHECKING:
 86    from transformers import PreTrainedTokenizerBase
 87
 88try:
 89    from examples.llm_finetune.utils.utils import (  # pyre-ignore[21]
 90        format_prompt,
 91        load_data,
 92        report_progress,
 93        resolve_model_path,
 94    )
 95except ImportError:
 96    from spdl.examples.llm_finetune.utils.utils import (
 97        format_prompt,
 98        load_data,
 99        report_progress,
100        resolve_model_path,
101    )
102
103_LG: logging.Logger = logging.getLogger(__name__)
104
105
106# ---------------------------------------------------------------------------
107# SPDL data pipeline
108# ---------------------------------------------------------------------------
109
110
111def build_pipeline(
112    samples: list[dict[str, str]],
113    tokenizer: PreTrainedTokenizerBase,
114    max_seq_len: int,
115    batch_size: int,
116    rank: int,
117    world_size: int,
118    num_threads: int,
119    seed: int,
120) -> PipelineBuilder:
121    """Build an SPDL pipeline for concurrent tokenization.
122
123    Pipeline stages:
124      1. Source: DistributedRandomSampler yields sample indices
125      2. Pipe: Look up sample by index
126      3. Pipe (concurrent): Format prompt and tokenize
127      4. Aggregate: Group into batches
128      5. Pipe: Collate into tensors
129      6. Sink: Buffer for the training loop
130    """
131
132    sampler = DistributedRandomSampler(
133        len(samples),
134        rank=rank,
135        world_size=world_size,
136        seed=seed,
137    )
138
139    # The HuggingFace fast tokenizer's Rust backend is not thread-safe.
140    # Use thread-local copies so each SPDL worker thread has its own instance.
141    class _TokenizerTLS(threading.local):
142        tokenizer: PreTrainedTokenizerBase | None = None
143
144    _tls: _TokenizerTLS = _TokenizerTLS()
145
146    def _get_tokenizer() -> PreTrainedTokenizerBase:
147        if _tls.tokenizer is None:
148            import copy
149
150            _tls.tokenizer = copy.deepcopy(tokenizer)
151        return _tls.tokenizer
152
153    def lookup(idx: int) -> dict[str, str]:
154        return samples[idx]
155
156    def tokenize(sample: dict[str, str]) -> dict[str, torch.Tensor]:
157        tok = _get_tokenizer()
158        text = format_prompt(sample)
159        enc = tok(
160            text,
161            max_length=max_seq_len,
162            truncation=True,
163            padding="max_length",
164            return_tensors="pt",
165        )
166        input_ids = enc["input_ids"].squeeze(0)
167        attention_mask = enc["attention_mask"].squeeze(0)
168        # For causal LM, labels = input_ids; mask padding with -100
169        labels = input_ids.clone()
170        labels[attention_mask == 0] = -100
171        return {
172            "input_ids": input_ids,
173            "attention_mask": attention_mask,
174            "labels": labels,
175        }
176
177    def collate(items: list[dict[str, torch.Tensor]]) -> dict[str, torch.Tensor]:
178        return {
179            "input_ids": torch.stack([it["input_ids"] for it in items]),
180            "attention_mask": torch.stack([it["attention_mask"] for it in items]),
181            "labels": torch.stack([it["labels"] for it in items]),
182        }
183
184    return (
185        PipelineBuilder()
186        .add_source(sampler)
187        .pipe(lookup)
188        .pipe(tokenize, concurrency=num_threads)
189        .aggregate(batch_size)
190        .pipe(collate)
191        .add_sink(buffer_size=3)
192    )
193
194
195def iterate_pipeline(
196    pipeline_builder: PipelineBuilder,
197    num_threads: int,
198    device: torch.device,
199) -> Iterator[dict[str, torch.Tensor]]:
200    """Build, run, and iterate over the SPDL pipeline, transferring to device."""
201    pipeline = pipeline_builder.build(num_threads=num_threads)
202    with pipeline.auto_stop():
203        for batch in pipeline.get_iterator(timeout=120):
204            yield {k: v.to(device, non_blocking=True) for k, v in batch.items()}
205
206
207# ---------------------------------------------------------------------------
208# Model setup
209# ---------------------------------------------------------------------------
210
211
212def build_model(
213    model_path: str,
214    device: torch.device,
215    lora_r: int,
216    lora_alpha: int,
217    lora_dropout: float,
218) -> torch.nn.Module:
219    """Load LLaMA model and apply LoRA."""
220    from peft import get_peft_model, LoraConfig, TaskType
221    from transformers import AutoModelForCausalLM
222
223    _LG.info("Loading model from %s", model_path)
224    model = AutoModelForCausalLM.from_pretrained(
225        model_path,
226        torch_dtype=torch.bfloat16,
227        attn_implementation="sdpa",
228    )
229
230    lora_config = LoraConfig(
231        task_type=TaskType.CAUSAL_LM,
232        r=lora_r,
233        lora_alpha=lora_alpha,
234        lora_dropout=lora_dropout,
235        target_modules=["q_proj", "v_proj"],
236    )
237    model = get_peft_model(model, lora_config)
238    model.print_trainable_parameters()
239
240    model = model.to(device)
241    return model
242
243
244# ---------------------------------------------------------------------------
245# Training
246# ---------------------------------------------------------------------------
247
248
249def train(
250    *,
251    model_path: str,
252    data_path: list[str],
253    output_dir: str,
254    max_seq_len: int,
255    batch_size: int,
256    num_epochs: int,
257    lr: float,
258    weight_decay: float,
259    max_grad_norm: float,
260    log_interval: int,
261    lora_r: int,
262    lora_alpha: int,
263    lora_dropout: float,
264    num_workers: int,
265    progress_fn: Callable[[int, int], None] | None = None,
266) -> None:
267    """Main training function, called per-rank."""
268    rank: int = dist.get_rank()
269    world_size: int = dist.get_world_size()
270    local_rank: int = int(os.environ.get("LOCAL_RANK", 0))
271    device = torch.device(f"cuda:{local_rank}")
272    torch.cuda.set_device(device)
273
274    _LG.info("Rank %d/%d on device %s", rank, world_size, device)
275
276    # --- Data ---
277    samples = load_data(data_path)
278
279    from transformers import AutoTokenizer
280
281    tokenizer = AutoTokenizer.from_pretrained(model_path)
282    if tokenizer.pad_token is None:
283        tokenizer.pad_token = tokenizer.eos_token
284
285    # --- Model ---
286    model = build_model(
287        model_path,
288        device,
289        lora_r=lora_r,
290        lora_alpha=lora_alpha,
291        lora_dropout=lora_dropout,
292    )
293    ddp_model = DDP(model, device_ids=[local_rank])
294
295    # --- Optimizer ---
296    optimizer = torch.optim.AdamW(
297        ddp_model.parameters(),
298        lr=lr,
299        weight_decay=weight_decay,
300        foreach=True,
301    )
302
303    num_steps_per_epoch = len(samples) // (batch_size * world_size)
304    total_steps = num_steps_per_epoch * num_epochs
305    if rank == 0:
306        _LG.info(
307            "Training: %d samples, %d epochs, %d steps/epoch, %d total steps",
308            len(samples),
309            num_epochs,
310            num_steps_per_epoch,
311            total_steps,
312        )
313        if progress_fn is not None:
314            progress_fn(0, total_steps)
315
316    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
317        optimizer,
318        T_max=total_steps,
319        eta_min=lr * 0.1,
320    )
321
322    # --- Training loop ---
323    global_step = 0
324    ddp_model.train()
325
326    for epoch in range(num_epochs):
327        _LG.info("Epoch %d/%d", epoch + 1, num_epochs)
328
329        pipeline_builder = build_pipeline(
330            samples=samples,
331            tokenizer=tokenizer,
332            max_seq_len=max_seq_len,
333            batch_size=batch_size,
334            rank=rank,
335            world_size=world_size,
336            num_threads=num_workers,
337            seed=epoch,  # different shuffle per epoch
338        )
339
340        t0 = time.monotonic()
341        epoch_loss = 0.0
342        num_batches = 0
343
344        for batch in iterate_pipeline(pipeline_builder, num_workers, device):
345            outputs = ddp_model(
346                input_ids=batch["input_ids"],
347                attention_mask=batch["attention_mask"],
348                labels=batch["labels"],
349            )
350            loss = outputs.loss
351
352            loss.backward()
353            torch.nn.utils.clip_grad_norm_(ddp_model.parameters(), max_grad_norm)
354            optimizer.step()
355            scheduler.step()
356            optimizer.zero_grad()
357
358            epoch_loss += loss.item()
359            num_batches += 1
360            global_step += 1
361
362            if rank == 0:
363                if progress_fn is not None:
364                    progress_fn(global_step, total_steps)
365                if global_step % log_interval == 0:
366                    avg_loss = epoch_loss / num_batches
367                    elapsed = time.monotonic() - t0
368                    _LG.info(
369                        "Step %d | loss=%.4f | lr=%.2e | %.1f samples/s",
370                        global_step,
371                        avg_loss,
372                        scheduler.get_last_lr()[0],
373                        num_batches * batch_size * world_size / elapsed,
374                    )
375
376        elapsed = time.monotonic() - t0
377        if rank == 0:
378            avg_loss = epoch_loss / max(num_batches, 1)
379            _LG.info(
380                "Epoch %d complete | avg_loss=%.4f | %.1fs | %.1f samples/s",
381                epoch + 1,
382                avg_loss,
383                elapsed,
384                num_batches * batch_size * world_size / elapsed,
385            )
386
387    # --- Save ---
388    if rank == 0 and output_dir:
389        output_path = Path(output_dir)
390        output_path.mkdir(parents=True, exist_ok=True)
391        ddp_model.module.save_pretrained(output_path)
392        tokenizer.save_pretrained(output_path)
393        _LG.info("Model saved to %s", output_path)
394
395
396def parse_args() -> argparse.Namespace:
397    parser = argparse.ArgumentParser(description=__doc__)
398    # Model
399    parser.add_argument(
400        "--model-path",
401        type=str,
402        required=True,
403        help="Path to pretrained LLaMA model directory",
404    )
405    parser.add_argument(
406        "--output-dir",
407        type=str,
408        default="",
409        help="Directory to save fine-tuned LoRA weights",
410    )
411    # Data
412    parser.add_argument(
413        "--data-path",
414        type=str,
415        nargs="+",
416        required=True,
417        help="One or more paths to Alpaca-format JSONL files (local or manifold://).",
418    )
419    parser.add_argument("--max-seq-len", type=int, default=512)
420    # Training
421    parser.add_argument("--batch-size", type=int, default=4)
422    parser.add_argument("--num-epochs", type=int, default=10)
423    parser.add_argument("--lr", type=float, default=5e-4)
424    parser.add_argument("--weight-decay", type=float, default=0.01)
425    parser.add_argument("--max-grad-norm", type=float, default=1.0)
426    parser.add_argument("--log-interval", type=int, default=10)
427    # LoRA
428    parser.add_argument("--lora-r", type=int, default=8)
429    parser.add_argument("--lora-alpha", type=int, default=16)
430    parser.add_argument("--lora-dropout", type=float, default=0.05)
431    # Pipeline
432    parser.add_argument(
433        "--num-workers",
434        type=int,
435        default=8,
436        help="Concurrent tokenization workers in the SPDL pipeline",
437    )
438    return parser.parse_args()
439
440
441def init_logging() -> None:
442    """Initialize logging."""
443    rank = os.environ.get("RANK", "?")
444    logging.basicConfig(
445        level=logging.INFO,
446        format=f"%(asctime)s [%(levelname).1s] [Rank{rank}] %(name)s: %(message)s",
447    )
448
449
450def main(args: argparse.Namespace) -> None:
451    dist.init_process_group(backend="nccl", timeout=timedelta(minutes=30))
452    try:
453        train(
454            model_path=resolve_model_path(args.model_path),
455            data_path=args.data_path,
456            output_dir=args.output_dir,
457            max_seq_len=args.max_seq_len,
458            batch_size=args.batch_size,
459            num_epochs=args.num_epochs,
460            lr=args.lr,
461            weight_decay=args.weight_decay,
462            max_grad_norm=args.max_grad_norm,
463            log_interval=args.log_interval,
464            lora_r=args.lora_r,
465            lora_alpha=args.lora_alpha,
466            lora_dropout=args.lora_dropout,
467            num_workers=args.num_workers,
468            progress_fn=report_progress,
469        )
470    finally:
471        dist.destroy_process_group()
472
473
474if __name__ == "__main__":
475    init_logging()
476    main(parse_args())

API Reference

Functions

build_model(model_path: str, device: device, lora_r: int, lora_alpha: int, lora_dropout: float) Module[source]

Load LLaMA model and apply LoRA.

build_pipeline(samples: list[dict[str, str]], tokenizer: PreTrainedTokenizerBase, max_seq_len: int, batch_size: int, rank: int, world_size: int, num_threads: int, seed: int) PipelineBuilder[source]

Build an SPDL pipeline for concurrent tokenization.

Pipeline stages:
  1. Source: DistributedRandomSampler yields sample indices

  2. Pipe: Look up sample by index

  3. Pipe (concurrent): Format prompt and tokenize

  4. Aggregate: Group into batches

  5. Pipe: Collate into tensors

  6. Sink: Buffer for the training loop

iterate_pipeline(pipeline_builder: PipelineBuilder, num_threads: int, device: device) Iterator[dict[str, Tensor]][source]

Build, run, and iterate over the SPDL pipeline, transferring to device.

load_data(paths: Sequence[str]) list[dict[str, str]][source]

Load and concatenate data from one or more JSONL files.

main(args: Namespace) None[source]
train(*, model_path: str, data_path: list[str], output_dir: str, max_seq_len: int, batch_size: int, num_epochs: int, lr: float, weight_decay: float, max_grad_norm: float, log_interval: int, lora_r: int, lora_alpha: int, lora_dropout: float, num_workers: int, progress_fn: Callable[[int, int], None] | None = None) None[source]

Main training function, called per-rank.