前言

Megatron 的核心思想是通过张量并行(Tensor Parallelism)与流水线并行(Pipeline Parallelism)相结合的方式,在 GPU 集群上实现超大规模语言模型的高效训练。

张量并行和流水线并行,同属于模型并行,简单地理解就是将模型参数按不同维度切分成小块,放到不同GPU上并行计算。

张量并行原理

张量并行属于模型并行的一种,旨在于将模型“纵向”切分成更小的“块”,放到不同GPU上进行并行计算。

理解模型的张量并行,需要从基础的矩阵乘法说起。

并行矩阵乘法

对于一个矩阵乘法$X @ W = Y$

image

假设W很大,导致单卡装不下。我们需要把W切开放不到不同的卡上。显然我们可以沿着W的行维度,或者列维度进行切分。

按行切分

假设有N张GPU,就可以把权重W切成N份。下图展示N=2的切割方式。

image

W按行切分,X就得按列对应切分。接着将按行切分展示成并行的方式。

image 图中的每一行表示单独在一块GPU上的计算过程。

  • fg 分别表示两个算子,每个算子都包括一组forward和backward操作。
  • f的forward是将X按列切分成两块,放到两个GPU上。相应的 backward是将两块GPU上的梯度$\frac{\partial L}{\partial X1}$和$\frac{\partial L}{\partial X2}$gather到一起。
  • g的forward是将Y1和Y2相加得到Y,backward是将$\frac{\partial L}{\partial Y}$广播到各个GPU上。

按列切分

W也可以按列切分。

image

展示成并行的形式。

image

  • 按列切分时,f 的foward时将X复制到各个GPU上,backward是将各个GPU上的梯度$\frac{\partial L}{\partial X}$AllReduce到一起;
  • g 的forward将Y1和Y2 gather到一起,backward是将$\frac{\partial L}{\partial Y}$切分到各个GPU上;

“按行”和“按列”切分权重的方法是Megatron-LM中的基础算子,MLP、Attention等权重的切分操作就是由这两个基础算子组合而成的。

MLP层张量并行

MLP的计算过程由两个矩阵乘法和非线性激活函数组成,示意图如下。

image

其中,GELU是激活函数,A和B分别为两个线性层的权重。在Transformer里,一般设h’ = 4h。假设现在有N块GPU,我们要把MLP层的权重拆到上面做计算,Megatron提供的拆分办法如下:

image

在MLP层中,对权重A采用“列切割”,对权重B采用“行切割”。

  • f 的forward计算:把输入X拷贝到两块GPU上,每块GPU即可独立做forward计算。
  • g 的forward计算:每块GPU上的forward的计算完毕,取得Z1和Z2后,GPU间做一次AllReduce,相加结果产生Z。
  • g 的backward计算:只需要把梯度 $\frac{\partial L}{\partial Z}$拷贝到两块GPU上,两块GPU就能各自独立做梯度计算。
  • f 的backward计算:当前层的梯度计算完毕,需要传递到下一层继续做梯度计算,此时需要求得 $\frac{\partial L}{\partial X}$ ,然后两块GPU做一次AllReduce,把各自的梯度相加即可

先对A“列切割”然后对B“行切割”的原因是可以保证各GPU上的计算相互独立,减少通讯量。MLP层做forward时产生一次AllReduce,做backward时产生一次AllReduceAllReduce的过程分为两个阶段,Reduce-ScatterAll-Gather,每个阶段的通讯量都相等。现在我们设每个阶段的通讯量为$\Phi$ ,则一次AllReduce产生的通讯量为$2\Phi$ ,MLP层的总通讯量为 $4\Phi$ 。

多头自注意力层

多头自注意力层包含每个头的SelfAttention计算,然后每个头的hidden向量concat后,再经过一个线性层。

其切分核心是将每个头的权重放到不同GPU上进行计算。

下图展示了当num_heads = 2时attention层的计算方法

image

对三个参数矩阵Q,K,V,按照“列切割”,每个头放到一块GPU上,做并行计算。对线性层B,按照“行切割”。切割的方式和MLP层基本一致,其forward与backward原理也一致。

在实际应用中,并不一定按照一个head占用一块GPU来切割权重,我们也可以一个多个head占用一块GPU,这依然不会改变单块GPU上独立计算的目的。

类比于MLP层,self-attention层在forward中做一次AllReduce,在backward中做一次AllReduce。总通讯量也是$4\Phi$。

Embedding层

输入层和输出层,均存在embedding层。

Embedding层一般由两个部分组成:

  • word embedding:维度(v, h),其中v表示词表大小。
  • positional embedding:维度(max_s, h),其中max_s表示模型允许的最大序列长度。

通常positional embedding对显存压力不大,可以每个GPU上拷贝一份。但词表一般较大,因此需要把word embedding拆分到各个GPU上。

对于输入层embedding来说,将ebmedding矩阵按行切割,各自查表。具体的做法如下:

image

输入X过word embedding的前向过程,等价于用token的序号去word embedding中按行索引,是一个查表过程。

假设词表中有300个词,现在我们将word embedding拆分到两块GPU上,第一块GPU维护词表[0, 150),第二块GPU维护词表[150, 299)。当输入X去GPU上查找时,能找到的词,就正常返回词向量,找不到就把词向量中的全部全素都置0。按此方式查找完毕后,每块GPU上的数据做一次AllReduce,就能得到最终的输入。

对于输出层embedding来说,一般和embedding绑定共享参数,它把输入再映射回词表里,得到每一个位置的词。

此时前向计算就是将embedding矩阵分割到各个GPU上,并行计算矩阵乘法。具体计算过程如下:

image

需要注意的是,当流水线并行大于1时,输入层和输出层不在同一个GPU上。因为输入层和输出层的embedding矩阵共享,所以反向传播时,需要通过AllReduce操作累加梯度。

Cross-entropy层

接着输出层计算损失函数。

image

一种方式是对Y1和Y2做一次All-Gather,把它们concat起来形成Y,然后对Y的每一行做softmax,最后计算交叉熵损失。但这会产生额外的通信量,当词表很大时,开销也不容忽视。

另一种优化方式,如下:

image

  • 每块GPU上,先按行求和,得到各自GPU上的$GPU_sum(e)$
  • 将每块GPU上结果做AllReduce,得到每行最终的$sum(e)$,也就softmax中的分母。此时的通讯量和词表大小无关,为$b*s$;
  • 在每块GPU上,即可计算各自维护部分的$e_i/sum(e)$,将其与真值做cross-entropy,得到每行的loss,按行加总起来以后得到GPU上scalar Loss。
  • 将各个GPU上的scalar Loss做AllReduce,得到总Loss。此时通讯量为GPU数量N。

Megatron实现

