Shortcuts

TorchRec简介

Created On: Oct 02, 2024 | Last Updated: Oct 10, 2024 | Last Verified: Oct 02, 2024

TorchRec 是一个 PyTorch 库,专为使用嵌入构建可扩展且高效的推荐系统而设计。此教程引导您完成安装过程,介绍嵌入的概念,并突出其在推荐系统中的重要性。它提供了使用 PyTorch 和 TorchRec 实现嵌入的实际演示,重点是通过分布式训练和高级优化处理大型嵌入表。

What you will learn
  • 嵌入的基础知识及其在推荐系统中的作用

  • 如何设置 TorchRec 以在 PyTorch 环境中管理和实现嵌入

  • 探索将大型嵌入表分布到多个 GPU 上的高级技术

Prerequisites
  • PyTorch v2.5 或更高版本支持 CUDA 11.8 或更高版本

  • Python 3.9 或更高版本

  • FBGEMM

安装依赖项

在 Google Colab 或其他环境中运行此教程之前,请安装以下依赖项:

!pip3 install --pre torch --index-url https://download.pytorch.org/whl/cu121 -U
!pip3 install fbgemm_gpu --index-url https://download.pytorch.org/whl/cu121
!pip3 install torchmetrics==1.0.3
!pip3 install torchrec --index-url https://download.pytorch.org/whl/cu121

备注

如果您在 Google Colab 中运行此教程,请确保切换到 GPU 运行时类型。有关更多信息,请参见 启用 CUDA

嵌入

在构建推荐系统时,类别特征通常具有较大的基数,比如帖子、用户、广告等。

为了表示这些实体并建模这些关系,使用了**嵌入**。在机器学习中,嵌入是高维空间中代表复杂数据(如单词、图像或用户)意义的实数向量

嵌入在推荐系统中

现在您可能会问,这些嵌入是如何生成的呢?嵌入表示为**嵌入表**中的单独行,也称为嵌入权重。原因是嵌入或嵌入表权重通过梯度下降像模型的所有其他权重一样进行训练!

嵌入表只是用于存储嵌入的大型矩阵,具有两个维度 (B, N),其中:

  • B 是表存储的嵌入数

  • N 是每个嵌入的维度数 (N 维嵌入)。

嵌入表的输入代表嵌入查找,用于获取特定索引或行的嵌入。在许多大型系统中使用的推荐系统中,唯一 ID 不仅用于特定用户,还用于帖子和广告等实体,作为对各自嵌入表的查找索引!

嵌入在推荐系统中通过以下过程进行训练:

  • 输入/查找索引作为唯一 ID 被馈入模型。ID 被哈希到嵌入表的总大小,以防止出现 ID > 行数的问题。

  • 然后检索嵌入并进行**池化,如求和或求平均值**。这是必要的,因为每个示例可以有可变数量的嵌入,而模型期望一致的形状。

  • 嵌入与模型的其余部分一起生成预测结果,例如广告的 点击率 (CTR)

  • 通过预测结果和示例的标签计算损失,并通过梯度下降和反向传播更新模型的所有权重,包括与示例相关的嵌入权重。

这些嵌入对于表示类别特征(如用户、帖子和广告)至关重要,以便捕捉关系并提供良好的推荐。深度学习推荐模型 (DLRM) 的论文进一步讨论了在推荐系统中使用嵌入表的技术细节。

本教程介绍了嵌入的概念,展示了 TorchRec 特定模块和数据类型,并描述了 TorchRec 的分布式训练是如何工作的。

import torch

PyTorch 中的嵌入

在 PyTorch 中,我们有以下类型的嵌入:

  • torch.nn.Embedding: 一个嵌入表,前向遍历将直接返回嵌入本身。

  • torch.nn.EmbeddingBag: 一个嵌入表,前向遍历将返回嵌入,然后进行池化,例如求和或求平均值,也被称为**池化嵌入**。

在本部分中,我们将简要介绍如何通过传递索引到表中来执行嵌入查找。

num_embeddings, embedding_dim = 10, 4

# Initialize our embedding table
weights = torch.rand(num_embeddings, embedding_dim)
print("Weights:", weights)

# Pass in pre-generated weights just for example, typically weights are randomly initialized
embedding_collection = torch.nn.Embedding(
    num_embeddings, embedding_dim, _weight=weights
)
embedding_bag_collection = torch.nn.EmbeddingBag(
    num_embeddings, embedding_dim, _weight=weights
)

