Llm finetuning¶
LLM fine-tuning with SPDL data loading pipeline.
Fine-tunes LLaMA 3.2 1B on Alpaca-style instruction data using LoRA, with SPDL PipelineBuilder for high-performance concurrent tokenization.
SPDL Data Pipeline¶
The core of this example is the SPDL data loading pipeline:
DistributedRandomSampler– distributes sample indices across ranks with per-epoch reshufflingpipe(tokenize, concurrency=N)– concurrent Alpaca-format prompt formatting and tokenizationaggregate(batch_size)– groups into batchespipe(collate)– stacks tensorsadd_sink(buffer_size=3)– prefetch buffer for the training loop
Data¶
Download instruction-following datasets:
# https://github.com/tatsu-lab/stanford_alpaca
python download_alpaca.py --output /tmp/alpaca.jsonl
# https://huggingface.co/datasets/databricks/databricks-dolly-15k
python download_dolly.py --output /tmp/dolly.jsonl
Data format (JSONL with Alpaca-style fields):
{"instruction": "Explain what a linked list is.", "input": "", "output": "A linked list is..."}
Usage¶
torchrun \
--nproc_per_node 8 \
-m spdl.examples.llm_finetune.llm_finetuning \
--model-path /path/to/Llama-3.2-1B-Instruct \
--data-path \
/tmp/alpaca.jsonl \
/tmp/dolly.jsonl
With the default settings (global batch size 8x32), the training throughput reaches roughly ~570 samples on H100 GPUs.
Source¶
Source
Click here to see the source.
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7"""LLM fine-tuning with SPDL data loading pipeline.
8
9Fine-tunes LLaMA 3.2 1B on Alpaca-style instruction data using LoRA,
10with SPDL PipelineBuilder for high-performance concurrent tokenization.
11
12SPDL Data Pipeline
13^^^^^^^^^^^^^^^^^^
14
15The core of this example is the SPDL data loading pipeline:
16
171. ``DistributedRandomSampler`` -- distributes sample indices across ranks
18 with per-epoch reshuffling
192. ``pipe(tokenize, concurrency=N)`` -- concurrent Alpaca-format prompt
20 formatting and tokenization
213. ``aggregate(batch_size)`` -- groups into batches
224. ``pipe(collate)`` -- stacks tensors
235. ``add_sink(buffer_size=3)`` -- prefetch buffer for the training loop
24
25Data
26^^^^
27
28Download instruction-following datasets::
29
30 # https://github.com/tatsu-lab/stanford_alpaca
31 python download_alpaca.py --output /tmp/alpaca.jsonl
32 # https://huggingface.co/datasets/databricks/databricks-dolly-15k
33 python download_dolly.py --output /tmp/dolly.jsonl
34
35Data format (JSONL with Alpaca-style fields)::
36
37 {"instruction": "Explain what a linked list is.", "input": "", "output": "A linked list is..."}
38
39Usage
40^^^^^
41
42::
43
44 torchrun \\
45 --nproc_per_node 8 \\
46 -m spdl.examples.llm_finetune.llm_finetuning \\
47 --model-path /path/to/Llama-3.2-1B-Instruct \\
48 --data-path \\
49 /tmp/alpaca.jsonl \\
50 /tmp/dolly.jsonl
51
52With the default settings (global batch size 8x32), the training throughput reaches roughly ~570
53samples on H100 GPUs.
54"""
55
56from __future__ import annotations
57
58__all__ = [
59 "build_model",
60 "build_pytorch_dataloader",
61 "build_spdl_dataloader",
62 "load_data",
63 "main",
64 "train",
65]
66
67# pyre-strict
68
69import argparse
70import logging
71import os
72import time
73from collections.abc import Callable
74from datetime import timedelta
75from pathlib import Path
76
77import torch
78import torch.distributed as dist
79from torch.nn.parallel import DistributedDataParallel as DDP
80
81try:
82 from examples.llm_finetune.utils.dataloader import ( # pyre-ignore[21]
83 build_pytorch_dataloader,
84 )
85 from examples.llm_finetune.utils.pipeline import ( # pyre-ignore[21]
86 build_spdl_dataloader,
87 )
88 from examples.llm_finetune.utils.utils import ( # pyre-ignore[21]
89 load_data,
90 report_progress,
91 resolve_model_path,
92 )
93except ImportError:
94 from spdl.examples.llm_finetune.utils.dataloader import build_pytorch_dataloader
95 from spdl.examples.llm_finetune.utils.pipeline import (
96 build_spdl_dataloader,
97 )
98 from spdl.examples.llm_finetune.utils.utils import (
99 load_data,
100 report_progress,
101 resolve_model_path,
102 )
103
104_LG: logging.Logger = logging.getLogger(__name__)
105
106
107# ---------------------------------------------------------------------------
108# Model setup
109# ---------------------------------------------------------------------------
110
111
112def build_model(
113 model_path: str,
114 device: torch.device,
115 lora_r: int,
116 lora_alpha: int,
117 lora_dropout: float,
118) -> torch.nn.Module:
119 """Load LLaMA model and apply LoRA."""
120 from peft import get_peft_model, LoraConfig, TaskType
121 from transformers import AutoModelForCausalLM
122
123 _LG.info("Loading model from %s", model_path)
124 model = AutoModelForCausalLM.from_pretrained(
125 model_path,
126 torch_dtype=torch.bfloat16,
127 attn_implementation="sdpa",
128 )
129
130 lora_config = LoraConfig(
131 task_type=TaskType.CAUSAL_LM,
132 r=lora_r,
133 lora_alpha=lora_alpha,
134 lora_dropout=lora_dropout,
135 target_modules=["q_proj", "v_proj"],
136 )
137 model = get_peft_model(model, lora_config)
138 model.print_trainable_parameters()
139
140 model = model.to(device)
141 return model
142
143
144# ---------------------------------------------------------------------------
145# Training
146# ---------------------------------------------------------------------------
147
148
149def train(
150 *,
151 model_path: str,
152 data_path: list[str],
153 output_dir: str,
154 max_seq_len: int,
155 batch_size: int,
156 num_epochs: int,
157 lr: float,
158 weight_decay: float,
159 max_grad_norm: float,
160 log_interval: int,
161 lora_r: int,
162 lora_alpha: int,
163 lora_dropout: float,
164 num_workers: int,
165 dataloader: str = "spdl",
166 mp_context: str = "forkserver",
167 progress_fn: Callable[[int, int], None] | None = None,
168) -> None:
169 """Main training function, called per-rank."""
170 rank: int = dist.get_rank()
171 world_size: int = dist.get_world_size()
172 local_rank: int = int(os.environ.get("LOCAL_RANK", 0))
173 device = torch.device(f"cuda:{local_rank}")
174 torch.cuda.set_device(device)
175
176 _LG.info(
177 "Rank %d/%d on device %s (dataloader=%s)",
178 rank,
179 world_size,
180 device,
181 dataloader,
182 )
183
184 # --- Data ---
185 samples = load_data(data_path)
186
187 from transformers import AutoTokenizer
188
189 tokenizer = AutoTokenizer.from_pretrained(model_path)
190 if tokenizer.pad_token is None:
191 tokenizer.pad_token = tokenizer.eos_token
192
193 # --- Model ---
194 model = build_model(
195 model_path,
196 device,
197 lora_r=lora_r,
198 lora_alpha=lora_alpha,
199 lora_dropout=lora_dropout,
200 )
201 ddp_model = DDP(model, device_ids=[local_rank])
202
203 # --- Optimizer ---
204 optimizer = torch.optim.AdamW(
205 ddp_model.parameters(),
206 lr=lr,
207 weight_decay=weight_decay,
208 foreach=True,
209 )
210
211 num_steps_per_epoch = len(samples) // (batch_size * world_size)
212 total_steps = num_steps_per_epoch * num_epochs
213 if rank == 0:
214 _LG.info(
215 "Training: %d samples, %d epochs, %d steps/epoch, %d total steps",
216 len(samples),
217 num_epochs,
218 num_steps_per_epoch,
219 total_steps,
220 )
221 if progress_fn is not None:
222 progress_fn(0, total_steps)
223
224 scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
225 optimizer,
226 T_max=total_steps,
227 eta_min=lr * 0.1,
228 )
229
230 # --- Build data source ---
231 if dataloader == "pytorch":
232 dl = build_pytorch_dataloader(
233 samples=samples,
234 tokenizer=tokenizer,
235 max_seq_len=max_seq_len,
236 batch_size=batch_size,
237 rank=rank,
238 world_size=world_size,
239 num_workers=num_workers,
240 mp_context=mp_context,
241 device=device,
242 )
243 else:
244 dl = build_spdl_dataloader(
245 samples=samples,
246 tokenizer=tokenizer,
247 max_seq_len=max_seq_len,
248 batch_size=batch_size,
249 rank=rank,
250 world_size=world_size,
251 num_threads=num_workers,
252 mp_context=mp_context,
253 )
254
255 # --- Training loop ---
256 global_step = 0
257 ddp_model.train()
258
259 for epoch in range(num_epochs):
260 _LG.info("Epoch %d/%d", epoch + 1, num_epochs)
261
262 t0 = time.monotonic()
263 epoch_loss = 0.0
264 num_batches = 0
265
266 for batch in dl:
267 outputs = ddp_model(
268 input_ids=batch["input_ids"],
269 attention_mask=batch["attention_mask"],
270 labels=batch["labels"],
271 )
272 loss = outputs.loss
273
274 loss.backward()
275 torch.nn.utils.clip_grad_norm_(ddp_model.parameters(), max_grad_norm)
276 optimizer.step()
277 scheduler.step()
278 optimizer.zero_grad()
279
280 epoch_loss += loss.item()
281 num_batches += 1
282 global_step += 1
283
284 if rank == 0:
285 if progress_fn is not None:
286 progress_fn(global_step, total_steps)
287 if global_step % log_interval == 0:
288 avg_loss = epoch_loss / num_batches
289 elapsed = time.monotonic() - t0
290 _LG.info(
291 "Step %d | loss=%.4f | lr=%.2e | %.1f samples/s",
292 global_step,
293 avg_loss,
294 scheduler.get_last_lr()[0],
295 num_batches * batch_size * world_size / elapsed,
296 )
297
298 elapsed = time.monotonic() - t0
299 if rank == 0:
300 avg_loss = epoch_loss / max(num_batches, 1)
301 _LG.info(
302 "Epoch %d complete | avg_loss=%.4f | %.1fs | %.1f samples/s",
303 epoch + 1,
304 avg_loss,
305 elapsed,
306 num_batches * batch_size * world_size / elapsed,
307 )
308
309 # --- Save ---
310 if rank == 0 and output_dir:
311 output_path = Path(output_dir)
312 output_path.mkdir(parents=True, exist_ok=True)
313 ddp_model.module.save_pretrained(output_path)
314 tokenizer.save_pretrained(output_path)
315 _LG.info("Model saved to %s", output_path)
316
317
318def parse_args() -> argparse.Namespace:
319 parser = argparse.ArgumentParser(description=__doc__)
320 # Model
321 parser.add_argument(
322 "--model-path",
323 type=str,
324 required=True,
325 help="Path to pretrained LLaMA model directory",
326 )
327 parser.add_argument(
328 "--output-dir",
329 type=str,
330 default="",
331 help="Directory to save fine-tuned LoRA weights",
332 )
333 # Data
334 parser.add_argument(
335 "--data-path",
336 type=str,
337 nargs="+",
338 required=True,
339 help="One or more paths to Alpaca-format JSONL files (local or manifold://).",
340 )
341 parser.add_argument("--max-seq-len", type=int, default=512)
342 # Training
343 parser.add_argument("--batch-size", type=int, default=4)
344 parser.add_argument("--num-epochs", type=int, default=10)
345 parser.add_argument("--lr", type=float, default=5e-4)
346 parser.add_argument("--weight-decay", type=float, default=0.01)
347 parser.add_argument("--max-grad-norm", type=float, default=1.0)
348 parser.add_argument("--log-interval", type=int, default=10)
349 # LoRA
350 parser.add_argument("--lora-r", type=int, default=8)
351 parser.add_argument("--lora-alpha", type=int, default=16)
352 parser.add_argument("--lora-dropout", type=float, default=0.05)
353 # Pipeline
354 parser.add_argument(
355 "--num-workers",
356 type=int,
357 default=8,
358 help="Concurrent tokenization workers in the data pipeline",
359 )
360 parser.add_argument(
361 "--dataloader",
362 type=str,
363 choices=["spdl", "pytorch"],
364 default="spdl",
365 help="Data loading backend: 'spdl' (default) or 'pytorch' (torch DataLoader)",
366 )
367 parser.add_argument(
368 "--mp-context",
369 type=str,
370 choices=["fork", "spawn", "forkserver"],
371 default="forkserver",
372 help="Multiprocessing context for workers (default: forkserver)",
373 )
374 return parser.parse_args()
375
376
377def init_logging() -> None:
378 """Initialize logging."""
379 rank = os.environ.get("RANK", "?")
380 logging.basicConfig(
381 level=logging.INFO,
382 format=f"%(asctime)s [%(levelname).1s] [Rank{rank}] %(name)s: %(message)s",
383 )
384
385
386def main(args: argparse.Namespace) -> None:
387 dist.init_process_group(backend="nccl", timeout=timedelta(minutes=3))
388 try:
389 train(
390 model_path=resolve_model_path(args.model_path),
391 data_path=args.data_path,
392 output_dir=args.output_dir,
393 max_seq_len=args.max_seq_len,
394 batch_size=args.batch_size,
395 num_epochs=args.num_epochs,
396 lr=args.lr,
397 weight_decay=args.weight_decay,
398 max_grad_norm=args.max_grad_norm,
399 log_interval=args.log_interval,
400 lora_r=args.lora_r,
401 lora_alpha=args.lora_alpha,
402 lora_dropout=args.lora_dropout,
403 num_workers=args.num_workers,
404 dataloader=args.dataloader,
405 mp_context=args.mp_context,
406 progress_fn=report_progress,
407 )
408 finally:
409 dist.destroy_process_group()
410
411
412if __name__ == "__main__":
413 init_logging()
414 main(parse_args())
API Reference¶
Functions
- build_model(model_path: str, device: device, lora_r: int, lora_alpha: int, lora_dropout: float) Module[source]¶
Load LLaMA model and apply LoRA.
- build_pytorch_dataloader(samples: list[dict[str, str]], tokenizer: PreTrainedTokenizerBase, max_seq_len: int, batch_size: int, rank: int, world_size: int, num_workers: int, mp_context: str = 'forkserver', device: torch.device | None = None) _TDataLoader[source]¶
Build a reusable PyTorch DataLoader for distributed LLM fine-tuning.
Build once before the training loop and reuse across epochs. The returned wrapper automatically calls
DistributedSampler.set_epochon each iteration, so no manualset_epochcall is needed. Similarly, yielded batches are automatically transferred to the given device, if any, so no manualto(device)call is needed.
- build_spdl_dataloader(samples: list[dict[str, str]], tokenizer: PreTrainedTokenizerBase, max_seq_len: int, batch_size: int, rank: int, world_size: int, num_threads: int, mp_context: str = 'forkserver') _TDataLoader[source]¶
Build a reusable SPDL data loader with nested pipeline architecture.
Creates two nested pipelines to separate CPU-bound data loading from GPU transfer, reducing the noisy-neighbour effect where data loading threads in the main process compete with the training loop for CPU time, delaying GPU kernel launches.
- Inner pipeline (runs in a subprocess):
Sampling → lookup → tokenize (concurrent) → aggregate → collate. All CPU work runs in a dedicated subprocess with its own thread pool, completely isolating it from the training process. The subprocess is created once and reused across epochs — each
for ... incall rebuilds the pipeline inside the same subprocess.- Outer pipeline (runs in the main process):
Receives CPU batches from the subprocess via IPC queue and transfers them to GPU using
transfer_tensorwith a dedicated single-thread executor. This ensures GPU transfer uses a consistent CUDA stream and overlaps with training computation.
Build once before the training loop and iterate each epoch:
dataloader = build_spdl_dataloader(samples, tokenizer, ...) for epoch in range(num_epochs): for batch in dataloader: train(batch)
- load_data(paths: Sequence[str]) list[dict[str, str]][source]¶
Load and concatenate data from one or more JSONL files.
- train(*, model_path: str, data_path: list[str], output_dir: str, max_seq_len: int, batch_size: int, num_epochs: int, lr: float, weight_decay: float, max_grad_norm: float, log_interval: int, lora_r: int, lora_alpha: int, lora_dropout: float, num_workers: int, dataloader: str = 'spdl', mp_context: str = 'forkserver', progress_fn: Callable[[int, int], None] | None = None) None[source]¶
Main training function, called per-rank.