Magatron模型并行训练主要入口是pretrain 方法。该方法依次执行以下四件事:

  1. 初始化Megatron,重点是初始化分布式环境;
  2. 定义模型架构并切割模型,设置优化器和lr调度器;
  3. 构造train/val/test数据集;
  4. 训练模型;

Megatron 框架中的张量并行机制,其实现依赖于两个前提:其一是在分布式训练环境中构建用于张量切片通信的进程组;其二是在模型定义阶段就将权重和激活值按维度切分,使其天然适配多设备协同计算。

初始化分布式环境

不管是张量并行(TP),流水线并行(PP)还是数据并行(DP),都需要在不同GPU之间传递数据。GPU之间的通信关系是需要分组的,以执行不同的并行方式,这就是初始化分布式环境做的事。

假设我们有2台机器(node0和node1),每台机器上有8块GPU,GPU的编号为0~15。 我们使用这16块GPU,做DP/TP/PP混合并行,如下图

image

  • MP:模型并行组(Model Parallism)。假设一个完整的模型需要布在8块GPU上,则如图所示,我们共布了2个model replica(2个MP)。MP组为:[[g0, g1, g4, g5, g8, g9, g12, g13], [g2, g3, g6, g7, g10, g11, g14, g15]]
  • TP:张量并行组(Tensor Parallism。对于一个模型的每一层,我们将其参数纵向切开,分别置于不同的GPU上,则图中一共有8个TP组。TP组为:[[g0, g1], [g4, g5],[g8, g9], [g12, g13], [g2, g3], [g6, g7], [g10, g11], [g14, g15]]
  • PP:流水线并行组(Pipeline Parallism)。对于一个模型,我们将其每一层都放置于不同的GPU上,则图中一共有4个PP组。PP组为:[[g0, g4, g8, g12], [g1, g5, g9, g13], [g2, g6, g10, g14], [g3, g7, g11, g15]]
  • DP:数据并行组(Data Parallism。经过上述切割,对维护有相同模型部分的GPU,我们就可以做数据并行,则图中共有8个DP组。DP组为[[g0, g2], [g1, g3], [g4, g6], [g5, g7], [g8, g10], [g9, g11], [g12, g14], [g13, g15]]

确认好DP/TP/PP组,并分配好进程后,每个进程独立执行自己所维护的那部分模型的计算,实现并行训练。

进一步地,对不同组能设置不同进程间的通讯方案。例如属于一个DP组的g0和g2需要进行梯度通讯,属于一个PP组的g4和g8需要进行层间输出结果的通讯。

回到pretrain函数,它的第一行就通过initialize_megatron执行了分布式初始化:

def pretrain(
    train_valid_test_dataset_provider,
    model_provider,
    forward_step_func,
    valid_forward_step_func=None,
    extra_args_provider=None,
    args_defaults={},
):
    initialize_megatron(
        extra_args_provider=extra_args_provider, args_defaults=args_defaults
    )
    ...

总体来说,initialize_megatron这个方法实现了3个目的:

  • 设置分布式环境:初始化进程,分配GPU,并设置进程大组(group)。也即例子中的0~15号进程同属一个分布式进程大组
  • 制定DP/TP/PP分组策略,设置进程子组(subgroup)
  • 设置DeepSpeed ZeRO-R,对activation进行优化

初始化进程大组

借助torch.distributed 来实现这一步,它是pytorch用于设置分布式训练环境的偏底层API(distributed communication package)。

        init_method = "tcp://"
        master_ip = os.getenv("MASTER_ADDR", "localhost") # 获取rank=0进程的ip
        master_port = os.getenv("MASTER_PORT", "6000") # 获取rank=0进程的端口
        init_method += master_ip + ":" + master_port 
        print( 
            f"  > (rank={args.rank}) initializing process group: "
            f"world_size={args.world_size} "
            f"backend={args.distributed_backend} " 
            f"init_method={init_method}",
            flush=True,
        )
        timeout = datetime.timedelta(minutes=args.dist_timeout)
        torch.distributed.init_process_group(
            backend=args.distributed_backend,
            world_size=args.world_size,
            rank=args.rank,
            init_method=init_method,
            timeout=timeout
        )
        print(f"  > (rank={args.rank}) process group initialized")

设置DP/TP/PP组

设置完进程大组(group)后,我们就可以进一步设置进程子组(subgroup)了,也即设置DP/TP/PP组。

            mpu.initialize_model_parallel( # megatron/mpu/initialize.py
                args.tensor_model_parallel_size,
                args.pipeline_model_parallel_size,
                args.virtual_pipeline_model_parallel_size,
            )

核心函数initialize_model_parallelmegatron/mpu/initialize.py 下。mpu的含义是model parallisim utils,也就是和模型并行设置相关的函数,都放在这个目录下,它接收3个参数:

  • tensor_model_parallel_size:每个TP组的进程数量。例如图例中是2
  • pipeline_model_parallel_size:每个PP组的进程数量。例如图例中是4
  • virtual_pipeline_model_parallel_size:每个virtual PP组的进程数量。这是NVIDIA对Megatron做后续迭代时提出的一种优化方法。我们之后会单独开一篇文章来讲解。这里可暂时忽略(不是必须参数,可以传None值)。

定好了TP和PP,DP_size就能根据 world_size // (TP_size * PP_size)计算得出。

def initialize_model_parallel(
    tensor_model_parallel_size_=1,
    pipeline_model_parallel_size_=1,
    virtual_pipeline_model_parallel_size_=None,
):
    """
    Initialize model data parallel groups.

    Arguments:
        tensor_model_parallel_size: number of GPUs used to parallelize model tensor.
        pipeline_model_parallel_size: number of GPUs used to parallelize model pipeline.

    Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we
    use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
    the model pipeline. The present function will
    create 8 tensor model-parallel groups, 4 pipeline model-parallel groups
    and 8 data-parallel groups as:
        8 data_parallel groups:
            [g0, g2], [g1, g3], [g4, g6], [g5, g7], [g8, g10], [g9, g11], [g12, g14], [g13, g15]
        8 tensor model-parallel groups:
            [g0, g1], [g2, g3], [g4, g5], [g6, g7], [g8, g9], [g10, g11], [g12, g13], [g14, g15]
        4 pipeline model-parallel groups: 
            [g0, g4, g8, g12], [g1, g5, g9, g13], [g2, g6, g10, g14], [g3, g7, g11, g15]
        2 model-parallel group:
        [g0, g1, g4, g5, g8, g9, g12, g13], [g2, g3, g6, g7, g10, g8, g14, g15]

    Note that for efficiency, the caller should make sure adjacent ranks
    are on the same DGX box. For example if we are using 2 DGX-1 boxes
    with a total of 16 GPUs, rank 0 to 7 belong to the first box and
    ranks 8 to 15 belong to the second box.
    """
    if torch.distributed.get_rank() == 0: 
        print(
            "> initializing tensor model parallel with size {}".format(
                tensor_model_parallel_size_
            )
        )
        print( # 打印出流水线模型并行的度
            "> initializing pipeline model parallel with size {}".format(
                pipeline_model_parallel_size_
            )
        )
    # Get world size and rank. Ensure some consistencies.
    assert torch.distributed.is_initialized() # 确保torch已经做了分布式初始化
    world_size = torch.distributed.get_world_size() # 得到全局进程的总数
    tensor_model_parallel_size = min(tensor_model_parallel_size_, world_size)
    pipeline_model_parallel_size = min(pipeline_model_parallel_size_, world_size)

    ensure_divisibility( # 后者表示一个完整模型所占的gpu数,我们要保证前者能被后者整除
        world_size, tensor_model_parallel_size * pipeline_model_parallel_size
    )
    # 在codegeex中,TP_size=8, PP_size=1,world_size = 1536,因此DP_size是1536/(8*1) = 192
    data_parallel_size = world_size // ( # 根据TP_size和PP_size,求出DP_size
        tensor_model_parallel_size * pipeline_model_parallel_size
    )

    num_tensor_model_parallel_groups = world_size // tensor_model_parallel_size # TP的组数
    num_pipeline_model_parallel_groups = world_size // pipeline_model_parallel_size # PP的组数
    num_data_parallel_groups = world_size // data_parallel_size # DP的组数

    if virtual_pipeline_model_parallel_size_ is not None: 
        global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
        global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
        _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = 0
        _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = (
            virtual_pipeline_model_parallel_size_
        )

    rank = torch.distributed.get_rank() # 获取当前进程的全局rank

    # Build the data-parallel groups.(设置DP组)
    global _DATA_PARALLEL_GROUP # 保存DP组,如[[0,2], [1,3]...],数字表示进进程的全局序号
    assert _DATA_PARALLEL_GROUP is None, "data parallel group is already initialized" 
    all_data_parallel_group_ranks = []
    for i in range(pipeline_model_parallel_size): 
        start_rank = i * num_pipeline_model_parallel_groups
        end_rank = (i + 1) * num_pipeline_model_parallel_groups
        for j in range(tensor_model_parallel_size): 
            ranks = range(start_rank + j, end_rank, tensor_model_parallel_size)
            all_data_parallel_group_ranks.append(list(ranks))
            group = torch.distributed.new_group(ranks) # 设置DP组
            if rank in ranks: 
                _DATA_PARALLEL_GROUP = group 

    # Build the model-parallel groups.(设置MP组)
    global _MODEL_PARALLEL_GROUP # 保存MP组
    assert _MODEL_PARALLEL_GROUP is None, "model parallel group is already initialized"
    for i in range(data_parallel_size):
        ranks = [
            data_parallel_group_ranks[i]
            for data_parallel_group_ranks in all_data_parallel_group_ranks
        ]
        group = torch.distributed.new_group(ranks) # 设置MP组
        if rank in ranks:
            _MODEL_PARALLEL_GROUP = group

    # Build the tensor model-parallel groups.(设置TP组)
    global _TENSOR_MODEL_PARALLEL_GROUP # 保存TP组
    assert (
        _TENSOR_MODEL_PARALLEL_GROUP is None
    ), "tensor model parallel group is already initialized"
    for i in range(num_tensor_model_parallel_groups):
        ranks = range(
            i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size
        )
        group = torch.distributed.new_group(ranks) # 设置TP组
        if rank in ranks:
            _TENSOR_MODEL_PARALLEL_GROUP = group

    # Build the pipeline model-parallel groups and embedding groups
    # (first and last rank in each pipeline model-parallel group).(设置PP组与embedding组)
    global _PIPELINE_MODEL_PARALLEL_GROUP # 设置PP组
    global _PIPELINE_GLOBAL_RANKS
    assert (
        _PIPELINE_MODEL_PARALLEL_GROUP is None
    ), "pipeline model parallel group is already initialized"
    global _EMBEDDING_GROUP
    assert _EMBEDDING_GROUP is None, "embedding group is already initialized"
    for i in range(num_pipeline_model_parallel_groups):
        ranks = range(i, world_size, num_pipeline_model_parallel_groups)
        group = torch.distributed.new_group(ranks) # 设置PP组
        if rank in ranks:
            _PIPELINE_MODEL_PARALLEL_GROUP = group
            _PIPELINE_GLOBAL_RANKS = ranks
        # Setup embedding group (to exchange gradients between
        # first and last stages).
        if len(ranks) > 1: 
            embedding_ranks = [ranks[0], ranks[-1]] 
        else:
            embedding_ranks = ranks
        group = torch.distributed.new_group(embedding_ranks) # 设置embedding组
        if rank in embedding_ranks:
            _EMBEDDING_GROUP = group

总结来说,我们采用torch.distributed.new_group(ranks) 在进程大组下设置子组。ranks是list of list,表示对进程序号的划分,例如设置DP组,则ranks为[[0,2], [1,3]...],以此类推。我们将划分结果存在全局变量中(例如_DATA_PARALLEL_GROUP),方便我们在后续切割模型时使用。

同时,我们定义以下函数,使得对于任意一个进程,我们都能查到它在DP/TP/PP组中的局部序号(local_rank),以及它对应的DP/TP/PP组的world_size。

Q:如何确保逻辑分组和物理GPU对应上,使得TP、DP尽量跑在同一个节点上?

A:torchrun能够确保rank按照节点连续分配,只要TP size不超过节点GPU数量,TP天然跑在同一个节点。

模型定义和切分

setup_model_and_optimizer是整个模型并行的入口函数,如下图,它主要由”定义模型架构并切割模型“,“设置optimizer”和“设置学习率”三部分组成。我们关注的重点在第一部分上(get_model)

image

get_model(model_provider) 函数主要做了两件事:

  • 在CPU上定义模型。model_provider 是一个函数,调用它即可返回Megatron的分布式模型,也就是一个MegatronModule类的子类,这个将是下文要介绍的重点。
  • 把模型从CPU搬运至GPU上。这里有两种方法可供选择:
    • 方案一:借助deepspeed进行管理。
    • 方案二:手动搬运管理。这里需要我们以下事情:
      • 显式搬运。即手动将模型搬运到当前进程所对应的GPU上。
      • 权重精度设定。
      • 初始化DP组。这里指的是定义DP组间forward、backward和梯度计算与通讯等方法。在Megatron中,TP和PP组的这些方法是人为定义的,而DP组则是可以用现成的(torch的DistributedDataParallel)。

在 Megatron 中,模型并行所需的张量分割与通信操作内嵌于模型定义之中,这里实现的模型就是Megatron的分布式模型。MegatronModule 是所有分布式模型和组件的基类,统一了模型组件的接口,并为模型并行(Model Parallelism)及训练/推理流程提供核心支持。

模型切分的基本原理,是利用可获取到的进程ID(全局或DP/TP/PP组内的),在不同进程依靠类似if..else.. 来解决模型差异化问题。

有两种方式对模型进行切割:

  • 方案一:先定义出完整的模型,并对模型参数做初始化,然后根据进程id取出相应子模型,搬运到GPU上
  • 方案二:直接根据进程id,设计好当前子模型,做参数初始化,搬运到GPU上

这两者从结果上可以等价,它们的核心差别,在于“随机种子”的设定,它关系到模型是否能够复现。例如,对于Word Embedding来说,切分的WE1和WE2间需要采用不同的随机种子。因为若采用相同的随机种子,则WE1和WE2的结果完全一样,这不等价于先随机初始化WE,再将它进行切割。

一般在TP/PP组内,设定不同的随机种子。而在DP组内,设定相同的随机种子

方案1(先做整体初始化再切割)在代码里被称为“CPU上的初始化”(_initialize_affine_weight_cpu)。

def _initialize_affine_weight_cpu(
    weight,
    output_size,
    input_size,
    per_partition_size,
    partition_dim,
    init_method,
    stride=1,
    return_master_weight=False,
):
    """Initialize affine weight for model parallel.

    Build the master weight on all processes and scatter
    the relevant chunk."""

    set_tensor_model_parallel_attributes(
        tensor=weight, is_parallel=True, dim=partition_dim, stride=stride
    )

    # Initialize master weight
    # 先初始化完整的master_weight,比传入的weight大
    master_weight = torch.empty(
        output_size, input_size, dtype=torch.float, requires_grad=False
    )
    init_method(master_weight)
    args = get_args()
    master_weight = master_weight.to(dtype=args.params_dtype)

    # Split and copy
    per_partition_per_stride_size = divide(per_partition_size, stride)
    weight_list = torch.split(
        master_weight, per_partition_per_stride_size, dim=partition_dim
    )
    rank = get_tensor_model_parallel_rank()
    world_size = get_tensor_model_parallel_world_size()
    my_weight_list = weight_list[rank::world_size]

    with torch.no_grad():
        torch.cat(my_weight_list, dim=partition_dim, out=weight)  # 传入的weight是在cpu上创建的
    if return_master_weight:
        return master_weight
    return None

方案2(直接对局部初始化)被称为“在GPU上的初始化”(_initialize_affine_weight_gpu)。

def _initialize_affine_weight_gpu(weight, init_method, partition_dim, stride=1):
    """Initialize affine weight for model parallel on GPU."""

    set_tensor_model_parallel_attributes(
        tensor=weight, is_parallel=True, dim=partition_dim, stride=stride
    )

    if ds_checkpointing.is_configured():
        global get_cuda_rng_tracker
        get_cuda_rng_tracker = ds_checkpointing.get_cuda_rng_tracker

    with get_cuda_rng_tracker().fork():
        init_method(weight)  # 传入的weight在GPU上创建了,直接初始化

通过use_cpu_initialization参数控制权重初始化方法,当使用“在GPU上的初始化”时,权重其实已经被搬运到GPU上了。

就像pytorch里的复杂模型是由一些基础模块组合而来一样,Megatron里的分布式模型也由一些基础模块组成,这些组件就好比pytorch里组件的“并行版”,完成了张量并行和流水线并行的实现。

以下介绍Megatron中常见的分布式模型组件的定义,而更为复杂的分布式Transformer层和具体的LLM模型,不过是这些基础的组件搭成的。

Embedding

Emebdding类定义了word/position/segment embedding,并定义输入X过embedding层的计算方法。

class Embedding(MegatronModule):
    """Language model embeddings.

    Arguments:
        hidden_size: hidden size 
        vocab_size: vocabulary size 
        max_sequence_length: maximum size of sequence. This
                             is used for positional embedding
        embedding_dropout_prob: dropout probability for embeddings
        init_method: weight initialization method
        num_tokentypes: size of the token-type embeddings. 0 value
                        will ignore this embedding
    """

    def __init__(
        self,
        hidden_size, # 每个token的向量维度
        vocab_size, # 词表大小
        max_sequence_length, # 最长序列长度
        embedding_dropout_prob, # dropout probability for embeddings
        init_method, # 初始化权重的方法
        num_tokentypes=0, # 类似于Bert中的segment type
    ):
        super(Embedding, self).__init__()

        args = get_args()

        self.hidden_size = hidden_size
        self.init_method = init_method
        self.num_tokentypes = num_tokentypes
        self.max_sequence_length = max_sequence_length

        # WE size: (vocab_size//TP_N, hidden_size)
        # TP_N表示TP组模型并行度
        self.word_embeddings = mpu.VocabParallelEmbedding(
            vocab_size, self.hidden_size, init_method=self.init_method) 
        self._word_embeddings_key = 'word_embeddings'

        self.vocab_size = vocab_size

        # PE size: (max_seq_len, hidden_size)
        self.position_embeddings = torch.nn.Embedding(
            max_sequence_length, self.hidden_size)
        self.position_embeddings = self.position_embeddings.half()
        self._position_embeddings_key = 'position_embeddings'
        # Initialize the position embeddings.
        self.init_method(self.position_embeddings.weight)

        # TE_size:(num_tokentypes, hidden_size)
        # TE类似于Bert中的segment embedding
        self._tokentype_embeddings_key = 'tokentype_embeddings'
        if self.num_tokentypes > 0:
            self.tokentype_embeddings = torch.nn.Embedding(self.num_tokentypes,
                                                           self.hidden_size)
            # Initialize the token-type embeddings.
            self.init_method(self.tokentype_embeddings.weight)
        else:
            self.tokentype_embeddings = None

        # Embeddings dropout
        self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob)

    def add_tokentype_embeddings(self, num_tokentypes):
        """如果在pretrain阶段未定义TE,而在fine-tune阶段TE,则可通过此函数添加
        """
        if self.tokentype_embeddings is not None:
            raise Exception('tokentype embeddings is already initialized')
        if torch.distributed.get_rank() == 0:
            print('adding embedding for {} tokentypes'.format(num_tokentypes),
                  flush=True)
        self.num_tokentypes = num_tokentypes
        self.tokentype_embeddings = torch.nn.Embedding(num_tokentypes,
                                                       self.hidden_size)
        # Initialize the token-type embeddings.
        self.init_method(self.tokentype_embeddings.weight)

    def forward(self, input_ids, position_ids, tokentype_ids=None):
        """定义输入X过embedding层的计算方法
        """

        # words_embeddings size = (b, seq_len, hidden_size)
        # 再次注意:self.word_embeddings做forward时,最终的输出结果是AllReduce的
        words_embeddings = self.word_embeddings(input_ids) 
        # position_embeddings size = (b, seq_len, hidden_size)
        position_embeddings = self.position_embeddings(position_ids)
        # embedding = WE + PE
        # embedding size = (b, seq_len, hidden_size)
        embeddings = words_embeddings + position_embeddings
        # 依需要决定是否增加TE
        if tokentype_ids is not None: 
            assert self.tokentype_embeddings is not None
            embeddings = embeddings + self.tokentype_embeddings(tokentype_ids)
        else:
            assert self.tokentype_embeddings is None

        # Dropout.
        embeddings = self.embedding_dropout(embeddings)

        return embeddings

    def state_dict_for_save_checkpoint(
        self, destination=None, prefix='', keep_vars=False,
    ):
        """For easy load.
        在模型训练过程中及时读取当前参数,方便及时保存(做checkpoint)
        篇幅限制,这里不展示细节
        """
        ...

    def load_state_dict(self, state_dict, strict=True):
        """Customized load.
        用于模型的重载。例如训到一半挂掉了,我们就重新初始化一个新模型,
        重载上个checkpoint保存下的权重。
        篇幅限制,这里不展示细节
        """
        ...

