Llm finetuning¶
LLM fine-tuning with SPDL data loading pipeline.
Fine-tunes LLaMA 3.2 1B on Alpaca-style instruction data using LoRA, with SPDL PipelineBuilder for high-performance concurrent tokenization.
SPDL Data Pipeline¶
The core of this example is the SPDL data loading pipeline:
DistributedRandomSampler– distributes sample indices across ranks with per-epoch reshufflingpipe(tokenize, concurrency=N)– concurrent Alpaca-format prompt formatting and tokenizationaggregate(batch_size)– groups into batchespipe(collate)– stacks tensorsadd_sink(buffer_size=3)– prefetch buffer for the training loop
Data¶
Download instruction-following datasets:
# https://github.com/tatsu-lab/stanford_alpaca
python download_alpaca.py --output /tmp/alpaca.jsonl
# https://huggingface.co/datasets/databricks/databricks-dolly-15k
python download_dolly.py --output /tmp/dolly.jsonl
Data format (JSONL with Alpaca-style fields):
{"instruction": "Explain what a linked list is.", "input": "", "output": "A linked list is..."}
Usage¶
torchrun \
--nproc_per_node 8 \
-m spdl.examples.llm_finetune.llm_finetuning \
--model-path /path/to/Llama-3.2-1B-Instruct \
--data-path \
/tmp/alpaca.jsonl \
/tmp/dolly.jsonl
With the default settings (global batch size 8x32), the training throughput reaches roughly ~570 samples on H100 GPUs.
Source¶
Source
Click here to see the source.
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7"""LLM fine-tuning with SPDL data loading pipeline.
8
9Fine-tunes LLaMA 3.2 1B on Alpaca-style instruction data using LoRA,
10with SPDL PipelineBuilder for high-performance concurrent tokenization.
11
12SPDL Data Pipeline
13^^^^^^^^^^^^^^^^^^
14
15The core of this example is the SPDL data loading pipeline:
16
171. ``DistributedRandomSampler`` -- distributes sample indices across ranks
18 with per-epoch reshuffling
192. ``pipe(tokenize, concurrency=N)`` -- concurrent Alpaca-format prompt
20 formatting and tokenization
213. ``aggregate(batch_size)`` -- groups into batches
224. ``pipe(collate)`` -- stacks tensors
235. ``add_sink(buffer_size=3)`` -- prefetch buffer for the training loop
24
25Data
26^^^^
27
28Download instruction-following datasets::
29
30 # https://github.com/tatsu-lab/stanford_alpaca
31 python download_alpaca.py --output /tmp/alpaca.jsonl
32 # https://huggingface.co/datasets/databricks/databricks-dolly-15k
33 python download_dolly.py --output /tmp/dolly.jsonl
34
35Data format (JSONL with Alpaca-style fields)::
36
37 {"instruction": "Explain what a linked list is.", "input": "", "output": "A linked list is..."}
38
39Usage
40^^^^^
41
42::
43
44 torchrun \\
45 --nproc_per_node 8 \\
46 -m spdl.examples.llm_finetune.llm_finetuning \\
47 --model-path /path/to/Llama-3.2-1B-Instruct \\
48 --data-path \\
49 /tmp/alpaca.jsonl \\
50 /tmp/dolly.jsonl
51
52With the default settings (global batch size 8x32), the training throughput reaches roughly ~570
53samples on H100 GPUs.
54"""
55
56from __future__ import annotations
57
58__all__ = [
59 "build_model",
60 "build_pipeline",
61 "iterate_pipeline",
62 "load_data",
63 "main",
64 "train",
65]
66
67# pyre-strict
68
69import argparse
70import logging
71import os
72import threading
73import time
74from collections.abc import Callable, Iterator
75from datetime import timedelta
76from pathlib import Path
77from typing import TYPE_CHECKING
78
79import torch
80import torch.distributed as dist
81from spdl.pipeline import PipelineBuilder
82from spdl.source import DistributedRandomSampler
83from torch.nn.parallel import DistributedDataParallel as DDP
84
85if TYPE_CHECKING:
86 from transformers import PreTrainedTokenizerBase
87
88try:
89 from examples.llm_finetune.utils.utils import ( # pyre-ignore[21]
90 format_prompt,
91 load_data,
92 report_progress,
93 resolve_model_path,
94 )
95except ImportError:
96 from spdl.examples.llm_finetune.utils.utils import (
97 format_prompt,
98 load_data,
99 report_progress,
100 resolve_model_path,
101 )
102
103_LG: logging.Logger = logging.getLogger(__name__)
104
105
106# ---------------------------------------------------------------------------
107# SPDL data pipeline
108# ---------------------------------------------------------------------------
109
110
111def build_pipeline(
112 samples: list[dict[str, str]],
113 tokenizer: PreTrainedTokenizerBase,
114 max_seq_len: int,
115 batch_size: int,
116 rank: int,
117 world_size: int,
118 num_threads: int,
119 seed: int,
120) -> PipelineBuilder:
121 """Build an SPDL pipeline for concurrent tokenization.
122
123 Pipeline stages:
124 1. Source: DistributedRandomSampler yields sample indices
125 2. Pipe: Look up sample by index
126 3. Pipe (concurrent): Format prompt and tokenize
127 4. Aggregate: Group into batches
128 5. Pipe: Collate into tensors
129 6. Sink: Buffer for the training loop
130 """
131
132 sampler = DistributedRandomSampler(
133 len(samples),
134 rank=rank,
135 world_size=world_size,
136 seed=seed,
137 )
138
139 # The HuggingFace fast tokenizer's Rust backend is not thread-safe.
140 # Use thread-local copies so each SPDL worker thread has its own instance.
141 class _TokenizerTLS(threading.local):
142 tokenizer: PreTrainedTokenizerBase | None = None
143
144 _tls: _TokenizerTLS = _TokenizerTLS()
145
146 def _get_tokenizer() -> PreTrainedTokenizerBase:
147 if _tls.tokenizer is None:
148 import copy
149
150 _tls.tokenizer = copy.deepcopy(tokenizer)
151 return _tls.tokenizer
152
153 def lookup(idx: int) -> dict[str, str]:
154 return samples[idx]
155
156 def tokenize(sample: dict[str, str]) -> dict[str, torch.Tensor]:
157 tok = _get_tokenizer()
158 text = format_prompt(sample)
159 enc = tok(
160 text,
161 max_length=max_seq_len,
162 truncation=True,
163 padding="max_length",
164 return_tensors="pt",
165 )
166 input_ids = enc["input_ids"].squeeze(0)
167 attention_mask = enc["attention_mask"].squeeze(0)
168 # For causal LM, labels = input_ids; mask padding with -100
169 labels = input_ids.clone()
170 labels[attention_mask == 0] = -100
171 return {
172 "input_ids": input_ids,
173 "attention_mask": attention_mask,
174 "labels": labels,
175 }
176
177 def collate(items: list[dict[str, torch.Tensor]]) -> dict[str, torch.Tensor]:
178 return {
179 "input_ids": torch.stack([it["input_ids"] for it in items]),
180 "attention_mask": torch.stack([it["attention_mask"] for it in items]),
181 "labels": torch.stack([it["labels"] for it in items]),
182 }
183
184 return (
185 PipelineBuilder()
186 .add_source(sampler)
187 .pipe(lookup)
188 .pipe(tokenize, concurrency=num_threads)
189 .aggregate(batch_size)
190 .pipe(collate)
191 .add_sink(buffer_size=3)
192 )
193
194
195def iterate_pipeline(
196 pipeline_builder: PipelineBuilder,
197 num_threads: int,
198 device: torch.device,
199) -> Iterator[dict[str, torch.Tensor]]:
200 """Build, run, and iterate over the SPDL pipeline, transferring to device."""
201 pipeline = pipeline_builder.build(num_threads=num_threads)
202 with pipeline.auto_stop():
203 for batch in pipeline.get_iterator(timeout=120):
204 yield {k: v.to(device, non_blocking=True) for k, v in batch.items()}
205
206
207# ---------------------------------------------------------------------------
208# Model setup
209# ---------------------------------------------------------------------------
210
211
212def build_model(
213 model_path: str,
214 device: torch.device,
215 lora_r: int,
216 lora_alpha: int,
217 lora_dropout: float,
218) -> torch.nn.Module:
219 """Load LLaMA model and apply LoRA."""
220 from peft import get_peft_model, LoraConfig, TaskType
221 from transformers import AutoModelForCausalLM
222
223 _LG.info("Loading model from %s", model_path)
224 model = AutoModelForCausalLM.from_pretrained(
225 model_path,
226 torch_dtype=torch.bfloat16,
227 attn_implementation="sdpa",
228 )
229
230 lora_config = LoraConfig(
231 task_type=TaskType.CAUSAL_LM,
232 r=lora_r,
233 lora_alpha=lora_alpha,
234 lora_dropout=lora_dropout,
235 target_modules=["q_proj", "v_proj"],
236 )
237 model = get_peft_model(model, lora_config)
238 model.print_trainable_parameters()
239
240 model = model.to(device)
241 return model
242
243
244# ---------------------------------------------------------------------------
245# Training
246# ---------------------------------------------------------------------------
247
248
249def train(
250 *,
251 model_path: str,
252 data_path: list[str],
253 output_dir: str,
254 max_seq_len: int,
255 batch_size: int,
256 num_epochs: int,
257 lr: float,
258 weight_decay: float,
259 max_grad_norm: float,
260 log_interval: int,
261 lora_r: int,
262 lora_alpha: int,
263 lora_dropout: float,
264 num_workers: int,
265 progress_fn: Callable[[int, int], None] | None = None,
266) -> None:
267 """Main training function, called per-rank."""
268 rank: int = dist.get_rank()
269 world_size: int = dist.get_world_size()
270 local_rank: int = int(os.environ.get("LOCAL_RANK", 0))
271 device = torch.device(f"cuda:{local_rank}")
272 torch.cuda.set_device(device)
273
274 _LG.info("Rank %d/%d on device %s", rank, world_size, device)
275
276 # --- Data ---
277 samples = load_data(data_path)
278
279 from transformers import AutoTokenizer
280
281 tokenizer = AutoTokenizer.from_pretrained(model_path)
282 if tokenizer.pad_token is None:
283 tokenizer.pad_token = tokenizer.eos_token
284
285 # --- Model ---
286 model = build_model(
287 model_path,
288 device,
289 lora_r=lora_r,
290 lora_alpha=lora_alpha,
291 lora_dropout=lora_dropout,
292 )
293 ddp_model = DDP(model, device_ids=[local_rank])
294
295 # --- Optimizer ---
296 optimizer = torch.optim.AdamW(
297 ddp_model.parameters(),
298 lr=lr,
299 weight_decay=weight_decay,
300 foreach=True,
301 )
302
303 num_steps_per_epoch = len(samples) // (batch_size * world_size)
304 total_steps = num_steps_per_epoch * num_epochs
305 if rank == 0:
306 _LG.info(
307 "Training: %d samples, %d epochs, %d steps/epoch, %d total steps",
308 len(samples),
309 num_epochs,
310 num_steps_per_epoch,
311 total_steps,
312 )
313 if progress_fn is not None:
314 progress_fn(0, total_steps)
315
316 scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
317 optimizer,
318 T_max=total_steps,
319 eta_min=lr * 0.1,
320 )
321
322 # --- Training loop ---
323 global_step = 0
324 ddp_model.train()
325
326 for epoch in range(num_epochs):
327 _LG.info("Epoch %d/%d", epoch + 1, num_epochs)
328
329 pipeline_builder = build_pipeline(
330 samples=samples,
331 tokenizer=tokenizer,
332 max_seq_len=max_seq_len,
333 batch_size=batch_size,
334 rank=rank,
335 world_size=world_size,
336 num_threads=num_workers,
337 seed=epoch, # different shuffle per epoch
338 )
339
340 t0 = time.monotonic()
341 epoch_loss = 0.0
342 num_batches = 0
343
344 for batch in iterate_pipeline(pipeline_builder, num_workers, device):
345 outputs = ddp_model(
346 input_ids=batch["input_ids"],
347 attention_mask=batch["attention_mask"],
348 labels=batch["labels"],
349 )
350 loss = outputs.loss
351
352 loss.backward()
353 torch.nn.utils.clip_grad_norm_(ddp_model.parameters(), max_grad_norm)
354 optimizer.step()
355 scheduler.step()
356 optimizer.zero_grad()
357
358 epoch_loss += loss.item()
359 num_batches += 1
360 global_step += 1
361
362 if rank == 0:
363 if progress_fn is not None:
364 progress_fn(global_step, total_steps)
365 if global_step % log_interval == 0:
366 avg_loss = epoch_loss / num_batches
367 elapsed = time.monotonic() - t0
368 _LG.info(
369 "Step %d | loss=%.4f | lr=%.2e | %.1f samples/s",
370 global_step,
371 avg_loss,
372 scheduler.get_last_lr()[0],
373 num_batches * batch_size * world_size / elapsed,
374 )
375
376 elapsed = time.monotonic() - t0
377 if rank == 0:
378 avg_loss = epoch_loss / max(num_batches, 1)
379 _LG.info(
380 "Epoch %d complete | avg_loss=%.4f | %.1fs | %.1f samples/s",
381 epoch + 1,
382 avg_loss,
383 elapsed,
384 num_batches * batch_size * world_size / elapsed,
385 )
386
387 # --- Save ---
388 if rank == 0 and output_dir:
389 output_path = Path(output_dir)
390 output_path.mkdir(parents=True, exist_ok=True)
391 ddp_model.module.save_pretrained(output_path)
392 tokenizer.save_pretrained(output_path)
393 _LG.info("Model saved to %s", output_path)
394
395
396def parse_args() -> argparse.Namespace:
397 parser = argparse.ArgumentParser(description=__doc__)
398 # Model
399 parser.add_argument(
400 "--model-path",
401 type=str,
402 required=True,
403 help="Path to pretrained LLaMA model directory",
404 )
405 parser.add_argument(
406 "--output-dir",
407 type=str,
408 default="",
409 help="Directory to save fine-tuned LoRA weights",
410 )
411 # Data
412 parser.add_argument(
413 "--data-path",
414 type=str,
415 nargs="+",
416 required=True,
417 help="One or more paths to Alpaca-format JSONL files (local or manifold://).",
418 )
419 parser.add_argument("--max-seq-len", type=int, default=512)
420 # Training
421 parser.add_argument("--batch-size", type=int, default=4)
422 parser.add_argument("--num-epochs", type=int, default=10)
423 parser.add_argument("--lr", type=float, default=5e-4)
424 parser.add_argument("--weight-decay", type=float, default=0.01)
425 parser.add_argument("--max-grad-norm", type=float, default=1.0)
426 parser.add_argument("--log-interval", type=int, default=10)
427 # LoRA
428 parser.add_argument("--lora-r", type=int, default=8)
429 parser.add_argument("--lora-alpha", type=int, default=16)
430 parser.add_argument("--lora-dropout", type=float, default=0.05)
431 # Pipeline
432 parser.add_argument(
433 "--num-workers",
434 type=int,
435 default=8,
436 help="Concurrent tokenization workers in the SPDL pipeline",
437 )
438 return parser.parse_args()
439
440
441def init_logging() -> None:
442 """Initialize logging."""
443 rank = os.environ.get("RANK", "?")
444 logging.basicConfig(
445 level=logging.INFO,
446 format=f"%(asctime)s [%(levelname).1s] [Rank{rank}] %(name)s: %(message)s",
447 )
448
449
450def main(args: argparse.Namespace) -> None:
451 dist.init_process_group(backend="nccl", timeout=timedelta(minutes=30))
452 try:
453 train(
454 model_path=resolve_model_path(args.model_path),
455 data_path=args.data_path,
456 output_dir=args.output_dir,
457 max_seq_len=args.max_seq_len,
458 batch_size=args.batch_size,
459 num_epochs=args.num_epochs,
460 lr=args.lr,
461 weight_decay=args.weight_decay,
462 max_grad_norm=args.max_grad_norm,
463 log_interval=args.log_interval,
464 lora_r=args.lora_r,
465 lora_alpha=args.lora_alpha,
466 lora_dropout=args.lora_dropout,
467 num_workers=args.num_workers,
468 progress_fn=report_progress,
469 )
470 finally:
471 dist.destroy_process_group()
472
473
474if __name__ == "__main__":
475 init_logging()
476 main(parse_args())
API Reference¶
Functions
- build_model(model_path: str, device: device, lora_r: int, lora_alpha: int, lora_dropout: float) Module[source]¶
Load LLaMA model and apply LoRA.
- build_pipeline(samples: list[dict[str, str]], tokenizer: PreTrainedTokenizerBase, max_seq_len: int, batch_size: int, rank: int, world_size: int, num_threads: int, seed: int) PipelineBuilder[source]¶
Build an SPDL pipeline for concurrent tokenization.
- Pipeline stages:
Source: DistributedRandomSampler yields sample indices
Pipe: Look up sample by index
Pipe (concurrent): Format prompt and tokenize
Aggregate: Group into batches
Pipe: Collate into tensors
Sink: Buffer for the training loop
- iterate_pipeline(pipeline_builder: PipelineBuilder, num_threads: int, device: device) Iterator[dict[str, Tensor]][source]¶
Build, run, and iterate over the SPDL pipeline, transferring to device.
- load_data(paths: Sequence[str]) list[dict[str, str]][source]¶
Load and concatenate data from one or more JSONL files.
- train(*, model_path: str, data_path: list[str], output_dir: str, max_seq_len: int, batch_size: int, num_epochs: int, lr: float, weight_decay: float, max_grad_norm: float, log_interval: int, lora_r: int, lora_alpha: int, lora_dropout: float, num_workers: int, progress_fn: Callable[[int, int], None] | None = None) None[source]¶
Main training function, called per-rank.