Llm finetuning

LLM fine-tuning with SPDL data loading pipeline.

Fine-tunes LLaMA 3.2 1B on Alpaca-style instruction data using LoRA, with SPDL PipelineBuilder for high-performance concurrent tokenization.

SPDL Data Pipeline

The core of this example is the SPDL data loading pipeline:

  1. DistributedRandomSampler – distributes sample indices across ranks with per-epoch reshuffling

  2. pipe(tokenize, concurrency=N) – concurrent Alpaca-format prompt formatting and tokenization

  3. aggregate(batch_size) – groups into batches

  4. pipe(collate) – stacks tensors

  5. add_sink(buffer_size=3) – prefetch buffer for the training loop

Data

Download instruction-following datasets:

# https://github.com/tatsu-lab/stanford_alpaca
python download_alpaca.py --output /tmp/alpaca.jsonl
# https://huggingface.co/datasets/databricks/databricks-dolly-15k
python download_dolly.py --output /tmp/dolly.jsonl

Data format (JSONL with Alpaca-style fields):

{"instruction": "Explain what a linked list is.", "input": "", "output": "A linked list is..."}

Usage

torchrun \
  --nproc_per_node 8 \
  -m spdl.examples.llm_finetune.llm_finetuning \
  --model-path /path/to/Llama-3.2-1B-Instruct \
  --data-path \
    /tmp/alpaca.jsonl \
    /tmp/dolly.jsonl

With the default settings (global batch size 8x32), the training throughput reaches roughly ~570 samples on H100 GPUs.

Source

Source