可以看到,分布式Embedding的的实现依靠VocabParallelEmbedding 对象,其定义如下:

class VocabParallelEmbedding(torch.nn.Module):
    """Embedding parallelized in the vocabulary dimension.

    This is mainly adapted from torch.nn.Embedding and all the default
    values are kept.
    Arguments:
        num_embeddings: vocabulary size.
        embedding_dim: size of hidden state.
        init_method: method to initialize weights.
    """

    def __init__(self, num_embeddings, embedding_dim, init_method=init.xavier_normal_):
        super(VocabParallelEmbedding, self).__init__()
        # Keep the input dimensions.
        self.num_embeddings = num_embeddings # vocab_size
        self.embedding_dim = embedding_dim # hidden_state.
        # Set the detauls for compatibility.
        self.padding_idx = None
        self.max_norm = None
        self.norm_type = 2.0
        self.scale_grad_by_freq = False
        self.sparse = False
        self._weight = None
        # 当前进程所在TP组进程总数
        self.tensor_model_parallel_size = get_tensor_model_parallel_world_size()
        # 根据当前进程在TP组中的序号,确定其所需维护的WE部分,沿着vocab维度对WE进行切割
        # 例如,进程id=0, 维护词表序号[0,5)范围内的数据;进程id=1,维护[5,10)
        (
            self.vocab_start_index,
            self.vocab_end_index,
        ) = VocabUtility.vocab_range_from_global_vocab_size(
            self.num_embeddings,
            get_tensor_model_parallel_rank(),
            self.tensor_model_parallel_size,
        )
        # 计算当前进程维护的词表大小
        self.num_embeddings_per_partition = (
            self.vocab_end_index - self.vocab_start_index
        )

        # 对WE做初始化
        args = get_args() # 读取预训练参数配置
        if args.use_cpu_initialization: # CPU上做初始化
            self.weight = Parameter( # 在CPU上先生成一个完整的WE
                torch.empty(
                    self.num_embeddings_per_partition,
                    self.embedding_dim,
                    dtype=args.params_dtype,
                    # dtype=torch.float32,
                )
            )
            # 对CPU上的WE做切割(随机种子在初始化分布式中已设定好,不用变)
            _initialize_affine_weight_cpu(
                self.weight, 
                self.num_embeddings, 
                self.embedding_dim,
                self.num_embeddings_per_partition,
                0,
                init_method, # 初始化权重的方法,例如xavier之类
            )
        else: # 在GPU上做初始化
            self.weight = Parameter( # 生成一个切割好的WE
                torch.empty(
                    self.num_embeddings_per_partition,
                    self.embedding_dim,
                    device=torch.cuda.current_device(),
                    dtype=args.params_dtype,
                    # dtype=torch.float32,
                )
            )
            # 在GPU上做初始化,注意TP组内不同进程采用不同的随机种子
            _initialize_affine_weight_gpu(
                self.weight, init_method, partition_dim=0, stride=1
            )

    def forward(self, input_):
        """定义输入X过WE的计算方法,输出结果已经过AllReduce"""
        if self.tensor_model_parallel_size > 1: # 如果使用TP
            # 如果在当前进程维护的WE上,找不到对应的单词,那么对应位置就赋0
            # 例如当前的数据的tokenid是:[2,7,1,5],当前维护的词表是[0,1,2](start_index=0, end_index = 3),
            # 则mask之后的数据为[2,0,1,0]
            # Build the mask.
            input_mask = (input_ < self.vocab_start_index) | (
                input_ >= self.vocab_end_index
            )
            # Mask the input.
            masked_input = input_.clone() - self.vocab_start_index
            masked_input[input_mask] = 0
        else:
            masked_input = input_

        # 输入X,过当前进程维护的部分WE的结果
        output_parallel = F.embedding(
            masked_input, # tensor containing indices into the embedding matrix
            self.weight, # 切割好的word embedding的权重
            self.padding_idx,
            self.max_norm,
            self.norm_type,
            self.scale_grad_by_freq,
            self.sparse,
        )
        # 当前词表不维护的部分,都设为0
        if self.tensor_model_parallel_size > 1:
            output_parallel[input_mask, :] = 0.0 #

        # 将TP组各GPU上的结果做AllReduce
        output = reduce_from_tensor_model_parallel_region(output_parallel)
        return output

