Llm finetuning

LLM fine-tuning with SPDL data loading pipeline.

Fine-tunes LLaMA 3.2 1B on Alpaca-style instruction data using LoRA, with SPDL PipelineBuilder for high-performance concurrent tokenization.

SPDL Data Pipeline

The core of this example is the SPDL data loading pipeline:

  1. DistributedRandomSampler – distributes sample indices across ranks with per-epoch reshuffling

  2. pipe(tokenize, concurrency=N) – concurrent Alpaca-format prompt formatting and tokenization

  3. aggregate(batch_size) – groups into batches

  4. pipe(collate) – stacks tensors

  5. add_sink(buffer_size=3) – prefetch buffer for the training loop

Data

Download instruction-following datasets:

# https://github.com/tatsu-lab/stanford_alpaca
python download_alpaca.py --output /tmp/alpaca.jsonl
# https://huggingface.co/datasets/databricks/databricks-dolly-15k
python download_dolly.py --output /tmp/dolly.jsonl

Data format (JSONL with Alpaca-style fields):

{"instruction": "Explain what a linked list is.", "input": "", "output": "A linked list is..."}

Usage

torchrun \
  --nproc_per_node 8 \
  -m spdl.examples.llm_finetune.llm_finetuning \
  --model-path /path/to/Llama-3.2-1B-Instruct \
  --data-path \
    /tmp/alpaca.jsonl \
    /tmp/dolly.jsonl

With the default settings (global batch size 8x32), the training throughput reaches roughly ~570 samples on H100 GPUs.

Source

Source