# Print out the tables, we should see the same weights as above
print("Embedding Collection Table: ", embedding_collection.weight)
print("Embedding Bag Collection Table: ", embedding_bag_collection.weight)

# Lookup rows (ids for embedding ids) from the embedding tables
# 2D tensor with shape (batch_size, ids for each batch)
ids = torch.tensor([[1, 3]])
print("Input row IDS: ", ids)

embeddings = embedding_collection(ids)

# Print out the embedding lookups
# You should see the specific embeddings be the same as the rows (ids) of the embedding tables above
print("Embedding Collection Results: ")
print(embeddings)
print("Shape: ", embeddings.shape)

# ``nn.EmbeddingBag`` default pooling is mean, so should be mean of batch dimension of values above
pooled_embeddings = embedding_bag_collection(ids)

print("Embedding Bag Collection Results: ")
print(pooled_embeddings)
print("Shape: ", pooled_embeddings.shape)

# ``nn.EmbeddingBag`` is the same as ``nn.Embedding`` but just with pooling (mean, sum, and so on)
# We can see that the mean of the embeddings of embedding_collection is the same as the output of the embedding_bag_collection
print("Mean: ", torch.mean(embedding_collection(ids), dim=1))

恭喜!现在您已经对如何使用嵌入表有了基本的了解——这是现代推荐系统的基石之一!这些表表示实体及其关系。例如,给定用户与他们喜欢的页面和帖子的关系。

TorchRec 特性概述

在上面的部分中,我们已经学习了如何使用嵌入表,这是现代推荐系统的基础之一!这些表表示实体和关系,例如用户、页面、帖子等。鉴于这些实体不断增加,通常会应用 哈希 函数以确保 ID 在某个嵌入表的范围内。然而,为了表示大量的实体并减少哈希冲突,这些表可能变得非常庞大(比如广告的数量)。事实上,这些表可能大到即使有 80G 的内存也无法适配到 1 个 GPU 上。

为了训练具有庞大嵌入表的模型,需要在 GPU 之间分片这些表,这带来了并行和优化方面的一系列新问题和机遇。幸运的是,我们拥有 TorchRec 库,它汇总并解决了这些问题。TorchRec 作为一个**为大规模分布式嵌入提供原语的库**。

接下来,我们将探索 TorchRec 库的主要特性。我们将从 torch.nn.Embedding 开始,并扩展到自定义 TorchRec 模块,探索分布式训练环境并生成嵌入的分片计划,查看固有的 TorchRec 优化,并将模型扩展为准备在 C++ 中进行推理。以下是本部分的简要概述:

  • TorchRec 模块和数据类型

  • 分布式训练、分片和优化

  • 推断

让我们从导入 TorchRec 开始:

import torchrec

本部分介绍了 TorchRec 模块和数据类型,其中包括 EmbeddingCollectionEmbeddingBagCollectionJaggedTensorKeyedJaggedTensorKeyedTensor 等实体。

EmbeddingBagEmbeddingBagCollection

我们已经探索了 torch.nn.Embeddingtorch.nn.EmbeddingBag。TorchRec 通过创建嵌入集合扩展了这些模块,也就是可以包含多个嵌入表的模块,如 EmbeddingCollectionEmbeddingBagCollection。我们将使用 EmbeddingBagCollection 来表示一组嵌入包。

在下面的示例代码中,我们创建了一个具有两个嵌入包的 EmbeddingBagCollection (EBC),一个表示 产品,另一个表示 用户。每个表 product_tableuser_table 都由大小为 4096 的 64 维嵌入表示。

ebc = torchrec.EmbeddingBagCollection(
    device="cpu",
    tables=[
        torchrec.EmbeddingBagConfig(
            name="product_table",
            embedding_dim=64,
            num_embeddings=4096,
            feature_names=["product"],
            pooling=torchrec.PoolingType.SUM,
        ),
        torchrec.EmbeddingBagConfig(
            name="user_table",
            embedding_dim=64,
            num_embeddings=4096,
            feature_names=["user"],
            pooling=torchrec.PoolingType.SUM,
        )
    ]
)
print(ebc.embedding_bags)

让我们检查一下 EmbeddingBagCollection 的前向方法以及模块的输入和输出:

import inspect

# Let's look at the ``EmbeddingBagCollection`` forward method
# What is a ``KeyedJaggedTensor`` and ``KeyedTensor``?
print(inspect.getsource(ebc.forward))

TorchRec 输入/输出数据类型

TorchRec 针对其模块的输入和输出具有独特的数据类型:JaggedTensorKeyedJaggedTensorKeyedTensor。您可能会问,为什么要创建新的数据类型来表示稀疏特征?要回答这个问题,我们必须了解稀疏特征在代码中的表示方式。