Click here to see the source.
  1# Copyright (c) Meta Platforms, Inc. and affiliates.
  2# All rights reserved.
  3#
  4# This source code is licensed under the BSD-style license found in the
  5# LICENSE file in the root directory of this source tree.
  6
  7"""LLM fine-tuning with SPDL data loading pipeline.
  8
  9Fine-tunes LLaMA 3.2 1B on Alpaca-style instruction data using LoRA,
 10with SPDL PipelineBuilder for high-performance concurrent tokenization.
 11
 12SPDL Data Pipeline
 13^^^^^^^^^^^^^^^^^^
 14
 15The core of this example is the SPDL data loading pipeline:
 16
 171. ``DistributedRandomSampler`` -- distributes sample indices across ranks
 18   with per-epoch reshuffling
 192. ``pipe(tokenize, concurrency=N)`` -- concurrent Alpaca-format prompt
 20   formatting and tokenization
 213. ``aggregate(batch_size)`` -- groups into batches
 224. ``pipe(collate)`` -- stacks tensors
 235. ``add_sink(buffer_size=3)`` -- prefetch buffer for the training loop
 24
 25Data
 26^^^^
 27
 28Download instruction-following datasets::
 29
 30    # https://github.com/tatsu-lab/stanford_alpaca
 31    python download_alpaca.py --output /tmp/alpaca.jsonl
 32    # https://huggingface.co/datasets/databricks/databricks-dolly-15k
 33    python download_dolly.py --output /tmp/dolly.jsonl
 34
 35Data format (JSONL with Alpaca-style fields)::
 36
 37    {"instruction": "Explain what a linked list is.", "input": "", "output": "A linked list is..."}
 38
 39Usage
 40^^^^^
 41
 42::
 43
 44    torchrun \\
 45      --nproc_per_node 8 \\
 46      -m spdl.examples.llm_finetune.llm_finetuning \\
 47      --model-path /path/to/Llama-3.2-1B-Instruct \\
 48      --data-path \\
 49        /tmp/alpaca.jsonl \\
 50        /tmp/dolly.jsonl
 51
 52With the default settings (global batch size 8x32), the training throughput reaches roughly ~570
 53samples on H100 GPUs.
 54"""
 55
 56from __future__ import annotations
 57
 58__all__ = [
 59    "build_model",
 60    "build_pytorch_dataloader",
 61    "build_spdl_dataloader",
 62    "load_data",
 63    "main",
 64    "train",
 65]
 66
 67# pyre-strict
 68
 69import argparse
 70import logging
 71import os
 72import time
 73from collections.abc import Callable
 74from datetime import timedelta
 75from pathlib import Path
 76
 77import torch
 78import torch.distributed as dist
 79from torch.nn.parallel import DistributedDataParallel as DDP
 80
 81try:
 82    from examples.llm_finetune.utils.dataloader import (  # pyre-ignore[21]
 83        build_pytorch_dataloader,
 84    )
 85    from examples.llm_finetune.utils.pipeline import (  # pyre-ignore[21]
 86        build_spdl_dataloader,
 87    )
 88    from examples.llm_finetune.utils.utils import (  # pyre-ignore[21]
 89        load_data,
 90        report_progress,
 91        resolve_model_path,
 92    )
 93except ImportError:
 94    from spdl.examples.llm_finetune.utils.dataloader import build_pytorch_dataloader
 95    from spdl.examples.llm_finetune.utils.pipeline import (
 96        build_spdl_dataloader,
 97    )
 98    from spdl.examples.llm_finetune.utils.utils import (
 99        load_data,
100        report_progress,
101        resolve_model_path,
102    )
103
104_LG: logging.Logger = logging.getLogger(__name__)
105
106
107# ---------------------------------------------------------------------------
108# Model setup
109# ---------------------------------------------------------------------------
110
111
112def build_model(
113    model_path: str,
114    device: torch.device,
115    lora_r: int,
116    lora_alpha: int,
117    lora_dropout: float,
118) -> torch.nn.Module:
119    """Load LLaMA model and apply LoRA."""
120    from peft import get_peft_model, LoraConfig, TaskType
121    from transformers import AutoModelForCausalLM
122
123    _LG.info("Loading model from %s", model_path)
124    model = AutoModelForCausalLM.from_pretrained(
125        model_path,
126        torch_dtype=torch.bfloat16,
127        attn_implementation="sdpa",
128    )
129
130    lora_config = LoraConfig(
131        task_type=TaskType.CAUSAL_LM,
132        r=lora_r,
133        lora_alpha=lora_alpha,
134        lora_dropout=lora_dropout,
135        target_modules=["q_proj", "v_proj"],
136    )
137    model = get_peft_model(model, lora_config)
138    model.print_trainable_parameters()
139
140    model = model.to(device)
141    return model
142
143
144# ---------------------------------------------------------------------------
145# Training
146# ---------------------------------------------------------------------------
147
148
149def train(
150    *,
151    model_path: str,
152    data_path: list[str],
153    output_dir: str,
154    max_seq_len: int,
155    batch_size: int,
156    num_epochs: int,
157    lr: float,
158    weight_decay: float,
159    max_grad_norm: float,
160    log_interval: int,
161    lora_r: int,
162    lora_alpha: int,
163    lora_dropout: float,
164    num_workers: int,
165    dataloader: str = "spdl",
166    mp_context: str = "forkserver",
167    progress_fn: Callable[[int, int], None] | None = None,
168) -> None:
169    """Main training function, called per-rank."""
170    if dist.is_initialized():
171        rank = dist.get_rank()
172        world_size = dist.get_world_size()
173    else:
174        rank = 0
175        world_size = 1
176    local_rank: int = int(os.environ.get("LOCAL_RANK", 0))
177    if torch.cuda.is_available():
178        device = torch.device(f"cuda:{local_rank}")
179        torch.cuda.set_device(device)
180    else:
181        device = torch.device("cpu")
182
183    _LG.info(
184        "Rank %d/%d on device %s (dataloader=%s)",
185        rank,
186        world_size,
187        device,
188        dataloader,
189    )
190
191    # --- Data ---
192    samples = load_data(data_path)
193
194    from transformers import AutoTokenizer
195
196    tokenizer = AutoTokenizer.from_pretrained(model_path)
197    if tokenizer.pad_token is None:
198        tokenizer.pad_token = tokenizer.eos_token
199
200    # --- Model ---
201    model = build_model(
202        model_path,
203        device,
204        lora_r=lora_r,
205        lora_alpha=lora_alpha,
206        lora_dropout=lora_dropout,
207    )
208    if dist.is_initialized():
209        ddp_model = DDP(
210            model,
211            device_ids=[local_rank] if torch.cuda.is_available() else None,
212        )
213    else:
214        ddp_model = model
215
216    # --- Optimizer ---
217    optimizer = torch.optim.AdamW(
218        ddp_model.parameters(),
219        lr=lr,
220        weight_decay=weight_decay,
221        foreach=True,
222    )
223
224    num_steps_per_epoch = len(samples) // (batch_size * world_size)
225    total_steps = num_steps_per_epoch * num_epochs
226    if rank == 0:
227        _LG.info(
228            "Training: %d samples, %d epochs, %d steps/epoch, %d total steps",
229            len(samples),
230            num_epochs,
231            num_steps_per_epoch,
232            total_steps,
233        )
234        if progress_fn is not None:
235            progress_fn(0, total_steps)
236
237    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
238        optimizer,
239        T_max=total_steps,
240        eta_min=lr * 0.1,
241    )
242
243    # --- Build data source ---
244    if dataloader == "pytorch":
245        dl = build_pytorch_dataloader(
246            samples=samples,
247            tokenizer=tokenizer,
248            max_seq_len=max_seq_len,
249            batch_size=batch_size,
250            rank=rank,
251            world_size=world_size,
252            num_workers=num_workers,
253            mp_context=mp_context,
254            device=device,
255        )
256    else:
257        dl = build_spdl_dataloader(
258            samples=samples,
259            tokenizer=tokenizer,
260            max_seq_len=max_seq_len,
261            batch_size=batch_size,
262            rank=rank,
263            world_size=world_size,
264            num_threads=num_workers,
265            mp_context=mp_context,
266        )
267
268    # --- Training loop ---
269    global_step = 0
270    ddp_model.train()
271
272    for epoch in range(num_epochs):
273        _LG.info("Epoch %d/%d", epoch + 1, num_epochs)
274
275        t0 = time.monotonic()
276        epoch_loss = 0.0
277        num_batches = 0
278
279        for batch in dl:
280            outputs = ddp_model(
281                input_ids=batch["input_ids"],
282                attention_mask=batch["attention_mask"],
283                labels=batch["labels"],
284            )
285            loss = outputs.loss
286
287            loss.backward()
288            torch.nn.utils.clip_grad_norm_(ddp_model.parameters(), max_grad_norm)
289            optimizer.step()
290            scheduler.step()
291            optimizer.zero_grad()
292
293            epoch_loss += loss.item()
294            num_batches += 1
295            global_step += 1
296
297            if rank == 0:
298                if progress_fn is not None:
299                    progress_fn(global_step, total_steps)
300                if global_step % log_interval == 0:
301                    avg_loss = epoch_loss / num_batches
302                    elapsed = time.monotonic() - t0
303                    _LG.info(
304                        "Step %d | loss=%.4f | lr=%.2e | %.1f samples/s",
305                        global_step,
306                        avg_loss,
307                        scheduler.get_last_lr()[0],
308                        num_batches * batch_size * world_size / elapsed,
309                    )
310
311        elapsed = time.monotonic() - t0
312        if rank == 0:
313            avg_loss = epoch_loss / max(num_batches, 1)
314            _LG.info(
315                "Epoch %d complete | avg_loss=%.4f | %.1fs | %.1f samples/s",
316                epoch + 1,
317                avg_loss,
318                elapsed,
319                num_batches * batch_size * world_size / elapsed,
320            )
321
322    # --- Save ---
323    if rank == 0 and output_dir:
324        output_path = Path(output_dir)
325        output_path.mkdir(parents=True, exist_ok=True)
326        model.save_pretrained(output_path)  # pyre-ignore[29]
327        tokenizer.save_pretrained(output_path)
328        _LG.info("Model saved to %s", output_path)
329
330
331def parse_args() -> argparse.Namespace:
332    parser = argparse.ArgumentParser(description=__doc__)
333    # Model
334    parser.add_argument(
335        "--model-path",
336        type=str,
337        required=True,
338        help="Path to pretrained LLaMA model directory",
339    )
340    parser.add_argument(
341        "--output-dir",
342        type=str,
343        default="",
344        help="Directory to save fine-tuned LoRA weights",
345    )
346    # Data
347    parser.add_argument(
348        "--data-path",
349        type=str,
350        nargs="+",
351        required=True,
352        help="One or more paths to Alpaca-format JSONL files (local or manifold://).",
353    )
354    parser.add_argument("--max-seq-len", type=int, default=512)
355    # Training
356    parser.add_argument("--batch-size", type=int, default=4)
357    parser.add_argument("--num-epochs", type=int, default=10)
358    parser.add_argument("--lr", type=float, default=5e-4)
359    parser.add_argument("--weight-decay", type=float, default=0.01)
360    parser.add_argument("--max-grad-norm", type=float, default=1.0)
361    parser.add_argument("--log-interval", type=int, default=10)
362    # LoRA
363    parser.add_argument("--lora-r", type=int, default=8)
364    parser.add_argument("--lora-alpha", type=int, default=16)
365    parser.add_argument("--lora-dropout", type=float, default=0.05)
366    # Pipeline
367    parser.add_argument(
368        "--num-workers",
369        type=int,
370        default=8,
371        help="Concurrent tokenization workers in the data pipeline",
372    )
373    parser.add_argument(
374        "--dataloader",
375        type=str,
376        choices=["spdl", "pytorch"],
377        default="spdl",
378        help="Data loading backend: 'spdl' (default) or 'pytorch' (torch DataLoader)",
379    )
380    parser.add_argument(
381        "--mp-context",
382        type=str,
383        choices=["fork", "spawn", "forkserver"],
384        default="forkserver",
385        help="Multiprocessing context for workers (default: forkserver)",
386    )
387    return parser.parse_args()
388
389
390def init_logging() -> None:
391    """Initialize logging."""
392    rank = os.environ.get("RANK", "?")
393    logging.basicConfig(
394        level=logging.INFO,
395        format=f"%(asctime)s [%(levelname).1s] [Rank{rank}] %(name)s: %(message)s",
396    )
397
398
399def main(args: argparse.Namespace) -> None:
400    use_distributed = "RANK" in os.environ
401    if use_distributed:
402        backend = "nccl" if torch.cuda.is_available() else "gloo"
403        dist.init_process_group(backend=backend, timeout=timedelta(minutes=3))
404    try:
405        train(
406            model_path=resolve_model_path(args.model_path),
407            data_path=args.data_path,
408            output_dir=args.output_dir,
409            max_seq_len=args.max_seq_len,
410            batch_size=args.batch_size,
411            num_epochs=args.num_epochs,
412            lr=args.lr,
413            weight_decay=args.weight_decay,
414            max_grad_norm=args.max_grad_norm,
415            log_interval=args.log_interval,
416            lora_r=args.lora_r,
417            lora_alpha=args.lora_alpha,
418            lora_dropout=args.lora_dropout,
419            num_workers=args.num_workers,
420            dataloader=args.dataloader,
421            mp_context=args.mp_context,
422            progress_fn=report_progress,
423        )
424    finally:
425        if use_distributed:
426            dist.destroy_process_group()
427
428
429if __name__ == "__main__":
430    init_logging()
431    main(parse_args())