def _initialize_affine_weight_cpu(...):
    """CPU版权重初始化。这个不难,大家可以自己阅读"""
    ...

def _initialize_affine_weight_gpu(...):
    """GPU版权重初始化。特别关注设置随机种子部分"""
    ...
    # 借助deepspeed或自定义的get_cuda_rng_tracker方法,对随机种子进行操作
    # get_cuda_rng_tracker细节,大家可自行阅读源码
    if ds_checkpointing.is_configured():
        global get_cuda_rng_tracker
        get_cuda_rng_tracker = ds_checkpointing.get_cuda_rng_tracker

    with get_cuda_rng_tracker().fork(): 
        init_method(weight)

ParallelSelfAttention

由前文的原理可知,对QKV矩阵,采用“列切分”,对线性矩阵B,采用“行切分”。这样设计的好处是,在经过QKV的计算后,各进程在不用通讯的前提下,继续做线性计算,直到最后一步才AllReduce,起到降低通讯成本的作用。

列切分矩阵乘法的定义,如下:

class ColumnParallelLinear(torch.nn.Module):
    """Linear layer with column parallelism.

    The linear layer is defined as Y = XA + b. A is parallelized along
    its second dimension as A = [A_1, ..., A_p].

    Arguments:
        input_size: first dimension of matrix A. 
        output_size: second dimension of matrix A. 
        bias: If true, add bias
        gather_output: If true, call all-gether on output and make Y avaiable
                       to all GPUs, otherwise, every GPU will have its output
                       which is Y_i = XA_i 
        init_method: method to initialize weights. Note that bias is always set
                     to zero.
        stride: For the strided linear layers.
        keep_master_weight_for_test: This was added for testing and should be
                                     set to False. It returns the master weights
                                     used for initialization.
        skip_bias_add: This was added to enable performance optimations where bias
                       can be fused with other elementwise operations. we skip
                       adding bias but instead return it. 
    """
    # 该类定义了切割后的权重W,例如切割后的W1和W2都可分别视为该类的一个实例

    def __init__(
        self,
        input_size, # W的第一个维度
        output_size, # W的第二个维度
        bias=True, # 是否需要引入bias
        gather_output=True, # 决定是否要将Y1和Y2做all-gather
        init_method=init.xavier_normal_,
        stride=1,
        keep_master_weight_for_test=False,
        skip_bias_add=False,
        params_dtype=None,
        skip_init=False,
        device=None,
    ):
        super(ColumnParallelLinear, self).__init__()

        # Keep input parameters
        self.input_size = input_size 
        self.output_size = output_size 
        self.gather_output = gather_output 
        # Divide the weight matrix along the last dimension.
        # 当前进程所在TP组的总进程数
        world_size = get_tensor_model_parallel_world_size()
        # 每块GPU上维护的hidden_size的大小,等于 原hidden_zize // TP组总进程数
        self.output_size_per_partition = divide(output_size, world_size) 
        self.skip_bias_add = skip_bias_add 
        self.params_dtype = params_dtype 
        self.device = device 
        # Parameters.
        # Note: torch.nn.functional.linear performs XA^T + b and as a result 
        # Initialize weight.
        args = get_args() # 取得命令行所有的参数
        if not skip_init: 
            if args.use_cpu_initialization: # CPU上初始化
                self.weight = Parameter(  
                    torch.empty(
                        self.output_size_per_partition,
                        self.input_size,
                        dtype=self.params_dtype if self.params_dtype is not None else args.params_dtype,
                    )
                )
                self.master_weight = _initialize_affine_weight_cpu( # 
                    self.weight,
                    self.output_size,
                    self.input_size,
                    self.output_size_per_partition,
                    0,
                    init_method,
                    stride=stride,
                    return_master_weight=keep_master_weight_for_test,
                )
            else: # GPU上初始化
                self.weight = Parameter( 
                    torch.empty(
                        self.output_size_per_partition,
                        self.input_size,
                        device=self.device if self.device is not None else torch.cuda.current_device(),
                        dtype=self.params_dtype if self.params_dtype is not None else args.params_dtype,
                    )
                )
                _initialize_affine_weight_gpu( 
                    self.weight, init_method, partition_dim=0, stride=stride
                )
        else:
            self.register_parameter("weight", None)

        # 对bias做处理,道理同weight
        if bias and not skip_init: 
            if args.use_cpu_initialization: # CPU上初始化
                self.bias = Parameter(
                    torch.empty(self.output_size_per_partition, 
                                dtype=self.params_dtype if self.params_dtype is not None else args.params_dtype)
                ) 
            else: 
                self.bias = Parameter( # GPU上初始化
                    torch.empty(
                        self.output_size_per_partition,
                        device=self.device if self.device is not None else torch.cuda.current_device(),
                        dtype=self.params_dtype if self.params_dtype is not None else args.params_dtype,
                    )
                )

            set_tensor_model_parallel_attributes(self.bias, True, 0, stride) 
            # Always initialize bias to zero. 
            with torch.no_grad():
                self.bias.zero_()
        else:
            self.register_parameter("bias", None)

    def forward(self, input_):
        # 定义列切割中的f算子
        # 调用copy_to_tensor_model_parallel_region则新建一个_CopyToModelParallelRegion实例(见下)
        input_parallel = copy_to_tensor_model_parallel_region(input_)

        bias = self.bias if not self.skip_bias_add else None # 定义bias
        output_parallel = F.linear(input_parallel, self.weight, bias) # X * 切割好的权重
        # 决定是否要对每个进程上的输出结果做All-Reduce
        if self.gather_output:
            # 定义列切割中的g算子
            # 调用gather_from_tensor_model_parallel_region则新建一个_GatherFromModelParallelRegion实例(见下)
            output = gather_from_tensor_model_parallel_region(output_parallel) # 把各GPU上的输出按照列gather起来后,作为最终输出
        else:
            output = output_parallel # 否则最终输出还是自己算的那块GPU
        output_bias = self.bias if self.skip_bias_add else None
        return output, output_bias