稀疏特征也称为 id_list_featureid_score_list_feature,它们是将用作嵌入表索引的 IDs,以检索该 ID 的嵌入。举个简单的例子,想象一个稀疏特征是用户互动过的广告。输入本身将是用户与之互动过的一组广告 ID,检索到的嵌入将是这些广告的语义表示。在代码中表示这些特征的棘手部分在于每个输入示例中 IDs 的数量是可变的。一个用户某天可能只互动了一个广告,而第二天则可能互动了三个。

一个简单的表示如下所示,其中我们有一个 lengths 张量表示批量示例中有多少个索引,以及一个 values 张量包含这些索引本身。

# Batch Size 2
# 1 ID in example 1, 2 IDs in example 2
id_list_feature_lengths = torch.tensor([1, 2])

# Values (IDs) tensor: ID 5 is in example 1, ID 7, 1 is in example 2
id_list_feature_values = torch.tensor([5, 7, 1])

接下来让我们来看看偏移量以及每个批次中包含的内容

# Lengths can be converted to offsets for easy indexing of values
id_list_feature_offsets = torch.cumsum(id_list_feature_lengths, dim=0)

print("Offsets: ", id_list_feature_offsets)
print("First Batch: ", id_list_feature_values[: id_list_feature_offsets[0]])
print(
    "Second Batch: ",
    id_list_feature_values[id_list_feature_offsets[0] : id_list_feature_offsets[1]],
)

from torchrec import JaggedTensor

# ``JaggedTensor`` is just a wrapper around lengths/offsets and values tensors!
jt = JaggedTensor(values=id_list_feature_values, lengths=id_list_feature_lengths)

# Automatically compute offsets from lengths
print("Offsets: ", jt.offsets())

# Convert to list of values
print("List of Values: ", jt.to_dense())

# ``__str__`` representation
print(jt)

from torchrec import KeyedJaggedTensor

# ``JaggedTensor`` represents IDs for 1 feature, but we have multiple features in an ``EmbeddingBagCollection``
# That's where ``KeyedJaggedTensor`` comes in! ``KeyedJaggedTensor`` is just multiple ``JaggedTensors`` for multiple id_list_feature_offsets
# From before, we have our two features "product" and "user". Let's create ``JaggedTensors`` for both!

product_jt = JaggedTensor(
    values=torch.tensor([1, 2, 1, 5]), lengths=torch.tensor([3, 1])
)
user_jt = JaggedTensor(values=torch.tensor([2, 3, 4, 1]), lengths=torch.tensor([2, 2]))

# Q1: How many batches are there, and which values are in the first batch for ``product_jt`` and ``user_jt``?
kjt = KeyedJaggedTensor.from_jt_dict({"product": product_jt, "user": user_jt})

# Look at our feature keys for the ``KeyedJaggedTensor``
print("Keys: ", kjt.keys())

# Look at the overall lengths for the ``KeyedJaggedTensor``
print("Lengths: ", kjt.lengths())

# Look at all values for ``KeyedJaggedTensor``
print("Values: ", kjt.values())

# Can convert ``KeyedJaggedTensor`` to dictionary representation
print("to_dict: ", kjt.to_dict())

# ``KeyedJaggedTensor`` string representation
print(kjt)

# Q2: What are the offsets for the ``KeyedJaggedTensor``?

# Now we can run a forward pass on our ``EmbeddingBagCollection`` from before
result = ebc(kjt)
result

# Result is a ``KeyedTensor``, which contains a list of the feature names and the embedding results
print(result.keys())

# The results shape is [2, 128], as batch size of 2. Reread previous section if you need a refresher on how the batch size is determined
# 128 for dimension of embedding. If you look at where we initialized the ``EmbeddingBagCollection``, we have two tables "product" and "user" of dimension 64 each
# meaning embeddings for both features are of size 64. 64 + 64 = 128
print(result.values().shape)

# Nice to_dict method to determine the embeddings that belong to each feature
result_dict = result.to_dict()
for key, embedding in result_dict.items():
    print(key, embedding.shape)

恭喜!您现在了解了 TorchRec 模块和数据类型。为自己鼓掌,为取得的进展感到骄傲。接下来,我们将学习分布式训练和分片。

分布式训练和分片

现在我们对 TorchRec 模块和数据类型有了了解,是时候进入下一阶段了。