Click here to see the source.
  1# Copyright (c) Meta Platforms, Inc. and affiliates.
  2# All rights reserved.
  3#
  4# This source code is licensed under the BSD-style license found in the
  5# LICENSE file in the root directory of this source tree.
  6
  7"""LLM fine-tuning with SPDL data loading pipeline.
  8
  9Fine-tunes LLaMA 3.2 1B on Alpaca-style instruction data using LoRA,
 10with SPDL PipelineBuilder for high-performance concurrent tokenization.
 11
 12SPDL Data Pipeline
 13^^^^^^^^^^^^^^^^^^
 14
 15The core of this example is the SPDL data loading pipeline:
 16
 171. ``DistributedRandomSampler`` -- distributes sample indices across ranks
 18   with per-epoch reshuffling
 192. ``pipe(tokenize, concurrency=N)`` -- concurrent Alpaca-format prompt
 20   formatting and tokenization
 213. ``aggregate(batch_size)`` -- groups into batches
 224. ``pipe(collate)`` -- stacks tensors
 235. ``add_sink(buffer_size=3)`` -- prefetch buffer for the training loop
 24
 25Data
 26^^^^
 27
 28Download instruction-following datasets::
 29
 30    # https://github.com/tatsu-lab/stanford_alpaca
 31    python download_alpaca.py --output /tmp/alpaca.jsonl
 32    # https://huggingface.co/datasets/databricks/databricks-dolly-15k
 33    python download_dolly.py --output /tmp/dolly.jsonl
 34
 35Data format (JSONL with Alpaca-style fields)::
 36
 37    {"instruction": "Explain what a linked list is.", "input": "", "output": "A linked list is..."}
 38
 39Usage
 40^^^^^
 41
 42::
 43
 44    torchrun \\
 45      --nproc_per_node 8 \\
 46      -m spdl.examples.llm_finetune.llm_finetuning \\
 47      --model-path /path/to/Llama-3.2-1B-Instruct \\
 48      --data-path \\
 49        /tmp/alpaca.jsonl \\
 50        /tmp/dolly.jsonl
 51
 52With the default settings (global batch size 8x32), the training throughput reaches roughly ~570
 53samples on H100 GPUs.
 54"""
 55
 56from __future__ import annotations
 57
 58__all__ = [
 59    "build_model",
 60    "build_pytorch_dataloader",
 61    "build_spdl_dataloader",
 62    "load_data",
 63    "main",
 64    "train",
 65]
 66
 67# pyre-strict
 68
 69import argparse
 70import logging
 71import os
 72import time
 73from collections.abc import Callable
 74from datetime import timedelta
 75from pathlib import Path
 76
 77import torch
 78import torch.distributed as dist
 79from torch.nn.parallel import DistributedDataParallel as DDP
 80
 81try:
 82    from examples.llm_finetune.utils.dataloader import (  # pyre-ignore[21]
 83        build_pytorch_dataloader,
 84    )
 85    from examples.llm_finetune.utils.pipeline import (  # pyre-ignore[21]
 86        build_spdl_dataloader,
 87    )
 88    from examples.llm_finetune.utils.utils import (  # pyre-ignore[21]
 89        load_data,
 90        report_progress,
 91        resolve_model_path,
 92    )
 93except ImportError:
 94    from spdl.examples.llm_finetune.utils.dataloader import build_pytorch_dataloader
 95    from spdl.examples.llm_finetune.utils.pipeline import (
 96        build_spdl_dataloader,
 97    )
 98    from spdl.examples.llm_finetune.utils.utils import (
 99        load_data,
100        report_progress,
101        resolve_model_path,
102    )
103
104_LG: logging.Logger = logging.getLogger(__name__)
105
106
107# ---------------------------------------------------------------------------
108# Model setup
109# ---------------------------------------------------------------------------
110
111
112def build_model(
113    model_path: str,
114    device: torch.device,
115    lora_r: int,
116    lora_alpha: int,
117    lora_dropout: float,
118) -> torch.nn.Module:
119    """Load LLaMA model and apply LoRA."""
120    from peft import get_peft_model, LoraConfig, TaskType
121    from transformers import AutoModelForCausalLM
122
123    _LG.info("Loading model from %s", model_path)
124    model = AutoModelForCausalLM.from_pretrained(
125        model_path,
126        torch_dtype=torch.bfloat16,
127        attn_implementation="sdpa",
128    )
129
130    lora_config = LoraConfig(
131        task_type=TaskType.CAUSAL_LM,
132        r=lora_r,
133        lora_alpha=lora_alpha,
134        lora_dropout=lora_dropout,
135        target_modules=["q_proj", "v_proj"],
136    )
137    model = get_peft_model(model, lora_config)
138    model.print_trainable_parameters()
139
140    model = model.to(device)
141    return model
142
143
144# ---------------------------------------------------------------------------
145# Training
146# ---------------------------------------------------------------------------
147
148
149def train(
150    *,
151    model_path: str,
152    data_path: list[str],
153    output_dir: str,
154    max_seq_len: int,
155    batch_size: int,
156    num_epochs: int,
157    lr: float,
158    weight_decay: float,
159    max_grad_norm: float,
160    log_interval: int,
161    lora_r: int,
162    lora_alpha: int,
163    lora_dropout: float,
164    num_workers: int,
165    dataloader: str = "spdl",
166    mp_context: str = "forkserver",
167    progress_fn: Callable[[int, int], None] | None = None,
168) -> None:
169    """Main training function, called per-rank."""
170    rank: int = dist.get_rank()
171    world_size: int = dist.get_world_size()
172    local_rank: int = int(os.environ.get("LOCAL_RANK", 0))
173    device = torch.device(f"cuda:{local_rank}")
174    torch.cuda.set_device(device)
175
176    _LG.info(
177        "Rank %d/%d on device %s (dataloader=%s)",
178        rank,
179        world_size,
180        device,
181        dataloader,
182    )
183
184    # --- Data ---
185    samples = load_data(data_path)
186
187    from transformers import AutoTokenizer
188
189    tokenizer = AutoTokenizer.from_pretrained(model_path)
190    if tokenizer.pad_token is None:
191        tokenizer.pad_token = tokenizer.eos_token
192
193    # --- Model ---
194    model = build_model(
195        model_path,
196        device,
197        lora_r=lora_r,
198        lora_alpha=lora_alpha,
199        lora_dropout=lora_dropout,
200    )
201    ddp_model = DDP(model, device_ids=[local_rank])
202
203    # --- Optimizer ---
204    optimizer = torch.optim.AdamW(
205        ddp_model.parameters(),
206        lr=lr,
207        weight_decay=weight_decay,
208        foreach=True,
209    )
210
211    num_steps_per_epoch = len(samples) // (batch_size * world_size)
212    total_steps = num_steps_per_epoch * num_epochs
213    if rank == 0:
214        _LG.info(
215            "Training: %d samples, %d epochs, %d steps/epoch, %d total steps",
216            len(samples),
217            num_epochs,
218            num_steps_per_epoch,
219            total_steps,
220        )
221        if progress_fn is not None:
222            progress_fn(0, total_steps)
223
224    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
225        optimizer,
226        T_max=total_steps,
227        eta_min=lr * 0.1,
228    )
229
230    # --- Build data source ---
231    if dataloader == "pytorch":
232        dl = build_pytorch_dataloader(
233            samples=samples,
234            tokenizer=tokenizer,
235            max_seq_len=max_seq_len,
236            batch_size=batch_size,
237            rank=rank,
238            world_size=world_size,
239            num_workers=num_workers,
240            mp_context=mp_context,
241            device=device,
242        )
243    else:
244        dl = build_spdl_dataloader(
245            samples=samples,
246            tokenizer=tokenizer,
247            max_seq_len=max_seq_len,
248            batch_size=batch_size,
249            rank=rank,
250            world_size=world_size,
251            num_threads=num_workers,
252            mp_context=mp_context,
253        )
254
255    # --- Training loop ---
256    global_step = 0
257    ddp_model.train()
258
259    for epoch in range(num_epochs):
260        _LG.info("Epoch %d/%d", epoch + 1, num_epochs)
261
262        t0 = time.monotonic()
263        epoch_loss = 0.0
264        num_batches = 0
265
266        for batch in dl:
267            outputs = ddp_model(
268                input_ids=batch["input_ids"],
269                attention_mask=batch["attention_mask"],
270                labels=batch["labels"],
271            )
272            loss = outputs.loss
273
274            loss.backward()
275            torch.nn.utils.clip_grad_norm_(ddp_model.parameters(), max_grad_norm)
276            optimizer.step()
277            scheduler.step()
278            optimizer.zero_grad()
279
280            epoch_loss += loss.item()
281            num_batches += 1
282            global_step += 1
283
284            if rank == 0:
285                if progress_fn is not None:
286                    progress_fn(global_step, total_steps)
287                if global_step % log_interval == 0:
288                    avg_loss = epoch_loss / num_batches
289                    elapsed = time.monotonic() - t0
290                    _LG.info(
291                        "Step %d | loss=%.4f | lr=%.2e | %.1f samples/s",
292                        global_step,
293                        avg_loss,
294                        scheduler.get_last_lr()[0],
295                        num_batches * batch_size * world_size / elapsed,
296                    )
297
298        elapsed = time.monotonic() - t0
299        if rank == 0:
300            avg_loss = epoch_loss / max(num_batches, 1)
301            _LG.info(
302                "Epoch %d complete | avg_loss=%.4f | %.1fs | %.1f samples/s",
303                epoch + 1,
304                avg_loss,
305                elapsed,
306                num_batches * batch_size * world_size / elapsed,
307            )
308
309    # --- Save ---
310    if rank == 0 and output_dir:
311        output_path = Path(output_dir)
312        output_path.mkdir(parents=True, exist_ok=True)
313        ddp_model.module.save_pretrained(output_path)
314        tokenizer.save_pretrained(output_path)
315        _LG.info("Model saved to %s", output_path)
316
317
318def parse_args() -> argparse.Namespace:
319    parser = argparse.ArgumentParser(description=__doc__)
320    # Model
321    parser.add_argument(
322        "--model-path",
323        type=str,
324        required=True,
325        help="Path to pretrained LLaMA model directory",
326    )
327    parser.add_argument(
328        "--output-dir",
329        type=str,
330        default="",
331        help="Directory to save fine-tuned LoRA weights",
332    )
333    # Data
334    parser.add_argument(
335        "--data-path",
336        type=str,
337        nargs="+",
338        required=True,
339        help="One or more paths to Alpaca-format JSONL files (local or manifold://).",
340    )
341    parser.add_argument("--max-seq-len", type=int, default=512)
342    # Training
343    parser.add_argument("--batch-size", type=int, default=4)
344    parser.add_argument("--num-epochs", type=int, default=10)
345    parser.add_argument("--lr", type=float, default=5e-4)
346    parser.add_argument("--weight-decay", type=float, default=0.01)
347    parser.add_argument("--max-grad-norm", type=float, default=1.0)
348    parser.add_argument("--log-interval", type=int, default=10)
349    # LoRA
350    parser.add_argument("--lora-r", type=int, default=8)
351    parser.add_argument("--lora-alpha", type=int, default=16)
352    parser.add_argument("--lora-dropout", type=float, default=0.05)
353    # Pipeline
354    parser.add_argument(
355        "--num-workers",
356        type=int,
357        default=8,
358        help="Concurrent tokenization workers in the data pipeline",
359    )
360    parser.add_argument(
361        "--dataloader",
362        type=str,
363        choices=["spdl", "pytorch"],
364        default="spdl",
365        help="Data loading backend: 'spdl' (default) or 'pytorch' (torch DataLoader)",
366    )
367    parser.add_argument(
368        "--mp-context",
369        type=str,
370        choices=["fork", "spawn", "forkserver"],
371        default="forkserver",
372        help="Multiprocessing context for workers (default: forkserver)",
373    )
374    return parser.parse_args()
375
376
377def init_logging() -> None:
378    """Initialize logging."""
379    rank = os.environ.get("RANK", "?")
380    logging.basicConfig(
381        level=logging.INFO,
382        format=f"%(asctime)s [%(levelname).1s] [Rank{rank}] %(name)s: %(message)s",
383    )
384
385
386def main(args: argparse.Namespace) -> None:
387    dist.init_process_group(backend="nccl", timeout=timedelta(minutes=3))
388    try:
389        train(
390            model_path=resolve_model_path(args.model_path),
391            data_path=args.data_path,
392            output_dir=args.output_dir,
393            max_seq_len=args.max_seq_len,
394            batch_size=args.batch_size,
395            num_epochs=args.num_epochs,
396            lr=args.lr,
397            weight_decay=args.weight_decay,
398            max_grad_norm=args.max_grad_norm,
399            log_interval=args.log_interval,
400            lora_r=args.lora_r,
401            lora_alpha=args.lora_alpha,
402            lora_dropout=args.lora_dropout,
403            num_workers=args.num_workers,
404            dataloader=args.dataloader,
405            mp_context=args.mp_context,
406            progress_fn=report_progress,
407        )
408    finally:
409        dist.destroy_process_group()
410
411
412if __name__ == "__main__":
413    init_logging()
414    main(parse_args())