API Reference

Functions

build_model(model_path: str, device: device, lora_r: int, lora_alpha: int, lora_dropout: float) Module[source]

Load LLaMA model and apply LoRA.

build_pytorch_dataloader(samples: list[dict[str, str]], tokenizer: PreTrainedTokenizerBase, max_seq_len: int, batch_size: int, rank: int, world_size: int, num_workers: int, mp_context: str = 'forkserver', device: torch.device | None = None) _TDataLoader[source]

Build a reusable PyTorch DataLoader for distributed LLM fine-tuning.

Build once before the training loop and reuse across epochs. The returned wrapper automatically calls DistributedSampler.set_epoch on each iteration, so no manual set_epoch call is needed. Similarly, yielded batches are automatically transferred to the given device, if any, so no manual to(device) call is needed.

build_spdl_dataloader(samples: list[dict[str, str]], tokenizer: PreTrainedTokenizerBase, max_seq_len: int, batch_size: int, rank: int, world_size: int, num_threads: int, mp_context: str = 'forkserver') _TDataLoader[source]

Build a reusable SPDL data loader with nested pipeline architecture.

Creates two nested pipelines to separate CPU-bound data loading from GPU transfer, reducing the noisy-neighbour effect where data loading threads in the main process compete with the training loop for CPU time, delaying GPU kernel launches.