请记住,TorchRec 的主要目的是为分布式嵌入提供原语。到目前为止,我们只在单个设备上处理嵌入表。由于嵌入表很小,这一点是可行的,但在生产环境中通常并非如此。嵌入表通常非常庞大,其中一个表无法适配到单个 GPU 上,因此需要多设备和分布式环境。

在本部分中,我们将探讨设置分布式环境的过程,了解实际生产训练是如何完成的,并探索嵌入表的分片,这些都是在 TorchRec 中完成的。

本部分也仅使用 1 个 GPU,但它将在分布式环境中对待。这仅是训练的限制,因为训练每个 GPU 都需要一个进程。而推理则没有这一要求。

在下面的示例代码中,我们设置了 PyTorch 分布式环境。

警告

如果您正在 Google Colab 中运行此代码,您只能调用此代码块一次,再次调用将导致错误,因为进程组只能初始化一次。

import os

import torch.distributed as dist

# Set up environment variables for distributed training
# RANK is which GPU we are on, default 0
os.environ["RANK"] = "0"
# How many devices in our "world", colab notebook can only handle 1 process
os.environ["WORLD_SIZE"] = "1"
# Localhost as we are training locally
os.environ["MASTER_ADDR"] = "localhost"
# Port for distributed training
os.environ["MASTER_PORT"] = "29500"

# nccl backend is for GPUs, gloo is for CPUs
dist.init_process_group(backend="gloo")

print(f"Distributed environment initialized: {dist}")

分布式嵌入

我们已经处理了主要的 TorchRec 模块:EmbeddingBagCollection。我们检查了它如何工作以及数据在 TorchRec 中的表示方式。然而,我们还没有探索 TorchRec 的主要部分之一,那就是 分布式嵌入

如今,GPU 是最流行的 ML 工作负载选择,因为它们能够执行比 CPU 多得多的浮点运算 (FLOPs)。然而,GPU 存在快速内存(HBM,相当于 CPU 的 RAM)稀缺的限制,通常只有几十 GB。

一个推荐系统模型可以包含远远超过 1 个 GPU 的内存限制的嵌入表,因此需要将嵌入表分布到多个 GPU 上,也就是所谓的 模型并行。另一方面,**数据并行**是指整个模型在每个 GPU 上都被复制,每个 GPU 处理一批不同的数据进行训练,在反向传播过程中同步梯度。

计算需求较低但内存需求较高的模型部分(如嵌入)通过模型并行分布,而计算需求较高但内存需求较低的模型部分(如密集层、MLP 等)通过数据并行分布。

分片

为了分布一个嵌入表,我们将嵌入表拆分成多个部分并将这些部分放置到不同的设备上,也就是所谓的“分片”。

分片嵌入表的方式有很多,最常见的方式有:

  • 表级分片:表完全放置在一个设备上

  • 列级分片:嵌入表的列进行分片

  • 行级分片:嵌入表的行进行分片

分片模块

虽然这些看起来需要解决和实现的任务很多,但您走运了。**TorchRec 提供了所有用于简单分布式训练和推理的原语!**事实上,TorchRec 模块在分布式环境中使用任何 TorchRec 模块时都有两个对应的类:

  • 模块分片器:此类提供一个 shard API,它处理 TorchRec 模块的分片,生成一个分片模块。 * 对于 EmbeddingBagCollection,分片器是 EmbeddingBagCollectionSharder

  • 分片模块:此类是 TorchRec 模块的分片变体。其输入/输出与普通 TorchRec 模块相同,但经过大量优化并能在分布式环境中工作。 * 对于 EmbeddingBagCollection,分片变体是 ShardedEmbeddingBagCollection

每个 TorchRec 模块都有未分片和分片变体。

  • 未分片版本用于原型设计和实验。

  • 分片版本用于分布式环境的分布式训练和推理。

TorchRec 模块的分片版本,例如 EmbeddingBagCollection,将处理模型并行性所需的一切,例如在 GPU 之间通信以将嵌入分布到正确的 GPU 上。

回顾我们的 EmbeddingBagCollection 模块

ebc

from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder
from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology
from torchrec.distributed.types import ShardingEnv

# Corresponding sharder for ``EmbeddingBagCollection`` module
sharder = EmbeddingBagCollectionSharder()

# ``ProcessGroup`` from torch.distributed initialized 2 cells above
pg = dist.GroupMember.WORLD
assert pg is not None, "Process group is not initialized"

print(f"Process Group: {pg}")

规划器

在展示分片如何工作之前,我们必须了解 规划器,它帮助我们确定最佳分片配置。

