Llm finetuning¶
LLM fine-tuning with SPDL data loading pipeline.
Fine-tunes LLaMA 3.2 1B on Alpaca-style instruction data using LoRA, with SPDL PipelineBuilder for high-performance concurrent tokenization.
SPDL Data Pipeline¶
The core of this example is the SPDL data loading pipeline:
DistributedRandomSampler– distributes sample indices across ranks with per-epoch reshufflingpipe(tokenize, concurrency=N)– concurrent Alpaca-format prompt formatting and tokenizationaggregate(batch_size)– groups into batchespipe(collate)– stacks tensorsadd_sink(buffer_size=3)– prefetch buffer for the training loop
Data¶
Download instruction-following datasets:
# https://github.com/tatsu-lab/stanford_alpaca
python download_alpaca.py --output /tmp/alpaca.jsonl
# https://huggingface.co/datasets/databricks/databricks-dolly-15k
python download_dolly.py --output /tmp/dolly.jsonl
Data format (JSONL with Alpaca-style fields):
{"instruction": "Explain what a linked list is.", "input": "", "output": "A linked list is..."}
Usage¶
torchrun \
--nproc_per_node 8 \
-m spdl.examples.llm_finetune.llm_finetuning \
--model-path /path/to/Llama-3.2-1B-Instruct \
--data-path \
/tmp/alpaca.jsonl \
/tmp/dolly.jsonl
With the default settings (global batch size 8x32), the training throughput reaches roughly ~570 samples on H100 GPUs.
Source¶
Source
Click here to see the source.
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7"""LLM fine-tuning with SPDL data loading pipeline.
8
9Fine-tunes LLaMA 3.2 1B on Alpaca-style instruction data using LoRA,
10with SPDL PipelineBuilder for high-performance concurrent tokenization.
11
12SPDL Data Pipeline
13^^^^^^^^^^^^^^^^^^
14
15The core of this example is the SPDL data loading pipeline:
16
171. ``DistributedRandomSampler`` -- distributes sample indices across ranks
18 with per-epoch reshuffling
192. ``pipe(tokenize, concurrency=N)`` -- concurrent Alpaca-format prompt
20 formatting and tokenization
213. ``aggregate(batch_size)`` -- groups into batches
224. ``pipe(collate)`` -- stacks tensors
235. ``add_sink(buffer_size=3)`` -- prefetch buffer for the training loop
24
25Data
26^^^^
27
28Download instruction-following datasets::
29
30 # https://github.com/tatsu-lab/stanford_alpaca
31 python download_alpaca.py --output /tmp/alpaca.jsonl
32 # https://huggingface.co/datasets/databricks/databricks-dolly-15k
33 python download_dolly.py --output /tmp/dolly.jsonl
34
35Data format (JSONL with Alpaca-style fields)::
36
37 {"instruction": "Explain what a linked list is.", "input": "", "output": "A linked list is..."}
38
39Usage
40^^^^^
41
42::
43
44 torchrun \\
45 --nproc_per_node 8 \\
46 -m spdl.examples.llm_finetune.llm_finetuning \\
47 --model-path /path/to/Llama-3.2-1B-Instruct \\
48 --data-path \\
49 /tmp/alpaca.jsonl \\
50 /tmp/dolly.jsonl
51
52With the default settings (global batch size 8x32), the training throughput reaches roughly ~570
53samples on H100 GPUs.
54"""
55
56from __future__ import annotations
57
58__all__ = [
59 "build_model",
60 "build_pytorch_dataloader",
61 "build_spdl_dataloader",
62 "load_data",
63 "main",
64 "train",
65]
66
67# pyre-strict
68
69import argparse
70import logging
71import os
72import time
73from collections.abc import Callable
74from datetime import timedelta
75from pathlib import Path
76
77import torch
78import torch.distributed as dist
79from torch.nn.parallel import DistributedDataParallel as DDP
80
81try:
82 from examples.llm_finetune.utils.dataloader import ( # pyre-ignore[21]
83 build_pytorch_dataloader,
84 )
85 from examples.llm_finetune.utils.pipeline import ( # pyre-ignore[21]
86 build_spdl_dataloader,
87 )
88 from examples.llm_finetune.utils.utils import ( # pyre-ignore[21]
89 load_data,
90 report_progress,
91 resolve_model_path,
92 )
93except ImportError:
94 from spdl.examples.llm_finetune.utils.dataloader import build_pytorch_dataloader
95 from spdl.examples.llm_finetune.utils.pipeline import (
96 build_spdl_dataloader,
97 )
98 from spdl.examples.llm_finetune.utils.utils import (
99 load_data,
100 report_progress,
101 resolve_model_path,
102 )
103
104_LG: logging.Logger = logging.getLogger(__name__)
105
106
107# ---------------------------------------------------------------------------
108# Model setup
109# ---------------------------------------------------------------------------
110
111
112def build_model(
113 model_path: str,
114 device: torch.device,
115 lora_r: int,
116 lora_alpha: int,
117 lora_dropout: float,
118) -> torch.nn.Module:
119 """Load LLaMA model and apply LoRA."""
120 from peft import get_peft_model, LoraConfig, TaskType
121 from transformers import AutoModelForCausalLM
122
123 _LG.info("Loading model from %s", model_path)
124 model = AutoModelForCausalLM.from_pretrained(
125 model_path,
126 torch_dtype=torch.bfloat16,
127 attn_implementation="sdpa",
128 )
129
130 lora_config = LoraConfig(
131 task_type=TaskType.CAUSAL_LM,
132 r=lora_r,
133 lora_alpha=lora_alpha,
134 lora_dropout=lora_dropout,
135 target_modules=["q_proj", "v_proj"],
136 )
137 model = get_peft_model(model, lora_config)
138 model.print_trainable_parameters()
139
140 model = model.to(device)
141 return model
142
143
144# ---------------------------------------------------------------------------
145# Training
146# ---------------------------------------------------------------------------
147
148
149def train(
150 *,
151 model_path: str,
152 data_path: list[str],
153 output_dir: str,
154 max_seq_len: int,
155 batch_size: int,
156 num_epochs: int,
157 lr: float,
158 weight_decay: float,
159 max_grad_norm: float,
160 log_interval: int,
161 lora_r: int,
162 lora_alpha: int,
163 lora_dropout: float,
164 num_workers: int,
165 dataloader: str = "spdl",
166 mp_context: str = "forkserver",
167 progress_fn: Callable[[int, int], None] | None = None,
168) -> None:
169 """Main training function, called per-rank."""
170 if dist.is_initialized():
171 rank = dist.get_rank()
172 world_size = dist.get_world_size()
173 else:
174 rank = 0
175 world_size = 1
176 local_rank: int = int(os.environ.get("LOCAL_RANK", 0))
177 if torch.cuda.is_available():
178 device = torch.device(f"cuda:{local_rank}")
179 torch.cuda.set_device(device)
180 else:
181 device = torch.device("cpu")
182
183 _LG.info(
184 "Rank %d/%d on device %s (dataloader=%s)",
185 rank,
186 world_size,
187 device,
188 dataloader,
189 )
190
191 # --- Data ---
192 samples = load_data(data_path)
193
194 from transformers import AutoTokenizer
195
196 tokenizer = AutoTokenizer.from_pretrained(model_path)
197 if tokenizer.pad_token is None:
198 tokenizer.pad_token = tokenizer.eos_token
199
200 # --- Model ---
201 model = build_model(
202 model_path,
203 device,
204 lora_r=lora_r,
205 lora_alpha=lora_alpha,
206 lora_dropout=lora_dropout,
207 )
208 if dist.is_initialized():
209 ddp_model = DDP(
210 model,
211 device_ids=[local_rank] if torch.cuda.is_available() else None,
212 )
213 else:
214 ddp_model = model
215
216 # --- Optimizer ---
217 optimizer = torch.optim.AdamW(
218 ddp_model.parameters(),
219 lr=lr,
220 weight_decay=weight_decay,
221 foreach=True,
222 )
223
224 num_steps_per_epoch = len(samples) // (batch_size * world_size)
225 total_steps = num_steps_per_epoch * num_epochs
226 if rank == 0:
227 _LG.info(
228 "Training: %d samples, %d epochs, %d steps/epoch, %d total steps",
229 len(samples),
230 num_epochs,
231 num_steps_per_epoch,
232 total_steps,
233 )
234 if progress_fn is not None:
235 progress_fn(0, total_steps)
236
237 scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
238 optimizer,
239 T_max=total_steps,
240 eta_min=lr * 0.1,
241 )
242
243 # --- Build data source ---
244 if dataloader == "pytorch":
245 dl = build_pytorch_dataloader(
246 samples=samples,
247 tokenizer=tokenizer,
248 max_seq_len=max_seq_len,
249 batch_size=batch_size,
250 rank=rank,
251 world_size=world_size,
252 num_workers=num_workers,
253 mp_context=mp_context,
254 device=device,
255 )
256 else:
257 dl = build_spdl_dataloader(
258 samples=samples,
259 tokenizer=tokenizer,
260 max_seq_len=max_seq_len,
261 batch_size=batch_size,
262 rank=rank,
263 world_size=world_size,
264 num_threads=num_workers,
265 mp_context=mp_context,
266 )
267
268 # --- Training loop ---
269 global_step = 0
270 ddp_model.train()
271
272 for epoch in range(num_epochs):
273 _LG.info("Epoch %d/%d", epoch + 1, num_epochs)
274
275 t0 = time.monotonic()
276 epoch_loss = 0.0
277 num_batches = 0
278
279 for batch in dl:
280 outputs = ddp_model(
281 input_ids=batch["input_ids"],
282 attention_mask=batch["attention_mask"],
283 labels=batch["labels"],
284 )
285 loss = outputs.loss
286
287 loss.backward()
288 torch.nn.utils.clip_grad_norm_(ddp_model.parameters(), max_grad_norm)
289 optimizer.step()
290 scheduler.step()
291 optimizer.zero_grad()
292
293 epoch_loss += loss.item()
294 num_batches += 1
295 global_step += 1
296
297 if rank == 0:
298 if progress_fn is not None:
299 progress_fn(global_step, total_steps)
300 if global_step % log_interval == 0:
301 avg_loss = epoch_loss / num_batches
302 elapsed = time.monotonic() - t0
303 _LG.info(
304 "Step %d | loss=%.4f | lr=%.2e | %.1f samples/s",
305 global_step,
306 avg_loss,
307 scheduler.get_last_lr()[0],
308 num_batches * batch_size * world_size / elapsed,
309 )
310
311 elapsed = time.monotonic() - t0
312 if rank == 0:
313 avg_loss = epoch_loss / max(num_batches, 1)
314 _LG.info(
315 "Epoch %d complete | avg_loss=%.4f | %.1fs | %.1f samples/s",
316 epoch + 1,
317 avg_loss,
318 elapsed,
319 num_batches * batch_size * world_size / elapsed,
320 )
321
322 # --- Save ---
323 if rank == 0 and output_dir:
324 output_path = Path(output_dir)
325 output_path.mkdir(parents=True, exist_ok=True)
326 model.save_pretrained(output_path) # pyre-ignore[29]
327 tokenizer.save_pretrained(output_path)
328 _LG.info("Model saved to %s", output_path)
329
330
331def parse_args() -> argparse.Namespace:
332 parser = argparse.ArgumentParser(description=__doc__)
333 # Model
334 parser.add_argument(
335 "--model-path",
336 type=str,
337 required=True,
338 help="Path to pretrained LLaMA model directory",
339 )
340 parser.add_argument(
341 "--output-dir",
342 type=str,
343 default="",
344 help="Directory to save fine-tuned LoRA weights",
345 )
346 # Data
347 parser.add_argument(
348 "--data-path",
349 type=str,
350 nargs="+",
351 required=True,
352 help="One or more paths to Alpaca-format JSONL files (local or manifold://).",
353 )
354 parser.add_argument("--max-seq-len", type=int, default=512)
355 # Training
356 parser.add_argument("--batch-size", type=int, default=4)
357 parser.add_argument("--num-epochs", type=int, default=10)
358 parser.add_argument("--lr", type=float, default=5e-4)
359 parser.add_argument("--weight-decay", type=float, default=0.01)
360 parser.add_argument("--max-grad-norm", type=float, default=1.0)
361 parser.add_argument("--log-interval", type=int, default=10)
362 # LoRA
363 parser.add_argument("--lora-r", type=int, default=8)
364 parser.add_argument("--lora-alpha", type=int, default=16)
365 parser.add_argument("--lora-dropout", type=float, default=0.05)
366 # Pipeline
367 parser.add_argument(
368 "--num-workers",
369 type=int,
370 default=8,
371 help="Concurrent tokenization workers in the data pipeline",
372 )
373 parser.add_argument(
374 "--dataloader",
375 type=str,
376 choices=["spdl", "pytorch"],
377 default="spdl",
378 help="Data loading backend: 'spdl' (default) or 'pytorch' (torch DataLoader)",
379 )
380 parser.add_argument(
381 "--mp-context",
382 type=str,
383 choices=["fork", "spawn", "forkserver"],
384 default="forkserver",
385 help="Multiprocessing context for workers (default: forkserver)",
386 )
387 return parser.parse_args()
388
389
390def init_logging() -> None:
391 """Initialize logging."""
392 rank = os.environ.get("RANK", "?")
393 logging.basicConfig(
394 level=logging.INFO,
395 format=f"%(asctime)s [%(levelname).1s] [Rank{rank}] %(name)s: %(message)s",
396 )
397
398
399def main(args: argparse.Namespace) -> None:
400 use_distributed = "RANK" in os.environ
401 if use_distributed:
402 backend = "nccl" if torch.cuda.is_available() else "gloo"
403 dist.init_process_group(backend=backend, timeout=timedelta(minutes=3))
404 try:
405 train(
406 model_path=resolve_model_path(args.model_path),
407 data_path=args.data_path,
408 output_dir=args.output_dir,
409 max_seq_len=args.max_seq_len,
410 batch_size=args.batch_size,
411 num_epochs=args.num_epochs,
412 lr=args.lr,
413 weight_decay=args.weight_decay,
414 max_grad_norm=args.max_grad_norm,
415 log_interval=args.log_interval,
416 lora_r=args.lora_r,
417 lora_alpha=args.lora_alpha,
418 lora_dropout=args.lora_dropout,
419 num_workers=args.num_workers,
420 dataloader=args.dataloader,
421 mp_context=args.mp_context,
422 progress_fn=report_progress,
423 )
424 finally:
425 if use_distributed:
426 dist.destroy_process_group()
427
428
429if __name__ == "__main__":
430 init_logging()
431 main(parse_args())
API Reference¶
Functions
- build_model(model_path: str, device: device, lora_r: int, lora_alpha: int, lora_dropout: float) Module[source]¶
Load LLaMA model and apply LoRA.
- build_pytorch_dataloader(samples: list[dict[str, str]], tokenizer: PreTrainedTokenizerBase, max_seq_len: int, batch_size: int, rank: int, world_size: int, num_workers: int, mp_context: str = 'forkserver', device: torch.device | None = None) _TDataLoader[source]¶
Build a reusable PyTorch DataLoader for distributed LLM fine-tuning.
Build once before the training loop and reuse across epochs. The returned wrapper automatically calls
DistributedSampler.set_epochon each iteration, so no manualset_epochcall is needed. Similarly, yielded batches are automatically transferred to the given device, if any, so no manualto(device)call is needed.
- build_spdl_dataloader(samples: list[dict[str, str]], tokenizer: PreTrainedTokenizerBase, max_seq_len: int, batch_size: int, rank: int, world_size: int, num_threads: int, mp_context: str = 'forkserver') _TDataLoader[source]¶
Build a reusable SPDL data loader with nested pipeline architecture.
Creates two nested pipelines to separate CPU-bound data loading from GPU transfer, reducing the noisy-neighbour effect where data loading threads in the main process compete with the training loop for CPU time, delaying GPU kernel launches.
- Inner pipeline (runs in a subprocess):
Sampling → lookup → tokenize (concurrent) → aggregate → collate. All CPU work runs in a dedicated subprocess with its own thread pool, completely isolating it from the training process. The subprocess is created once and reused across epochs — each
for ... incall rebuilds the pipeline inside the same subprocess.- Outer pipeline (runs in the main process):
Receives CPU batches from the subprocess via IPC queue and transfers them to GPU using
transfer_tensorwith a dedicated single-thread executor. This ensures GPU transfer uses a consistent CUDA stream and overlaps with training computation.
Build once before the training loop and iterate each epoch:
dataloader = build_spdl_dataloader(samples, tokenizer, ...) for epoch in range(num_epochs): for batch in dataloader: train(batch)
- load_data(paths: Sequence[str]) list[dict[str, str]][source]¶
Load and concatenate data from one or more JSONL files.
- train(*, model_path: str, data_path: list[str], output_dir: str, max_seq_len: int, batch_size: int, num_epochs: int, lr: float, weight_decay: float, max_grad_norm: float, log_interval: int, lora_r: int, lora_alpha: int, lora_dropout: float, num_workers: int, dataloader: str = 'spdl', mp_context: str = 'forkserver', progress_fn: Callable[[int, int], None] | None = None) None[source]¶
Main training function, called per-rank.