# 列切割中的f与g
class _CopyToModelParallelRegion(torch.autograd.Function):
    """Pass the input to the model parallel region."""
    # 列切割下的f算子
    # forward:copy输入
    # backward:对梯度做AllReduce

    @staticmethod
    def symbolic(graph, input_):
        return input_

    @staticmethod
    def forward(ctx, input_):
        return input_

    @staticmethod
    def backward(ctx, grad_output):
        return _reduce(grad_output)

class _GatherFromModelParallelRegion(torch.autograd.Function):
    """Gather the input from model parallel region and concatinate."""
    # 列切割中的g算子
    # forward:All-Gather输出
    # backward:对梯度,沿着列方向做split

    @staticmethod
    def symbolic(graph, input_):
        return _gather(input_)

    @staticmethod
    def forward(ctx, input_):
        return _gather(input_)

    @staticmethod
    def backward(ctx, grad_output):
        return _split(grad_output)

行切分矩阵乘法的定义,如下:

class RowParallelLinear(torch.nn.Module):
    """Linear layer with row parallelism.

    The linear layer is defined as Y = XA + b. A is parallelized along
    its first dimension and X along its second dimension as:
               -   -
              | A_1 |
              | .   |
          A = | .   |        X = [X_1, ..., X_p]
              | .   |
              | A_p |
               -   -
    Arguments:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias. Note that bias is not parallelized.
        input_is_parallel: If true, we assume that the input is already
                           split across the GPUs and we do not split
                           again.
        init_method: method to initialize weights. Note that bias is always set
                     to zero.
        stride: For the strided linear layers.
        keep_master_weight_for_test: This was added for testing and should be
                                     set to False. It returns the master weights
                                     used for initialization.
        skip_bias_add: This was added to enable performance optimations where bias
                       can be fused with other elementwise operations. we skip
                       adding bias but instead return it.
    """

    def __init__(
        self,
        input_size,
        output_size,
        bias=True,
        input_is_parallel=False,
        init_method=init.xavier_normal_,
        stride=1,
        keep_master_weight_for_test=False,
        skip_bias_add=False,
        params_dtype=None,
        skip_init=False,
        device=None,
    ):
        super(RowParallelLinear, self).__init__()

        # Keep input parameters
        self.input_size = input_size
        self.output_size = output_size
        self.input_is_parallel = input_is_parallel
        # Divide the weight matrix along the last dimension.
        world_size = get_tensor_model_parallel_world_size()
        self.input_size_per_partition = divide(input_size, world_size)
        self.skip_bias_add = skip_bias_add
        self.params_dtype = params_dtype
        self.device = device

        # Parameters.
        # Note: torch.nn.functional.linear performs XA^T + b and as a result
        # we allocate the transpose.
        # Initialize weight.
        args = get_args()
        if not skip_init:
            if args.use_cpu_initialization:
                self.weight = Parameter(
                    torch.empty(
                        self.output_size,
                        self.input_size_per_partition,
                        dtype=self.params_dtype if self.params_dtype is not None else args.params_dtype,
                    )
                )
                self.master_weight = _initialize_affine_weight_cpu(
                    self.weight,
                    self.output_size,
                    self.input_size,
                    self.input_size_per_partition,
                    1,
                    init_method,
                    stride=stride,
                    return_master_weight=keep_master_weight_for_test,
                )
            else:
                self.weight = Parameter(
                    torch.empty(
                        self.output_size,
                        self.input_size_per_partition,
                        device=self.device if self.device is not None else torch.cuda.current_device(),
                        dtype=self.params_dtype if self.params_dtype is not None else args.params_dtype,
                    )
                )
                _initialize_affine_weight_gpu(
                    self.weight, init_method, partition_dim=1, stride=stride
                )
        else:
            self.register_parameter("weight", None)

        if bias and not skip_init:
            if args.use_cpu_initialization:
                self.bias = Parameter(
                    torch.empty(self.output_size, 
                                dtype=self.params_dtype if self.params_dtype is not None else args.params_dtype)
                )
            else:
                self.bias = Parameter(
                    torch.empty(
                        self.output_size,
                        device=self.device if self.device is not None else torch.cuda.current_device(),
                        dtype=self.params_dtype if self.params_dtype is not None else args.params_dtype,
                    )
                )
            # Always initialize bias to zero.
            with torch.no_grad():
                self.bias.zero_()
        else:
            self.register_parameter("bias", None)

    def forward(self, input_):
        # Set up backprop all-reduce.
        if self.input_is_parallel:
            input_parallel = input_
        else:
            input_parallel = scatter_to_tensor_model_parallel_region(input_)
        # Matrix multiply.
        output_parallel = F.linear(input_parallel, self.weight)
        # All-reduce across all the partitions.
        output_ = reduce_from_tensor_model_parallel_region(output_parallel)
        if not self.skip_bias_add:
            output = output_ + self.bias if self.bias is not None else output_
            output_bias = None
        else:
            output = output_
            output_bias = self.bias
        return output, output_bias