给定数量的嵌入表和 GPU 排列,有许多不同的分片配置是可能的。例如,给定两个嵌入表和两个 GPU,您可以:

  • 将一个表放置在每个 GPU 上

  • 将两个表都放到一个 GPU 上而另一个 GPU 不放置表

  • 将某些行和列分片放到每个 GPU 上

鉴于所有这些可能性,我们通常需要一个性能最佳的分片配置。

这就是规划器的作用。规划器可以根据嵌入表数量和 GPU 数量确定最佳配置。事实证明,手动完成这项任务非常困难,工程师需要考虑许多因素来确保优化的分片计划。幸运的是,当使用规划器时 TorchRec 提供了一个自动规划器。

TorchRec 规划器能:

  • 评估硬件的内存限制

  • 根据嵌入查找作为内存提取估算计算量

  • 处理数据特定因素

  • 考虑其他硬件特性,如带宽,以生成优化的分片计划

为了考虑到所有这些变量,TorchRec的规划器可以接受嵌入表的不同数据量、约束条件、硬件信息和拓扑结构,以帮助生成模型的最佳分片计划,这通常会在不同的栈之间提供。

要了解更多关于分片的信息,请参阅我们的分片教程。

# In our case, 1 GPU and compute on CUDA device
planner = EmbeddingShardingPlanner(
    topology=Topology(
        world_size=1,
        compute_device="cuda",
    )
)

# Run planner to get plan for sharding
plan = planner.collective_plan(ebc, [sharder], pg)

print(f"Sharding Plan generated: {plan}")

规划器结果

正如您在上面看到的,运行规划器时会产生大量输出。我们可以看到许多统计数据的计算,以及嵌入表的最终放置位置。

运行规划器的结果是一个静态计划,可以重复用于分片!这允许生产模型的分片计划是静态的,而不是每次都重新生成新的分片计划。下面,我们使用分片计划最终生成了我们的``ShardedEmbeddingBagCollection``。

# The static plan that was generated
plan

env = ShardingEnv.from_process_group(pg)

# Shard the ``EmbeddingBagCollection`` module using the ``EmbeddingBagCollectionSharder``
sharded_ebc = sharder.shard(ebc, plan.plan[""], env, torch.device("cuda"))

print(f"Sharded EBC Module: {sharded_ebc}")

使用``LazyAwaitable``进行GPU训练

请记住,TorchRec是一个针对分布式嵌入高度优化的库。TorchRec引入的一个概念是`LazyAwaitable`,以实现GPU训练的更高性能。您将在各种分片的TorchRec模块的输出中看到``LazyAwaitable``类型。``LazyAwaitable``类型的作用就是尽可能延迟计算某些结果,它通过扮演类似异步类型来实现这一点。

from typing import List

from torchrec.distributed.types import LazyAwaitable


# Demonstrate a ``LazyAwaitable`` type:
class ExampleAwaitable(LazyAwaitable[torch.Tensor]):
    def __init__(self, size: List[int]) -> None:
        super().__init__()
        self._size = size

    def _wait_impl(self) -> torch.Tensor:
        return torch.ones(self._size)


awaitable = ExampleAwaitable([3, 2])
awaitable.wait()

kjt = kjt.to("cuda")
output = sharded_ebc(kjt)
# The output of our sharded ``EmbeddingBagCollection`` module is an `Awaitable`?
print(output)

kt = output.wait()
# Now we have our ``KeyedTensor`` after calling ``.wait()``
# If you are confused as to why we have a ``KeyedTensor ``output,
# give yourself a refresher on the unsharded ``EmbeddingBagCollection`` module
print(type(kt))

print(kt.keys())

print(kt.values().shape)

# Same output format as unsharded ``EmbeddingBagCollection``
result_dict = kt.to_dict()
for key, embedding in result_dict.items():
    print(key, embedding.shape)

分片的TorchRec模块结构

我们已经成功分片了给定分片计划的``EmbeddingBagCollection``!分片模块具有TorchRec的通用API,抽象了多GPU间分布式通信和计算。事实上,这些API在训练和推断过程中进行了高度优化。以下是TorchRec为分布式训练和推断提供的三个通用API:

  • input_dist:处理从GPU到GPU的输入分发。

  • lookups:使用FBGEMM TBE以优化的、批处理的方式执行实际的嵌入查找(稍后详细介绍)。

  • output_dist:处理从GPU到GPU的输出分发。

输入和输出的分发通过NCCL Collectives(例如All-to-Alls)实现,这样所有GPU之间可发送和接收数据。TorchRec与PyTorch分布式进行接口连接,并为终端用户提供简洁抽象,去除了关注底层细节的需求。