Inner pipeline (runs in a subprocess):

Sampling → lookup → tokenize (concurrent) → aggregate → collate. All CPU work runs in a dedicated subprocess with its own thread pool, completely isolating it from the training process. The subprocess is created once and reused across epochs — each for ... in call rebuilds the pipeline inside the same subprocess.

Outer pipeline (runs in the main process):

Receives CPU batches from the subprocess via IPC queue and transfers them to GPU using transfer_tensor with a dedicated single-thread executor. This ensures GPU transfer uses a consistent CUDA stream and overlaps with training computation.

Build once before the training loop and iterate each epoch:

dataloader = build_spdl_dataloader(samples, tokenizer, ...)
for epoch in range(num_epochs):
    for batch in dataloader:
        train(batch)
load_data(paths: Sequence[str]) list[dict[str, str]][source]

Load and concatenate data from one or more JSONL files.

main(args: Namespace) None[source]
train(*, model_path: str, data_path: list[str], output_dir: str, max_seq_len: int, batch_size: int, num_epochs: int, lr: float, weight_decay: float, max_grad_norm: float, log_interval: int, lora_r: int, lora_alpha: int, lora_dropout: float, num_workers: int, dataloader: str = 'spdl', mp_context: str = 'forkserver', progress_fn: Callable[[int, int], None] | None = None) None[source]

Main training function, called per-rank.