API Reference

Functions

build_model(model_path: str, device: device, lora_r: int, lora_alpha: int, lora_dropout: float) Module[source]

Load LLaMA model and apply LoRA.

build_pytorch_dataloader(samples: list[dict[str, str]], tokenizer: PreTrainedTokenizerBase, max_seq_len: int, batch_size: int, rank: int, world_size: int, num_workers: int, mp_context: str = 'forkserver', device: torch.device | None = None) _TDataLoader[source]

Build a reusable PyTorch DataLoader for distributed LLM fine-tuning.

Build once before the training loop and reuse across epochs. The returned wrapper automatically calls DistributedSampler.set_epoch on each iteration, so no manual set_epoch call is needed. Similarly, yielded batches are automatically transferred to the given device, if any, so no manual to(device) call is needed.

build_spdl_dataloader(samples: list[dict[str, str]], tokenizer: PreTrainedTokenizerBase, max_seq_len: int, batch_size: int, rank: int, world_size: int, num_threads: int, mp_context: str = 'forkserver') _TDataLoader[source]

Build a reusable SPDL data loader with nested pipeline architecture.

Creates two nested pipelines to separate CPU-bound data loading from GPU transfer, reducing the noisy-neighbour effect where data loading threads in the main process compete with the training loop for CPU time, delaying GPU kernel launches.