反向传播过程会以相反的顺序执行所有这些收集操作,从而分发梯度。input_dist``lookup``和``output_dist``全部依赖于分片方案。由于我们是以表级方式分片的,这些API是通过TwPooledEmbeddingSharding模块构造的。

sharded_ebc

# Distribute input KJTs to all other GPUs and receive KJTs
sharded_ebc._input_dists

# Distribute output embeddings to all other GPUs and receive embeddings
sharded_ebc._output_dists

优化嵌入查找

在为嵌入表集合执行查找时,一个简单的解决方案是迭代所有的``nn.EmbeddingBags``并逐表进行查找。这正是标准未分片的``EmbeddingBagCollection``所做的。然而,虽然此解决方案简单,但非常慢。

FBGEMM是一个提供GPU操作符(即内核)的库,这些操作符经过高度优化。其中之一被称为表批处理嵌入(TBE),它提供了两个主要优化:

  • 表批处理,允许通过一次内核调用查找多个嵌入。

  • 优化器融合,允许模块根据PyTorch的优化器和参数更新自身。

ShardedEmbeddingBagCollection``使用FBGEMM TBE作为查找方式,而非传统的``nn.EmbeddingBags,以优化嵌入查找。

sharded_ebc._lookups

DistributedModelParallel

我们已经探索了一个``EmbeddingBagCollection``的分片!我们能够使用``EmbeddingBagCollectionSharder``和未分片的``EmbeddingBagCollection``生成``ShardedEmbeddingBagCollection``模块。这种工作流是可以接受的,但通常在实现模型并行时,使用DistributedModelParallel(DMP)作为标准接口。当用DMP包装您的模型(在我们的例子中是``ebc``),将发生以下过程:

  1. 决定如何分片模型。DMP会收集可用的分片器,并制定嵌入表(例如,EmbeddingBagCollection)分片的最佳计划。

  2. 实际上进行模型分片。这包括在适当的设备上为每个嵌入表分配内存。

DMP接收我们刚刚尝试过的所有内容,例如静态分片计划、分片器列表等。然而,它还具有一些很好的默认设置,可以无缝地分片TorchRec模型。在这个实例中,由于我们有两个嵌入表和一个GPU,TorchRec会将两者都放在单个GPU上。

ebc

model = torchrec.distributed.DistributedModelParallel(ebc, device=torch.device("cuda"))

out = model(kjt)
out.wait()

model

分片最佳实践

目前,我们的配置仅在一个GPU(或rank)上进行分片,这很简单:只需将所有表放在一个GPU的内存中。然而,在实际的生产场景中,嵌入表通常**分片到数百个GPU上**,使用不同的分片方法,如表级、行级和列级。确定适当的分片配置(防止内存不足问题)并在内存和计算方面保持平衡以优化性能,这非常重要。

添加优化器

请记住,TorchRec模块针对大规模分布式训练进行了超强优化。一个重要的优化与优化器相关。

TorchRec模块提供了无缝的API,将反向传播和优化步骤合并到训练中,这显著提高了性能,减少了内存使用,并增加了为不同的模型参数分配不同优化器的细粒度控制。

优化器类

TorchRec使用``CombinedOptimizer``,它包含一组``KeyedOptimizers``。CombinedOptimizer``实际上可以轻松地处理模型中各子组的多个优化器。``KeyedOptimizer``扩展了``torch.optim.Optimizer,通过参数字典初始化并公开参数。EmbeddingBagCollection``中的每个``TBE``模块都会有它自己的``KeyedOptimizer,然后这些优化器被组合成一个``CombinedOptimizer``。

TorchRec中的融合优化器

在使用``DistributedModelParallel``时,优化器是融合的,这意味着优化器更新是在反向传播中完成的。这是TorchRec和FBGEMM中的一个优化,优化器的嵌入梯度不会被物化,而是直接应用于参数。这带来了显著的内存节省,因为嵌入梯度通常与参数本身的大小相当。

但是,您可以选择让优化器变为``dense``,这不会应用此优化,并让您检查嵌入梯度或对其进行计算。对于这种情况,``dense``优化器将是您在PyTorch中进行模型训练的常规优化器。

一旦通过``DistributedModelParallel``创建优化器,您仍需要为与TorchRec嵌入模块无关的其他参数管理一个优化器。要找到这些其他参数,请使用``in_backward_optimizer_filter(model.named_parameters())``。像正常的Torch优化器一样为这些参数应用优化器,并将这部分优化器与``model.fused_optimizer``合并到一个``CombinedOptimizer``中,您可以在训练循环中用它来进行``zero_grad``和``step``操作。

