Source code for fairseq2.recipes.llama.convert_checkpoint
# Copyright (c) Meta Platforms, Inc. and affiliates.# All rights reserved.## This source code is licensed under the BSD-style license found in the# LICENSE file in the root directory of this source tree.from__future__importannotationsimportjsonimportsysimportwarningsfromargparseimportArgumentParser,NamespacefromitertoolsimportcountfrompathlibimportPathfromtypingimportfinalfromwarningsimportcatch_warningsfromtyping_extensionsimportoverridefromfairseq2.assetsimportdefault_asset_storefromfairseq2.loggingimportget_log_writerfromfairseq2.models.llamaimportload_llama_configfromfairseq2.models.llama.integimportconvert_to_reference_checkpointfromfairseq2.recipes.cliimportCliCommandHandlerfromfairseq2.recipes.consoleimportget_error_consolefromfairseq2.setupimportsetup_fairseq2fromfairseq2.utils.fileimportdump_torch_tensors,load_torch_tensorslog=get_log_writer(__name__)
[docs]@finalclassConvertCheckpointCommandHandler(CliCommandHandler):"""Converts fairseq2 LLaMA checkpoints to reference checkpoints."""
[docs]@overridedefinit_parser(self,parser:ArgumentParser)->None:parser.add_argument("--model",metavar="ARCH_NAME",help="model name to fetch architecture to generate params.json",)parser.add_argument("input_dir",type=Path,help="checkpoint directory",)parser.add_argument("output_dir",type=Path,help="output directory to store reference checkpoint",)
[docs]@overridedefrun(self,args:Namespace)->int:ifnotargs.input_dir.exists()ornotargs.input_dir.is_dir():log.error("`input_dir` must be a directory.")sys.exit(1)ifargs.output_dir.exists():log.error("`output_dir` must not exist.")sys.exit(1)setup_fairseq2()arch=(default_asset_store.retrieve_card(args.model).field("model_arch").as_(str))ifarch:model_config=load_llama_config(args.model)else:model_config=Noneinput_files=[]# Determine input checkpoint files.input_file=args.input_dir.joinpath("model.pt")ifinput_file.exists():input_files.append(input_file)else:forshard_idxincount():input_file=args.input_dir.joinpath(f"model.{shard_idx}.pt")ifnotinput_file.exists():breakinput_files.append(input_file)ifnotinput_files:log.error("`input_dir` must contain a model checkpoint file (i.e. model.pt)")# fmt: skipsys.exit(1)output_files=[]# Determine output checkpoint filenames.forshard_idxinrange(len(input_files)):output_file=args.output_dir.joinpath(f"consolidated.{shard_idx:02d}.pth")output_files.append(output_file)args.output_dir.mkdir(parents=True)# Begin conversion.withget_error_console().status("[bold green]Converting...")asstatus:forinput_file,output_fileinzip(input_files,output_files):status.update(f"[bold green]Loading {input_file.name}...")try:withcatch_warnings():warnings.simplefilter("ignore")checkpoint=load_torch_tensors(input_file,restrict=True)exceptRuntimeError:log.exception("Checkpoint file {} cannot be loaded.",input_file.name)sys.exit(1)if"model"notincheckpoint:log.error("Checkpoint file {} does not contain a 'model' entry.",input_file.name)# fmt: skipsys.exit(1)status.update(f"[bold green]Converting {input_file.name} to {output_file.name}...")ref_state_dict=convert_to_reference_checkpoint(checkpoint)try:dump_torch_tensors(ref_state_dict,output_file)exceptRuntimeError:log.exception("Checkpoint file {} cannot be saved.",output_file.name)sys.exit(1)log.info("{} converted!",input_file.name)# Generate a basic params.json, mainly to use with HG transformers.ifmodel_configisnotNone:params={"model":{"dim":model_config.model_dim,"n_layers":model_config.num_layers,"n_heads":model_config.num_attn_heads,"multiple_of":model_config.ffn_inner_dim_to_multiple,"rope_theta":model_config.rope_theta,"norm_eps":1e-5,},}ifmodel_config.num_attn_heads!=model_config.num_key_value_heads:params["model"]["n_kv_heads"]=model_config.num_key_value_heads# we only specify archs where multiplier != 1.0ffn_dim_multipliers={"llama2_70b":1.3,"llama3_8b":1.3,"llama3_70b":1.3,"llama3_1_8b":1.3,"llama3_1_70b":1.3,"llama3_1_405b":1.2,"llama3_2_1b":1.5,}ifarchinffn_dim_multipliers:params["model"]["ffn_dim_multiplier"]=ffn_dim_multipliers[arch]try:withargs.output_dir.joinpath("params.json").open("w")asfp:json.dump(params,fp)exceptRuntimeError:log.exception("params.json cannot be created.")sys.exit(1)log.info("params.json generated for {}.",args.model)return0