Inner pipeline (runs in a subprocess):

Sampling → lookup → tokenize (concurrent) → aggregate → collate. All CPU work runs in a dedicated subprocess with its own thread pool, completely isolating it from the training process. The subprocess is created once and reused across epochs — each for ... in call rebuilds the pipeline inside the same subprocess.

Outer pipeline (runs in the main process):

Receives CPU batches from the subprocess via IPC queue and transfers them to GPU using transfer_tensor with a dedicated single-thread executor. This ensures GPU transfer uses a consistent CUDA stream and overlaps with training computation.

Build once before the training loop and iterate each epoch:

dataloader = build_spdl_dataloader(samples, tokenizer, ...)
for epoch in range(num_epochs):
    for batch in dataloader:
        train(batch)
load_data(paths: Sequence[str]) list[dict[str, str]][source]

Load and concatenate data from one or more JSONL files.

main(args: Namespace) None[source]
train(*, model_path: str, data_path: list[str], output_dir: str, max_seq_len: int, batch_size: int, num_epochs: int, lr: float, weight_decay: float, max_grad_norm: float, log_interval: int, lora_r: int, lora_alpha: int, lora_dropout: float, num_workers: int, dataloader: str = 'spdl', mp_context: str = 'forkserver', progress_fn: Callable[[int, int], None] | None = None) None[source]

Main training function, called per-rank.