向``EmbeddingBagCollection``添加优化器

我们将通过两个方式这样做,它们是等效的,但可以根据您的偏好提供选择:

  1. 通过分片器中的``fused_params``的优化器kwargs。

  2. 通过``apply_optimizer_in_backward``,将优化器参数转换为``fused_params``以传递给``EmbeddingBagCollection``或``EmbeddingCollection``中的``TBE``。

# Option 1: Passing optimizer kwargs through fused parameters
from torchrec.optim.optimizers import in_backward_optimizer_filter
from fbgemm_gpu.split_embedding_configs import EmbOptimType


# We initialize the sharder with
fused_params = {
    "optimizer": EmbOptimType.EXACT_ROWWISE_ADAGRAD,
    "learning_rate": 0.02,
    "eps": 0.002,
}

# Initialize sharder with ``fused_params``
sharder_with_fused_params = EmbeddingBagCollectionSharder(fused_params=fused_params)

# We'll use same plan and unsharded EBC as before but this time with our new sharder
sharded_ebc_fused_params = sharder_with_fused_params.shard(ebc, plan.plan[""], env, torch.device("cuda"))

# Looking at the optimizer of each, we can see that the learning rate changed, which indicates our optimizer has been applied correctly.
# If seen, we can also look at the TBE logs of the cell to see that our new optimizer is indeed being applied
print(f"Original Sharded EBC fused optimizer: {sharded_ebc.fused_optimizer}")
print(f"Sharded EBC with fused parameters fused optimizer: {sharded_ebc_fused_params.fused_optimizer}")

print(f"Type of optimizer: {type(sharded_ebc_fused_params.fused_optimizer)}")

from torch.distributed.optim import _apply_optimizer_in_backward as apply_optimizer_in_backward
import copy
# Option 2: Applying optimizer through apply_optimizer_in_backward
# Note: we need to call apply_optimizer_in_backward on unsharded model first and then shard it

# We can achieve the same result as we did in the previous
ebc_apply_opt = copy.deepcopy(ebc)
optimizer_kwargs = {"lr": 0.5}

for name, param in ebc_apply_opt.named_parameters():
    print(f"{name=}")
    apply_optimizer_in_backward(torch.optim.SGD, [param], optimizer_kwargs)

sharded_ebc_apply_opt = sharder.shard(ebc_apply_opt, plan.plan[""], env, torch.device("cuda"))

# Now when we print the optimizer, we will see our new learning rate, you can verify momentum through the TBE logs as well if outputted
print(sharded_ebc_apply_opt.fused_optimizer)
print(type(sharded_ebc_apply_opt.fused_optimizer))

# We can also check through the filter other parameters that aren't associated with the "fused" optimizer(s)
# Practically, just non TorchRec module parameters. Since our module is just a TorchRec EBC
# there are no other parameters that aren't associated with TorchRec
print("Non Fused Model Parameters:")
print(dict(in_backward_optimizer_filter(sharded_ebc_fused_params.named_parameters())).keys())

# Here we do a dummy backwards call and see that parameter updates for fused
# optimizers happen as a result of the backward pass

ebc_output = sharded_ebc_fused_params(kjt).wait().values()
loss = torch.sum(torch.ones_like(ebc_output) - ebc_output)
print(f"First Iteration Loss: {loss}")

loss.backward()

ebc_output = sharded_ebc_fused_params(kjt).wait().values()
loss = torch.sum(torch.ones_like(ebc_output) - ebc_output)
# We don't call an optimizer.step(), so for the loss to have changed here,
# that means that the gradients were somehow updated, which is what the
# fused optimizer automatically handles for us
print(f"Second Iteration Loss: {loss}")

推断

现在我们能够训练分布式嵌入了,那么如何将训练好的模型优化用于推断呢?推断通常对**模型性能和大小**非常敏感。在Python环境中仅运行训练好的模型是非常低效的。训练环境和推断环境之间有两个关键区别:

  • 量化:推断模型通常会被量化,将模型参数的精度降低以获得更低的预测延迟和更小的模型大小。例如,将训练模型中的FP32(每权重4字节)转换为INT8(每权重1字节)。这是必要的,因为嵌入表的规模非常巨大,我们希望在推断中使用尽可能少的设备以最小化延迟。

  • C++环境:推断延迟非常重要,为了确保足够的性能,模型通常运行在C++环境中,或者在某些没有Python运行时的设备上。

