使用全数据并行(FSDP)进行高级模型训练
作者: Hamid Shojanazeri,Less Wright,Rohan Varma,Yanli Zhao
本教程介绍了 PyTorch 1.12 发行版中 Fully Sharded Data Parallel (FSDP) 的更多高级功能。要熟悉 FSDP,请参阅 FSDP 入门教程。
在本教程中,我们将使用 FSDP 对 HuggingFace (HF) 的 T5 模型进行微调,并通过文本摘要来演示其应用。
该示例使用了Wikihow,并为了简化,我们将在一个配备8个A100 GPU的P4dn实例上展示单节点训练过程。不久我们将发布一篇关于多节点集群上的大规模FSDP训练的博客文章,请关注PyTorch官方频道获取更多信息。
FSDP 是一个专注于易用性、性能和长期支持的生产级包。FSDP 的主要优势之一是减少每个 GPU 上的内存占用,从而使得与 DDP 相比,在更低的总内存下训练更大规模的模型成为可能,并通过计算和通信的重叠来高效地训练模型。这种降低的内存压力可以用于训练更大规模的模型或增加批量大小,进而有可能提高整体训练吞吐量。你可以在这里了解更多关于 PyTorch FSDP 的信息。
本教程中的FSDP功能
-
Transformer 自动换行策略
-
混合精度
-
在设备上初始化FSDP模型
-
分片策略
-
反向预取
-
通过流式传输到CPU来保存模型检查点
FSDP工作原理回顾
总体而言,FDSP 的工作原理如下:
在构造函数中
-
分割模型参数,每个 ranks 只保留自己的分片。
在前向传播过程中
-
运行all_gather来收集所有rank上的分片,并恢复此FSDP单元的完整参数,然后进行前向计算
-
丢弃刚刚收集的非自有参数分片,以释放内存
在反向传播过程中
-
运行all_gather以从所有排名收集所有分片,在此FSDP单元中恢复完整参数,然后执行反向计算。
-
丢弃非拥有的参数以释放内存。
-
运行 reduce_scatter 以同步梯度
HF T5的微调
HF T5 预训练模型有四种不同的大小,参数量从60百万到110亿不等。在本教程中,我们将展示如何使用FSDP对T5 3B进行微调以实现WikiHow数据集上的文本摘要任务。本教程的主要目的是突出FSDP在训练超过30亿参数的大规模模型时可用的不同功能。此外,我们还介绍了针对基于Transformer的模型的特定功能。本教程的代码可在PyTorch 实用技巧中找到。
设置
1.1 安装 PyTorch 夜间版
pip3install--pretorchtorchvisiontorchaudio-fhttps://download.pytorch.org/whl/nightly/cu113/torch_nightly.html
1.2 数据集设置
请创建一个data文件夹,从wikihowAll.csv 和 wikihowSep.cs 下载 WikiHow 数据集,并将其放置在data文件夹中。我们将使用summarization_dataset 中的 wikihow 数据集。
接下来,我们将以下代码片段添加到名为“T5_training.py”的Python脚本中。
注
本教程的完整源代码可在 PyTorch 示例 中获取。
1.3 导入必要的包:
import os import argparse import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from transformers import AutoTokenizer, GPT2TokenizerFast from transformers import T5Tokenizer, T5ForConditionalGeneration import functools from torch.optim.lr_scheduler import StepLR import torch.nn.functional as F import torch.distributed as dist import torch.multiprocessing as mp from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.data.distributed import DistributedSampler from transformers.models.t5.modeling_t5 import T5Block from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( checkpoint_wrapper, CheckpointImpl, apply_activation_checkpointing_wrapper) from torch.distributed.fsdp import ( FullyShardedDataParallel as FSDP, MixedPrecision, BackwardPrefetch, ShardingStrategy, FullStateDictConfig, StateDictType, ) from torch.distributed.fsdp.wrap import ( transformer_auto_wrap_policy, enable_wrap, wrap, ) from functools import partial from torch.utils.data import DataLoader from pathlib import Path from summarization_dataset import * from transformers.models.t5.modeling_t5 import T5Block from typing import Type import time import tqdm from datetime import datetime
1.4 分布式训练设置。这里我们使用两个辅助函数来初始化分布式训练,并在训练完成后进行清理。在这个教程中,我们将使用 torch elastic,并通过torchrun 来执行,这会自动设置工作进程的 RANK 和 WORLD_SIZE。
def setup(): # initialize the process group dist.init_process_group("nccl") def cleanup(): dist.destroy_process_group()
2.1 设置 HuggingFace T5 模型:
def setup_model(model_name): model = T5ForConditionalGeneration.from_pretrained(model_name) tokenizer = T5Tokenizer.from_pretrained(model_name) return model, tokenizer
我们在这里还添加了几个辅助函数,用于处理日期和格式化内存指标。
def get_date_of_run(): """create date and time for file save uniqueness example: 2022-05-07-08:31:12_PM' """ date_of_run = datetime.now().strftime("%Y-%m-%d-%I:%M:%S_%p") print(f"--> current date and time of run = {date_of_run}") return date_of_run def format_metrics_to_gb(item): """quick function to format numbers to gigabyte and round to 4 digit precision""" metric_num = item / g_gigabyte metric_num = round(metric_num, ndigits=4) return metric_num
2.2 定义训练函数:
def train(args, model, rank, world_size, train_loader, optimizer, epoch, sampler=None): model.train() local_rank = int(os.environ['LOCAL_RANK']) fsdp_loss = torch.zeros(2).to(local_rank) if sampler: sampler.set_epoch(epoch) if rank==0: inner_pbar = tqdm.tqdm( range(len(train_loader)), colour="blue", desc="r0 Training Epoch" ) for batch in train_loader: for key in batch.keys(): batch[key] = batch[key].to(local_rank) optimizer.zero_grad() output = model(input_ids=batch["source_ids"],attention_mask=batch["source_mask"],labels=batch["target_ids"] ) loss = output["loss"] loss.backward() optimizer.step() fsdp_loss[0] += loss.item() fsdp_loss[1] += len(batch) if rank==0: inner_pbar.update(1) dist.all_reduce(fsdp_loss, op=dist.ReduceOp.SUM) train_accuracy = fsdp_loss[0] / fsdp_loss[1] if rank == 0: inner_pbar.close() print( f"Train Epoch: \t{epoch}, Loss: \t{train_accuracy:.4f}" ) return train_accuracy
2.3 定义一个验证函数:
def validation(model, rank, world_size, val_loader): model.eval() correct = 0 local_rank = int(os.environ['LOCAL_RANK']) fsdp_loss = torch.zeros(3).to(local_rank) if rank == 0: inner_pbar = tqdm.tqdm( range(len(val_loader)), colour="green", desc="Validation Epoch" ) with torch.no_grad(): for batch in val_loader: for key in batch.keys(): batch[key] = batch[key].to(local_rank) output = model(input_ids=batch["source_ids"],attention_mask=batch["source_mask"],labels=batch["target_ids"]) fsdp_loss[0] += output["loss"].item() # sum up batch loss fsdp_loss[1] += len(batch) if rank==0: inner_pbar.update(1) dist.all_reduce(fsdp_loss, op=dist.ReduceOp.SUM) val_loss = fsdp_loss[0] / fsdp_loss[1] if rank == 0: inner_pbar.close() print(f"Validation Loss: {val_loss:.4f}") return val_loss
2.4 定义一个使用 FSDP 封装模型的分布式训练函数。
def fsdp_main(args): model, tokenizer = setup_model("t5-base") local_rank = int(os.environ['LOCAL_RANK']) rank = int(os.environ['RANK']) world_size = int(os.environ['WORLD_SIZE']) dataset = load_dataset('wikihow', 'all', data_dir='data/') print(dataset.keys()) print("Size of train dataset: ", dataset['train'].shape) print("Size of Validation dataset: ", dataset['validation'].shape) #wikihow(tokenizer, type_path, num_samples, input_length, output_length, print_text=False) train_dataset = wikihow(tokenizer, 'train', 1500, 512, 150, False) val_dataset = wikihow(tokenizer, 'validation', 300, 512, 150, False) sampler1 = DistributedSampler(train_dataset, rank=rank, num_replicas=world_size, shuffle=True) sampler2 = DistributedSampler(val_dataset, rank=rank, num_replicas=world_size) setup() train_kwargs = {'batch_size': args.batch_size, 'sampler': sampler1} test_kwargs = {'batch_size': args.test_batch_size, 'sampler': sampler2} cuda_kwargs = {'num_workers': 2, 'pin_memory': True, 'shuffle': False} train_kwargs.update(cuda_kwargs) test_kwargs.update(cuda_kwargs) train_loader = torch.utils.data.DataLoader(train_dataset,**train_kwargs) val_loader = torch.utils.data.DataLoader(val_dataset, **test_kwargs) t5_auto_wrap_policy = functools.partial( transformer_auto_wrap_policy, transformer_layer_cls={ T5Block, }, ) sharding_strategy: ShardingStrategy = ShardingStrategy.SHARD_GRAD_OP #for Zero2 and FULL_SHARD for Zero3 torch.cuda.set_device(local_rank) #init_start_event = torch.cuda.Event(enable_timing=True) #init_end_event = torch.cuda.Event(enable_timing=True) #init_start_event.record() bf16_ready = ( torch.version.cuda and torch.cuda.is_bf16_supported() and LooseVersion(torch.version.cuda) >= "11.0" and dist.is_nccl_available() and nccl.version() >= (2, 10) ) if bf16_ready: mp_policy = bfSixteen else: mp_policy = None # defaults to fp32 # model is on CPU before input to FSDP model = FSDP(model, auto_wrap_policy=t5_auto_wrap_policy, mixed_precision=mp_policy, #sharding_strategy=sharding_strategy, device_id=torch.cuda.current_device()) optimizer = optim.AdamW(model.parameters(), lr=args.lr) scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma) best_val_loss = float("inf") curr_val_loss = float("inf") file_save_name = "T5-model-" if rank == 0: time_of_run = get_date_of_run() dur = [] train_acc_tracking = [] val_acc_tracking = [] training_start_time = time.time() if rank == 0 and args.track_memory: mem_alloc_tracker = [] mem_reserved_tracker = [] for epoch in range(1, args.epochs + 1): t0 = time.time() train_accuracy = train(args, model, rank, world_size, train_loader, optimizer, epoch, sampler=sampler1) if args.run_validation: curr_val_loss = validation(model, rank, world_size, val_loader) scheduler.step() if rank == 0: print(f"--> epoch {epoch} completed...entering save and stats zone") dur.append(time.time() - t0) train_acc_tracking.append(train_accuracy.item()) if args.run_validation: val_acc_tracking.append(curr_val_loss.item()) if args.track_memory: mem_alloc_tracker.append( format_metrics_to_gb(torch.cuda.memory_allocated()) ) mem_reserved_tracker.append( format_metrics_to_gb(torch.cuda.memory_reserved()) ) print(f"completed save and stats zone...") if args.save_model and curr_val_loss < best_val_loss: # save if rank == 0: print(f"--> entering save model state") save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) with FSDP.state_dict_type( model, StateDictType.FULL_STATE_DICT, save_policy ): cpu_state = model.state_dict() #print(f"saving process: rank {rank} done w state_dict") if rank == 0: print(f"--> saving model ...") currEpoch = ( "-" + str(epoch) + "-" + str(round(curr_val_loss.item(), 4)) + ".pt" ) print(f"--> attempting to save model prefix {currEpoch}") save_name = file_save_name + "-" + time_of_run + "-" + currEpoch print(f"--> saving as model name {save_name}") torch.save(cpu_state, save_name) if curr_val_loss < best_val_loss: best_val_loss = curr_val_loss if rank==0: print(f"-->>>> New Val Loss Record: {best_val_loss}") dist.barrier() cleanup()
2.5 解析参数并设置主函数:
if __name__ == '__main__': # Training settings parser = argparse.ArgumentParser(description='PyTorch T5 FSDP Example') parser.add_argument('--batch-size', type=int, default=4, metavar='N', help='input batch size for training (default: 64)') parser.add_argument('--test-batch-size', type=int, default=4, metavar='N', help='input batch size for testing (default: 1000)') parser.add_argument('--epochs', type=int, default=2, metavar='N', help='number of epochs to train (default: 3)') parser.add_argument('--lr', type=float, default=.002, metavar='LR', help='learning rate (default: .002)') parser.add_argument('--gamma', type=float, default=0.7, metavar='M', help='Learning rate step gamma (default: 0.7)') parser.add_argument('--no-cuda', action='store_true', default=False, help='disables CUDA training') parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)') parser.add_argument('--track_memory', action='store_false', default=True, help='track the gpu memory') parser.add_argument('--run_validation', action='store_false', default=True, help='running the validation') parser.add_argument('--save-model', action='store_false', default=True, help='For Saving the current Model') args = parser.parse_args() torch.manual_seed(args.seed) fsdp_main(args)
使用 torchrun 运行训练:
torchrun--nnodes1--nproc_per_node4T5_training.py
变压器包装策略
如之前教程所述,auto_wrap_policy 是 FSDP 的一个特性,它使得自动拆分给定模型,并将模型、优化器和梯度碎片放入不同的 FSDP 单元变得更加容易。
对于某些架构(如Transformer编码器-解码器),模型的部分组件(例如嵌入表)会被编码器和解码器共享。在这种情况下,我们需要将嵌入表放置在外部的FSDP单元中,以便可以从编码器和解码器访问它。此外,通过为变换器注册层类,可以使其分片计划更加通信高效。在PyTorch 1.12版本中,FSDP添加了此支持,并且我们现在有了一个针对变换器的包装策略。
它可以这样创建,其中 T5Block 表示 T5 变换器层类(包含 MHSA 和 FFN)。
t5_auto_wrap_policy = functools.partial( transformer_auto_wrap_policy, transformer_layer_cls={ T5Block, }, ) torch.cuda.set_device(local_rank) model = FSDP(model, auto_wrap_policy=t5_auto_wrap_policy)
要查看封装的模型,你可以轻松地打印模型,并直观地检查分片和FSDP单元。
混合精度
FSDP 支持灵活的混合精度训练,并允许使用任意低精度类型(如 fp16 或 bfloat16)。目前,bfloat16 只在 Ampere GPU 上可用,因此在使用前需要确认其原生支持。例如,在 V100 GPU 上虽然可以运行 bfloat16,但由于是非原生环境运行,可能会导致显著的性能下降。
要检查是否原生支持 BFloat16,可以使用以下方式:
bf16_ready = ( torch.version.cuda and torch.cuda.is_bf16_supported() and LooseVersion(torch.version.cuda) >= "11.0" and dist.is_nccl_available() and nccl.version() >= (2, 10) )
FSDP 混合精度的一个优势在于可以对参数、梯度和缓冲区的不同精度级别进行精细化控制,具体如下:
fpSixteen = MixedPrecision( param_dtype=torch.float16, # Gradient communication precision. reduce_dtype=torch.float16, # Buffer precision. buffer_dtype=torch.float16, ) bfSixteen = MixedPrecision( param_dtype=torch.bfloat16, # Gradient communication precision. reduce_dtype=torch.bfloat16, # Buffer precision. buffer_dtype=torch.bfloat16, ) fp32_policy = MixedPrecision( param_dtype=torch.float32, # Gradient communication precision. reduce_dtype=torch.float32, # Buffer precision. buffer_dtype=torch.float32, )
请注意,如果未指定某种类型(如参数、reduce 或 buffer),则不会进行任何类型的转换。
grad_bf16=MixedPrecision(reduce_dtype=torch.bfloat16)
在版本 2.4 中,我们将相关的混合精度策略添加到 FSDP 包装器中即可:
model = FSDP(model, auto_wrap_policy=t5_auto_wrap_policy, mixed_precision=bfSixteen)
在我们的实验中,我们发现使用 BFloat16 进行训练可使速度提升高达 4 倍,并且在某些实验中内存减少约 30%,从而可以用于增加批处理大小。
在设备上初始化FSDP模型
在 1.12 版本中,FSDP 引入了一个 device_id 参数,用于在由 device_id 指定的设备上初始化输入的 CPU 模块。当整个模型无法放入单个 GPU 中,但可以放在主机的 CPU 内存中时,这个功能非常有用。如果指定了 device_id,FSDP 将基于每个 FSDP 单元将模型移动到指定设备上,在初始化过程中避免了 GPU 内存溢出问题,并且比基于 CPU 的初始化快得多:
torch.cuda.set_device(local_rank) model = FSDP(model, auto_wrap_policy=t5_auto_wrap_policy, mixed_precision=bfSixteen, device_id=torch.cuda.current_device())
反向预取
反向预取设置控制了请求下一个FSDP单元参数的时机。将其设置为BACKWARD_PRE后,在当前单位开始计算之前,可以提前发起并接收下一个FSDP单元的参数请求。这样可以将all_gather通信与梯度计算重叠起来,在内存消耗略有增加的情况下提高训练速度。它可以在2.4版本中的FSDP包装器中如下使用:
torch.cuda.set_device(local_rank) model = FSDP(model, auto_wrap_policy=t5_auto_wrap_policy, mixed_precision=bfSixteen, device_id=torch.cuda.current_device(), backward_prefetch = BackwardPrefetch.BACKWARD_PRE)
backward_prefetch有两种模式:BACKWARD_PRE 和 BACKWARD_POST。其中,BACKWARD_POST 模式意味着下一个 FSDP 单元的参数不会在当前 FSDP 单元处理完成之前被请求,从而最小化内存开销。在某些情况下,使用BACKWARD_PRE模式可以将模型训练速度提高2-10%,对于更大的模型甚至可以获得更高的加速效果。
通过流式传输到Rank0 CPU来保存模型检查点
为了使用 FULL_STATE_DICT 方式(与本地模型保存方式相同)保存模型检查点,PyTorch 1.12 提供了一些工具来支持大型模型的保存。
首先,可以通过指定一个FullStateDictConfig来允许状态字典仅在rank 0上进行填充,并将其卸载到CPU。
此功能可以按如下方式运行:
save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) with FSDP.state_dict_type( model, StateDictType.FULL_STATE_DICT, save_policy ): cpu_state = model.state_dict() if rank == 0: save_name = file_save_name + "-" + time_of_run + "-" + currEpoch torch.save(cpu_state, save_name)