# 行切割中的f和g算子
class _ScatterToModelParallelRegion(torch.autograd.Function):
    """Split the input and keep only the corresponding chuck to the rank."""
    # 行切割中的f算子
    # forward:沿列split输入
    # backward:all-gather梯度
    @staticmethod
    def symbolic(graph, input_):
        return _split(input_)

    @staticmethod
    def forward(ctx, input_):
        return _split(input_)

    @staticmethod
    def backward(ctx, grad_output):
        return _gather(grad_output)

class _ReduceFromModelParallelRegion(torch.autograd.Function):
    """All-reduce the input from the model parallel region."""
    # 行切割中的g算子
    # forward:AllReduce输出
    # backward:正常计算梯度,GPU间无需做任何通讯
    @staticmethod
    def symbolic(graph, input_):
        return _reduce(input_)

    @staticmethod
    def forward(ctx, input_):
        return _reduce(input_)

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output

ParallelSelfAttention 类由“列切分”ColumnParallelLinear和“行切分”RowParallelLinear组合而来。

class ParallelSelfAttention(MegatronModule):
    """Parallel self-attention layer abstract class.

    Self-attention layer takes input with size [b, s, h]
    and returns output of the same size.
    """

    def __init__(self, init_method,
                 output_layer_init_method, layer_number):
        super(ParallelSelfAttention, self).__init__()
        args = get_args()
        self.fp16 = args.fp16
        self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32
        self.layer_number = max(1, layer_number)

        # Per attention head and per partition values.
        world_size = mpu.get_model_parallel_world_size()
        self.hidden_size_per_partition = mpu.divide(
            args.hidden_size // 2 if args.compress else args.hidden_size,
            world_size)
        self.hidden_size_per_attention_head = mpu.divide(
            args.hidden_size // 2 if args.compress else args.hidden_size, args.num_attention_heads)
        self.num_attention_heads_per_partition = mpu.divide(
            args.num_attention_heads, world_size)
        if hasattr(args, 'attention_upweight'):
            self.attention_upweight = args.attention_upweight
        else:
            self.attention_upweight = None
        # Strided linear layer.
        self.query = mpu.ColumnParallelLinear(
            args.hidden_size,
            args.hidden_size // 2 if args.compress else args.hidden_size,
            gather_output=False,
            init_method=init_method)
        self.key = mpu.ColumnParallelLinear(
            args.hidden_size,
            args.hidden_size // 2 if args.compress else args.hidden_size,
            gather_output=False,
            init_method=init_method)
        self.value = mpu.ColumnParallelLinear(
            args.hidden_size,
            args.hidden_size // 2 if args.compress else args.hidden_size,
            gather_output=False,
            init_method=init_method)

        self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
        self.softmax = torch.nn.Softmax(dim=-1)

        # Dropout. Note that for a single iteration, this layer will generate
        # different outputs on different number of parallel partitions but
        # on average it should not be partition dependent.
        self.attention_dropout = torch.nn.Dropout(args.attention_dropout)

        # Output.
        self.dense = mpu.RowParallelLinear(
            args.hidden_size // 2 if args.compress else args.hidden_size,
            args.hidden_size,
            input_is_parallel=True if args.tensor_model_parallel_size > 1 else False,
            init_method=output_layer_init_method,
            skip_bias_add=True)

    def forward(
            self,
            hidden_states,
            attention_mask,
            layer_past=None,
            get_key_value=False,
            prompt_length=None,
            context_length=None,
    ):
        # hidden_states: [sq, b, h]

        # =====================
        # Query, Key, and Value
        # =====================
        ...

ParallelMLP

分布式MLP也是由ColumnParallelLinearRowParallelLinear基础组件构成。

class ParallelMLP(MegatronModule):
    """MLP.

    MLP will take the input with h hidden state, project it to 4*h
    hidden dimension, perform nonlinear transformation, and project the
    state back into h hidden dimension. At the end, dropout is also
    applied.
    """

    def __init__(
        self,
        init_method,
        output_layer_init_method,
        scale: int = 4,
    ):
        super(ParallelMLP, self).__init__()
        args = get_args()

        # Project to 4h.
        self.dense_h_to_4h = mpu.ColumnParallelLinear(
            args.hidden_size,
            scale * args.hidden_size,
            gather_output=False,
            init_method=init_method,
            # skip_bias_add=True,
        )

        self.activation_func = fast_gelu

        # Project back to h.
        self.dense_4h_to_h = mpu.RowParallelLinear(
            scale * args.hidden_size,
            args.hidden_size,
            input_is_parallel=True if args.tensor_model_parallel_size > 1 else False,
            init_method=output_layer_init_method,
            # skip_bias_add=True,
        )

    def forward(self, hidden_states):
        # [s, b, 4hp]
        intermediate_parallel, _ = self.dense_h_to_4h(hidden_states)
        intermediate_parallel = self.activation_func(intermediate_parallel)
        # [s, b, h]
        output, output_bias = self.dense_4h_to_h(intermediate_parallel)

        return output, output_bias

CrossEntropy

交叉熵的平行计算。核心类为_VocabParallelCrossEntropy。我们在原理篇中讲过交叉熵的并行计算,其优化核心是将通讯量从$bsv$降至$b*s$。但是Megatron代码中定义的交叉熵计算方式,稍微复杂一些,也和我们一般理解的交叉熵有些许差异

class _VocabParallelCrossEntropy(torch.autograd.Function):
    """
    分布式计算Loss    
    """
    @staticmethod
    def forward(ctx, vocab_parallel_logits, target):
        # 1. logit - global max(logit)操作,主要目的是防溢出
        logits_max = torch.max(vocab_parallel_logits, dim=-1)[0] # (b, s, 1)
        torch.distributed.all_reduce( # (b, s, 1)
            logits_max,
            op=torch.distributed.ReduceOp.MAX, # 找全局最大值
            group=get_tensor_model_parallel_group(),
        )
        # Subtract the maximum value. 
        vocab_parallel_logits.sub_(logits_max.unsqueeze(dim=-1)) # 原始GPU上维护的logits减去每行最大值(防止溢出)

        # 2、根据当前进程id,取出当前进程所维护词表序号等信息
        # 函数,能够获取当前进程所维护词表的start_index和end_index
        get_vocab_range = VocabUtility.vocab_range_from_per_partition_vocab_size 
        # 这块GPU上logits最后一维的大小,等于所维护的词表的大小(v/N)
        partition_vocab_size = vocab_parallel_logits.size()[-1]
        # 取得当前进程所在TP组中的序号
        rank = get_tensor_model_parallel_rank()
        # 取得当前进程所在TP组的总进程数
        world_size = get_tensor_model_parallel_world_size()
        # 取得当前进程所维护的词表的start_index和end_index 
        vocab_start_index, vocab_end_index = get_vocab_range( 
            partition_vocab_size, rank, world_size
        )

        # 3. 基于真值,取出每个token在真值位置上的logit(即和真值的相似度)
        # Create a mask of valid vocab ids (1 means it needs to be masked)
        target_mask = (target < vocab_start_index) | (target >= vocab_end_index) # target = (b, s)
        masked_target = target.clone() - vocab_start_index
        masked_target[target_mask] = 0

        # Get predicted-logits = logits[target].
        # For Simplicity, we convert logits to a 2-D tensor with size
        # [*, partition-vocab-size] and target to a 1-D tensor of size [*].
        logits_2d = vocab_parallel_logits.view(-1, partition_vocab_size) # (b*s, v/N)
        masked_target_1d = masked_target.view(-1) # (b*s)
        arange_1d = torch.arange( # [b*s]
            start=0, end=logits_2d.size()[0], device=logits_2d.device
        )
        # logits_2d[arange_1d, masked_target_1d]: 
        # tensor的切片操作。arange_1d:取出所有的行。masked_target_1d:取出logit
        predicted_logits_1d = logits_2d[arange_1d, masked_target_1d] # (b*s)
        predicted_logits_1d = predicted_logits_1d.clone().contiguous()
        predicted_logits = predicted_logits_1d.view_as(target) # (b, s)
        predicted_logits[target_mask] = 0.0
        # All reduce is needed to get the chunks from other GPUs.
        torch.distributed.all_reduce( # allreduce之后得到的logit矩阵为(b, s),每一个位置表示对应真值位置的预测logit
            predicted_logits,
            op=torch.distributed.ReduceOp.SUM,
            group=get_tensor_model_parallel_group(),
        )

        # Sum of exponential of logits along vocab dimension across all GPUs.
        exp_logits = vocab_parallel_logits # (b, s, v/N)
        torch.exp(vocab_parallel_logits, out=exp_logits)
        sum_exp_logits = exp_logits.sum(dim=-1) # (b, s)
        torch.distributed.all_reduce(
            sum_exp_logits,
            op=torch.distributed.ReduceOp.SUM,
            group=get_tensor_model_parallel_group(),
        )

        # 4. 计算Loss = log(sum(exp(logits))) - predicted-logit.
        loss = torch.log(sum_exp_logits) - predicted_logits # (b, s)

        # Store softmax, target-mask and masked-target for backward pass.
        exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))
        ctx.save_for_backward(exp_logits, target_mask, masked_target_1d)

        return loss

    @staticmethod
    def backward(ctx, grad_output):

        # Retreive tensors from the forward path.
        softmax, target_mask, masked_target_1d = ctx.saved_tensors

        # All the inputs have softmax as their gradient.
        grad_input = softmax
        # For simplicity, work with the 2D gradient.
        partition_vocab_size = softmax.size()[-1]
        grad_2d = grad_input.view(-1, partition_vocab_size)

        # Add the gradient from matching classes.
        arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=grad_2d.device)
        grad_2d[arange_1d, masked_target_1d] -= 1.0 - target_mask.view(-1).float()

        # Finally elementwise multiplication with the output gradients.
        grad_input.mul_(grad_output.unsqueeze(dim=-1))

        return grad_input, None

<
Previous Post
Aho–Corasick算法在语音识别应用
>
Blog Archive
Archive of all previous blog posts