TorchRec提供了将TorchRec模型转换为推断准备状态的基本措施:

  • 用于量化模型的API,自动使用FBGEMM TBE进行优化

  • 为分布式推断分片嵌入

  • 将模型编译为TorchScript(兼容C++)

在本节中,我们将介绍以下工作流程:

  • 量化模型

  • 对量化模型进行分片

  • 将分片量化的模型编译为TorchScript

ebc

class InferenceModule(torch.nn.Module):
    def __init__(self, ebc: torchrec.EmbeddingBagCollection):
        super().__init__()
        self.ebc_ = ebc

    def forward(self, kjt: KeyedJaggedTensor):
        return self.ebc_(kjt)

module = InferenceModule(ebc)
for name, param in module.named_parameters():
    # Here, the parameters should still be FP32, as we are using a standard EBC
    # FP32 is default, regularly used for training
    print(name, param.shape, param.dtype)

量化

正如您在上面看到的,普通EBC包含嵌入表权重作为FP32精度(每权重32位)。在这里,我们将使用TorchRec推断库,将模型的嵌入权重量化为INT8。

from torch import quantization as quant
from torchrec.modules.embedding_configs import QuantConfig
from torchrec.quant.embedding_modules import (
    EmbeddingBagCollection as QuantEmbeddingBagCollection,
)


quant_dtype = torch.int8


qconfig = QuantConfig(
    # dtype of the result of the embedding lookup, post activation
    # torch.float generally for compatibility with rest of the model
    # as rest of the model here usually isn't quantized
    activation=quant.PlaceholderObserver.with_args(dtype=torch.float),
    # quantized type for embedding weights, aka parameters to actually quantize
    weight=quant.PlaceholderObserver.with_args(dtype=quant_dtype),
)
qconfig_spec = {
    # Map of module type to qconfig
    torchrec.EmbeddingBagCollection: qconfig,
}
mapping = {
    # Map of module type to quantized module type
    torchrec.EmbeddingBagCollection: QuantEmbeddingBagCollection,
}


module = InferenceModule(ebc)

# Quantize the module
qebc = quant.quantize_dynamic(
    module,
    qconfig_spec=qconfig_spec,
    mapping=mapping,
    inplace=False,
)


print(f"Quantized EBC: {qebc}")

kjt = kjt.to("cpu")

qebc(kjt)

# Once quantized, goes from parameters -> buffers, as no longer trainable
for name, buffer in qebc.named_buffers():
    # The shapes of the tables should be the same but the dtype should be int8 now
    # post quantization
    print(name, buffer.shape, buffer.dtype)

分片

这里我们对TorchRec量化模型进行分片。这样可以确保我们使用通过FBGEMM TBE的高性能模块。为了与训练保持一致,这里我们使用一个设备(1 TBE)。

from torchrec import distributed as trec_dist
from torchrec.distributed.shard import _shard_modules


sharded_qebc = _shard_modules(
    module=qebc,
    device=torch.device("cpu"),
    env=trec_dist.ShardingEnv.from_local(
        1,
        0,
    ),
)


print(f"Sharded Quantized EBC: {sharded_qebc}")

sharded_qebc(kjt)

编译

现在我们有了优化的急切TorchRec推理模型。下一步是确保此模型可以在C++中加载,因为目前它只能在Python运行时中运行。

Meta推荐的编译方法有两种:`torch.fx tracing <https://pytorch.org/docs/stable/fx.html>`__(生成模型中间表示)并将结果转换为TorchScript,TorchScript与C++兼容。

from torchrec.fx import Tracer


tracer = Tracer(leaf_modules=["IntNBitTableBatchedEmbeddingBagsCodegen"])

graph = tracer.trace(sharded_qebc)
gm = torch.fx.GraphModule(sharded_qebc, graph)

print("Graph Module Created!")

print(gm.code)

scripted_gm = torch.jit.script(gm)
print("Scripted Graph Module Created!")

print(scripted_gm.code)

结论

在本教程中,您从训练一个分布式RecSys模型到使其准备推理。此外,TorchRec仓库 有完整示例,说明如何将TorchRec TorchScript模型加载到C++中进行推理。

了解更多信息,请参阅我们的 dlrm 示例,其中包括使用 Deep Learning Recommendation Model for Personalization and Recommendation Systems 描述的方法在Criteo 1TB数据集上进行多节点训练。

**脚本的总运行时间:**(0分钟0.000秒)

通过Sphinx-Gallery生成的图集

文档

访问 PyTorch 的详细开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源