TorchRec 简介
TorchRec 是一个专为使用嵌入构建可扩展且高效的推荐系统而设计的 PyTorch 库。本教程将引导您完成安装过程,介绍嵌入的概念,并强调它们在推荐系统中的重要性。它提供了使用 PyTorch 和 TorchRec 实现嵌入的实用演示,重点是通过分布式训练和高级优化技术来处理大型嵌入表。
你将学到什么
-
嵌入的基础知识及其在推荐系统中的作用
-
如何设置 TorchRec 以在 PyTorch 环境中管理和实现嵌入
-
探索将大型嵌入表分布在多个 GPU 上的高级技术
先决条件
-
PyTorch v2.5 或更高版本,且 CUDA 11.8 或更高版本
-
Python 3.9 或更高版本
安装依赖
在 Google Colab 或其他环境中运行本教程之前,请先安装以下依赖项:
!pip3install--pretorch--index-urlhttps://download.pytorch.org/whl/cu121-U
!pip3installfbgemm_gpu--index-urlhttps://download.pytorch.org/whl/cu121
!pip3installtorchmetrics==1.0.3
!pip3installtorchrec--index-urlhttps://download.pytorch.org/whl/cu121
如果您在 Google Colab 中运行此操作,请确保切换到 GPU 运行时类型。有关更多信息,请参阅 启用 CUDA
嵌入
在构建推荐系统时,分类特征通常具有巨大的基数,例如帖子、用户、广告等。
为了表示这些实体并建模这些关系,通常会使用嵌入。在机器学习中,嵌入是高维空间中的实数向量,用于表示复杂数据(如单词、图像或用户)中的含义。
推荐系统中的嵌入
现在您可能会想,这些嵌入最初是如何生成的呢?其实,嵌入是以 嵌入表 中的单独行来表示的,也称为嵌入权重。这是因为嵌入或嵌入表权重与模型的其他所有权重一样,都是通过梯度下降法进行训练的!
嵌入表本质上是一个用于存储嵌入的大型矩阵,具有两个维度(B, N),其中:
-
B 是表中存储的嵌入数量
-
N 是每个嵌入的维度数量(N 维嵌入)。
嵌入表的输入表示嵌入查找,用于检索特定索引或行的嵌入。在推荐系统中,例如许多大型系统中使用的推荐系统,唯一ID不仅用于特定用户,还用于跨实体(如帖子和广告),作为相应嵌入表的查找索引!
在推荐系统中,嵌入通过以下过程进行训练:
-
输入/查找索引作为唯一 ID 输入到模型中。ID 会被哈希到嵌入表的总大小,以防止 ID 大于行数时出现问题。
-
然后检索嵌入并进行池化操作,例如对嵌入求和或取平均值。这是必要的,因为每个样本的嵌入数量可能不同,而模型期望输入的形状一致。
-
嵌入与模型的其他部分结合使用以生成预测,例如广告的点击率(CTR)。
-
根据预测值和样本的标签计算损失,并通过梯度下降和反向传播更新模型的所有权重,包括与该样本相关的嵌入权重。
这些嵌入对于表示分类特征(如用户、帖子和广告)至关重要,以便捕捉关系并做出良好的推荐。深度学习推荐模型(DLRM)论文详细讨论了在推荐系统中使用嵌入表的技术细节。
本教程介绍了嵌入的概念,展示了 TorchRec 特定的模块和数据类型,并描述了 TorchRec 如何实现分布式训练。
importtorch
PyTorch 中的嵌入
在 PyTorch 中,我们有以下几种类型的嵌入:
-
torch.nn.Embedding
: 一个嵌入表,前向传播时直接返回嵌入本身。 -
torch.nn.EmbeddingBag
: 嵌入表,前向传播时返回嵌入后对其进行池化操作,例如求和或求平均,也称为池化嵌入。
在本节中,我们将简要介绍如何通过将索引传入表中来执行嵌入查找。
num_embeddings, embedding_dim = 10, 4
# Initialize our embedding table
weights = torch.rand(num_embeddings, embedding_dim)
print("Weights:", weights)
# Pass in pre-generated weights just for example, typically weights are randomly initialized
embedding_collection = torch.nn.Embedding(
num_embeddings, embedding_dim, _weight=weights
)
embedding_bag_collection = torch.nn.EmbeddingBag(
num_embeddings, embedding_dim, _weight=weights
)
# Print out the tables, we should see the same weights as above
print("Embedding Collection Table: ", embedding_collection.weight)
print("Embedding Bag Collection Table: ", embedding_bag_collection.weight)
# Lookup rows (ids for embedding ids) from the embedding tables
# 2D tensor with shape (batch_size, ids for each batch)
ids = torch.tensor([[1, 3]])
print("Input row IDS: ", ids)
embeddings = embedding_collection(ids)
# Print out the embedding lookups
# You should see the specific embeddings be the same as the rows (ids) of the embedding tables above
print("Embedding Collection Results: ")
print(embeddings)
print("Shape: ", embeddings.shape)
# ``nn.EmbeddingBag`` default pooling is mean, so should be mean of batch dimension of values above
pooled_embeddings = embedding_bag_collection(ids)
print("Embedding Bag Collection Results: ")
print(pooled_embeddings)
print("Shape: ", pooled_embeddings.shape)
# ``nn.EmbeddingBag`` is the same as ``nn.Embedding`` but just with pooling (mean, sum, and so on)
# We can see that the mean of the embeddings of embedding_collection is the same as the output of the embedding_bag_collection
print("Mean: ", torch.mean(embedding_collection(ids), dim=1))
Weights: tensor([[0.8823, 0.9150, 0.3829, 0.9593],
[0.3904, 0.6009, 0.2566, 0.7936],
[0.9408, 0.1332, 0.9346, 0.5936],
[0.8694, 0.5677, 0.7411, 0.4294],
[0.8854, 0.5739, 0.2666, 0.6274],
[0.2696, 0.4414, 0.2969, 0.8317],
[0.1053, 0.2695, 0.3588, 0.1994],
[0.5472, 0.0062, 0.9516, 0.0753],
[0.8860, 0.5832, 0.3376, 0.8090],
[0.5779, 0.9040, 0.5547, 0.3423]])
Embedding Collection Table: Parameter containing:
tensor([[0.8823, 0.9150, 0.3829, 0.9593],
[0.3904, 0.6009, 0.2566, 0.7936],
[0.9408, 0.1332, 0.9346, 0.5936],
[0.8694, 0.5677, 0.7411, 0.4294],
[0.8854, 0.5739, 0.2666, 0.6274],
[0.2696, 0.4414, 0.2969, 0.8317],
[0.1053, 0.2695, 0.3588, 0.1994],
[0.5472, 0.0062, 0.9516, 0.0753],
[0.8860, 0.5832, 0.3376, 0.8090],
[0.5779, 0.9040, 0.5547, 0.3423]], requires_grad=True)
Embedding Bag Collection Table: Parameter containing:
tensor([[0.8823, 0.9150, 0.3829, 0.9593],
[0.3904, 0.6009, 0.2566, 0.7936],
[0.9408, 0.1332, 0.9346, 0.5936],
[0.8694, 0.5677, 0.7411, 0.4294],
[0.8854, 0.5739, 0.2666, 0.6274],
[0.2696, 0.4414, 0.2969, 0.8317],
[0.1053, 0.2695, 0.3588, 0.1994],
[0.5472, 0.0062, 0.9516, 0.0753],
[0.8860, 0.5832, 0.3376, 0.8090],
[0.5779, 0.9040, 0.5547, 0.3423]], requires_grad=True)
Input row IDS: tensor([[1, 3]])
Embedding Collection Results:
tensor([[[0.3904, 0.6009, 0.2566, 0.7936],
[0.8694, 0.5677, 0.7411, 0.4294]]], grad_fn=<EmbeddingBackward0>)
Shape: torch.Size([1, 2, 4])
Embedding Bag Collection Results:
tensor([[0.6299, 0.5843, 0.4988, 0.6115]], grad_fn=<EmbeddingBagBackward0>)
Shape: torch.Size([1, 4])
Mean: tensor([[0.6299, 0.5843, 0.4988, 0.6115]], grad_fn=<MeanBackward1>)
恭喜!现在您已经对如何使用嵌入表有了基本的了解——这是现代推荐系统的基石之一!这些表代表实体及其关系。例如,某个用户与他们点赞的页面和帖子之间的关系。
TorchRec 功能概述
在上面的章节中,我们已经学习了如何使用嵌入表,这是现代推荐系统的基石之一!这些表表示实体及其关系,例如用户、页面、帖子等。由于这些实体的数量不断增加,通常会应用哈希函数以确保ID在某个嵌入表的范围内。然而,为了表示大量的实体并减少哈希冲突,这些表可能会变得非常庞大(例如,考虑广告的数量)。事实上,这些表可能会变得如此庞大,以至于即使有80G的内存,也无法在单个GPU上容纳它们。
为了训练包含大规模嵌入表的模型,需要将这些表分散到多个 GPU 上进行分片,这在并行性和优化方面引入了一系列全新的问题和机遇。幸运的是,我们有 TorchRec 库,它已经遇到、整合并解决了许多这些问题。TorchRec 是一个提供大规模分布式嵌入原语的库。
接下来,我们将探索 TorchRec 库的主要功能。我们将从 torch.nn.Embedding
开始,并扩展到自定义的 TorchRec 模块,探索为嵌入生成分片计划的分布式训练环境,了解 TorchRec 的内置优化,并将模型扩展为适用于 C++ 推理。以下是本节内容的简要概述:
-
TorchRec 模块与数据类型
-
分布式训练、分片与优化
-
推理
让我们从导入 TorchRec 开始:
importtorchrec
本节将介绍 TorchRec 模块和数据类型,包括 EmbeddingCollection
、EmbeddingBagCollection
、JaggedTensor
、KeyedJaggedTensor
、KeyedTensor
等实体。
从 EmbeddingBag
到 EmbeddingBagCollection
我们已经探讨过 torch.nn.Embedding
和 torch.nn.EmbeddingBag
。TorchRec 通过创建嵌入集合扩展了这些模块,换句话说,这些模块可以包含多个嵌入表,具体通过 EmbeddingCollection
和 EmbeddingBagCollection
实现。我们将使用 EmbeddingBagCollection
来表示一组嵌入包。
在下面的示例代码中,我们创建了一个包含两个嵌入包的 EmbeddingBagCollection
(EBC),其中一个表示产品,另一个表示用户。每个表,product_table
和 user_table
,都由一个维度为64、大小为4096的嵌入表示。
ebc = torchrec.EmbeddingBagCollection(
device="cpu",
tables=[
torchrec.EmbeddingBagConfig(
name="product_table",
embedding_dim=64,
num_embeddings=4096,
feature_names=["product"],
pooling=torchrec.PoolingType.SUM,
),
torchrec.EmbeddingBagConfig(
name="user_table",
embedding_dim=64,
num_embeddings=4096,
feature_names=["user"],
pooling=torchrec.PoolingType.SUM,
)
]
)
print(ebc.embedding_bags)
ModuleDict(
(product_table): EmbeddingBag(4096, 64, mode='sum')
(user_table): EmbeddingBag(4096, 64, mode='sum')
)
让我们检查一下 EmbeddingBagCollection
的 forward
方法以及该模块的输入和输出:
importinspect
# Let's look at the ``EmbeddingBagCollection`` forward method
# What is a ``KeyedJaggedTensor`` and ``KeyedTensor``?
print(inspect.getsource(ebc.forward))
def forward(self, features: KeyedJaggedTensor) -> KeyedTensor:
"""
Run the EmbeddingBagCollection forward pass. This method takes in a `KeyedJaggedTensor`
and returns a `KeyedTensor`, which is the result of pooling the embeddings for each feature.
Args:
features (KeyedJaggedTensor): Input KJT
Returns:
KeyedTensor
"""
flat_feature_names: List[str] = []
for names in self._feature_names:
flat_feature_names.extend(names)
inverse_indices = reorder_inverse_indices(
inverse_indices=features.inverse_indices_or_none(),
feature_names=flat_feature_names,
)
pooled_embeddings: List[torch.Tensor] = []
feature_dict = features.to_dict()
for i, embedding_bag in enumerate(self.embedding_bags.values()):
for feature_name in self._feature_names[i]:
f = feature_dict[feature_name]
res = embedding_bag(
input=f.values(),
offsets=f.offsets(),
per_sample_weights=f.weights() if self._is_weighted else None,
).float()
pooled_embeddings.append(res)
return KeyedTensor(
keys=self._embedding_names,
values=process_pooled_embeddings(
pooled_embeddings=pooled_embeddings,
inverse_indices=inverse_indices,
),
length_per_key=self._lengths_per_embedding,
)
TorchRec 输入/输出数据类型
TorchRec 为其模块的输入和输出定义了不同的数据类型:JaggedTensor
、KeyedJaggedTensor
和 KeyedTensor
。现在你可能会问,为什么要创建新的数据类型来表示稀疏特征?要回答这个问题,我们必须理解稀疏特征在代码中是如何表示的。
稀疏特征也被称为 id_list_feature
和 id_score_list_feature
,它们是将用作嵌入表索引的 ID,以检索该 ID 的嵌入。举一个非常简单的例子,假设一个稀疏特征是用户与之互动的广告。输入本身将是一组用户与之互动的广告 ID,而检索到的嵌入将是这些广告的语义表示。在代码中表示这些特征的棘手之处在于,在每个输入示例中,ID 的数量是可变的。某一天用户可能只与一个广告互动,而第二天他们可能与三个广告互动。
一个简单的表示如下所示,其中我们有一个 lengths
张量表示批次中每个示例的索引数量,以及一个包含索引本身的 values
张量。
# Batch Size 2
# 1 ID in example 1, 2 IDs in example 2
id_list_feature_lengths = torch.tensor([1, 2])
# Values (IDs) tensor: ID 5 is in example 1, ID 7, 1 is in example 2
id_list_feature_values = torch.tensor([5, 7, 1])
接下来,让我们看看偏移量以及每个批次中包含的内容。
# Lengths can be converted to offsets for easy indexing of values
id_list_feature_offsets = torch.cumsum(id_list_feature_lengths, dim=0)
print("Offsets: ", id_list_feature_offsets)
print("First Batch: ", id_list_feature_values[: id_list_feature_offsets[0]])
print(
"Second Batch: ",
id_list_feature_values[id_list_feature_offsets[0] : id_list_feature_offsets[1]],
)
fromtorchrecimport JaggedTensor
# ``JaggedTensor`` is just a wrapper around lengths/offsets and values tensors!
jt = JaggedTensor(values=id_list_feature_values, lengths=id_list_feature_lengths)
# Automatically compute offsets from lengths
print("Offsets: ", jt.offsets())
# Convert to list of values
print("List of Values: ", jt.to_dense())
# ``__str__`` representation
print(jt)
fromtorchrecimport KeyedJaggedTensor
# ``JaggedTensor`` represents IDs for 1 feature, but we have multiple features in an ``EmbeddingBagCollection``
# That's where ``KeyedJaggedTensor`` comes in! ``KeyedJaggedTensor`` is just multiple ``JaggedTensors`` for multiple id_list_feature_offsets
# From before, we have our two features "product" and "user". Let's create ``JaggedTensors`` for both!
product_jt = JaggedTensor(
values=torch.tensor([1, 2, 1, 5]), lengths=torch.tensor([3, 1])
)
user_jt = JaggedTensor(values=torch.tensor([2, 3, 4, 1]), lengths=torch.tensor([2, 2]))
# Q1: How many batches are there, and which values are in the first batch for ``product_jt`` and ``user_jt``?
kjt = KeyedJaggedTensor.from_jt_dict({"product": product_jt, "user": user_jt})
# Look at our feature keys for the ``KeyedJaggedTensor``
print("Keys: ", kjt.keys())
# Look at the overall lengths for the ``KeyedJaggedTensor``
print("Lengths: ", kjt.lengths())
# Look at all values for ``KeyedJaggedTensor``
print("Values: ", kjt.values())
# Can convert ``KeyedJaggedTensor`` to dictionary representation
print("to_dict: ", kjt.to_dict())
# ``KeyedJaggedTensor`` string representation
print(kjt)
# Q2: What are the offsets for the ``KeyedJaggedTensor``?
# Now we can run a forward pass on our ``EmbeddingBagCollection`` from before
result = ebc(kjt)
result
# Result is a ``KeyedTensor``, which contains a list of the feature names and the embedding results
print(result.keys())
# The results shape is [2, 128], as batch size of 2. Reread previous section if you need a refresher on how the batch size is determined
# 128 for dimension of embedding. If you look at where we initialized the ``EmbeddingBagCollection``, we have two tables "product" and "user" of dimension 64 each
# meaning embeddings for both features are of size 64. 64 + 64 = 128
print(result.values().shape)
# Nice to_dict method to determine the embeddings that belong to each feature
result_dict = result.to_dict()
for key, embedding in result_dict.items():
print(key, embedding.shape)
Offsets: tensor([1, 3])
First Batch: tensor([5])
Second Batch: tensor([7, 1])
Offsets: tensor([0, 1, 3])
List of Values: [tensor([5]), tensor([7, 1])]
JaggedTensor({
[[5], [7, 1]]
})
Keys: ['product', 'user']
Lengths: tensor([3, 1, 2, 2])
Values: tensor([1, 2, 1, 5, 2, 3, 4, 1])
to_dict: {'product': <torchrec.sparse.jagged_tensor.JaggedTensor object at 0x7f9a1b01f310>, 'user': <torchrec.sparse.jagged_tensor.JaggedTensor object at 0x7f9a1b01e860>}
KeyedJaggedTensor({
"product": [[1, 2, 1], [5]],
"user": [[2, 3], [4, 1]]
})
['product', 'user']
torch.Size([2, 128])
product torch.Size([2, 64])
user torch.Size([2, 64])
恭喜!您现在已经了解了 TorchRec 模块和数据类型。为您的坚持点赞。接下来,我们将学习分布式训练和分片技术。
分布式训练与分片
既然我们已经了解了 TorchRec 模块和数据类型,是时候将其提升到一个新的水平了。
请记住,TorchRec 的主要目的是为分布式嵌入提供基础组件。到目前为止,我们只在单设备上处理嵌入表。这在嵌入表较小的情况下是可行的,但在生产环境中通常并非如此。嵌入表往往会变得非常庞大,以至于一个表无法容纳在单个 GPU 上,这就需要使用多设备和分布式环境。
在本节中,我们将探讨如何设置分布式环境,如何进行实际的生产训练,以及如何使用 TorchRec 对嵌入表进行分片。
本节仅使用 1 个 GPU,但会以分布式方式进行处理。这只是训练的限制,因为训练需要为每个 GPU 分配一个进程。推理则没有这个要求。
在下面的示例代码中,我们设置了 PyTorch 的分布式环境。
如果您在 Google Colab 中运行此代码,只能调用此单元一次,再次调用会导致错误,因为进程组只能初始化一次。
importos
importtorch.distributedasdist
# Set up environment variables for distributed training
# RANK is which GPU we are on, default 0
os.environ["RANK"] = "0"
# How many devices in our "world", colab notebook can only handle 1 process
os.environ["WORLD_SIZE"] = "1"
# Localhost as we are training locally
os.environ["MASTER_ADDR"] = "localhost"
# Port for distributed training
os.environ["MASTER_PORT"] = "29500"
# nccl backend is for GPUs, gloo is for CPUs
dist.init_process_group(backend="gloo")
print(f"Distributed environment initialized: {dist}")
Distributed environment initialized: <module 'torch.distributed' from '/usr/local/lib/python3.10/dist-packages/torch/distributed/__init__.py'>
分布式嵌入
我们已经使用过 TorchRec 的主要模块:EmbeddingBagCollection
。我们探讨了它的工作原理以及数据在 TorchRec 中的表示方式。然而,我们尚未深入探索 TorchRec 的一个重要部分,即分布式嵌入。
目前,GPU 是机器学习工作负载中最受欢迎的选择,因为它们能够执行比 CPU 多几个数量级的浮点运算/秒 (FLOPs)。然而,GPU 存在快速内存(HBM,类似于 CPU 的 RAM)稀缺的限制,通常只有几十 GB。
推荐系统模型可能包含远超过单个 GPU 内存限制的嵌入表,因此需要将嵌入表分布在多个 GPU 上,这也被称为模型并行。另一方面,数据并行则是将整个模型复制到每个 GPU 上,每个 GPU 处理不同的数据批次进行训练,并在反向传播时同步梯度。
模型中需要较少计算但更多内存的部分(如嵌入)通过模型并行进行分布,而需要更多计算但较少内存的部分(如全连接层、多层感知机等)则通过数据并行进行分布。
分片
为了分发嵌入表,我们将嵌入表分割成多个部分,并将这些部分放置在不同的设备上,这个过程也称为“分片”。
分片嵌入表有多种方法。最常见的方式包括:
-
表级别:整个表被放置在一个设备上
-
列级别:嵌入表的列被分片
-
行级别:嵌入表的行被分片
分片模块
虽然这一切看起来需要处理和实施的内容很多,但您很幸运。TorchRec 提供了所有用于简化分布式训练和推理的原语! 事实上,TorchRec 模块有两个对应的类,用于在分布式环境中处理任何 TorchRec 模块:
-
模块分片器:该类提供了一个
shard
API,用于处理 TorchRec 模块的分片,生成一个分片后的模块。* 对于EmbeddingBagCollection
,分片器是 EmbeddingBagCollectionSharder -
分片模块:该类是 TorchRec 模块的分片版本。它与常规的 TorchRec 模块具有相同的输入/输出,但更加优化,并且可以在分布式环境中工作。* 对于
EmbeddingBagCollection
,分片版本是 ShardedEmbeddingBagCollection
每个 TorchRec 模块都有未分片和分片两种变体。
-
非分片版本适用于原型设计和实验。
-
分片版本适用于分布式环境中的分布式训练和推理。
TorchRec 模块的分片版本,例如 EmbeddingBagCollection
,将处理模型并行所需的所有操作,例如在 GPU 之间进行通信以将嵌入分发到正确的 GPU。
关于我们的 EmbeddingBagCollection
模块的回顾
ebc
fromtorchrec.distributed.embeddingbagimport EmbeddingBagCollectionSharder
fromtorchrec.distributed.plannerimport EmbeddingShardingPlanner, Topology
fromtorchrec.distributed.typesimport ShardingEnv
# Corresponding sharder for ``EmbeddingBagCollection`` module
sharder = EmbeddingBagCollectionSharder()
# ``ProcessGroup`` from torch.distributed initialized 2 cells above
pg = dist.GroupMember.WORLD
assert pg is not None, "Process group is not initialized"
print(f"Process Group: {pg}")
Process Group: <torch.distributed.distributed_c10d.ProcessGroup object at 0x7f9983a76230>
规划器
在展示分片工作原理之前,我们必须了解规划器,它帮助我们确定最佳的分片配置。
给定若干嵌入表和若干计算资源,存在多种可能的分片配置。例如,给定 2 个嵌入表和 2 个 GPU,您可以:
-
在每个 GPU 上放置一个表
-
将两个表都放置在单个 GPU 上,另一个 GPU 上不放置任何表
-
在每个 GPU 上放置特定的行和列
考虑到所有这些可能性,我们通常需要一个在性能上最优的分片配置。
这就是规划器的作用所在。规划器能够根据嵌入表的数量和 GPU 的数量,确定最优的配置。实际上,手动完成这项工作极其困难,工程师需要考虑大量因素以确保分片计划的最优性。幸运的是,当使用规划器时,TorchRec 提供了一个自动规划器。
TorchRec 规划器:
-
评估硬件的内存限制
-
根据内存获取(如嵌入查找)估算计算量
-
解决数据特定的因素
-
考虑其他硬件特性(如带宽)以生成最优的分片方案
为了综合考虑所有这些变量,TorchRec 规划器可以接收嵌入表、约束条件、硬件信息和拓扑结构等各种数据,以帮助生成模型的最优分片计划,这些数据通常会在各个堆栈中提供。
要了解更多关于分片的信息,请参阅我们的分片教程。
# In our case, 1 GPU and compute on CUDA device
planner = EmbeddingShardingPlanner(
topology=Topology(
world_size=1,
compute_device="cuda",
)
)
# Run planner to get plan for sharding
plan = planner.collective_plan(ebc, [sharder], pg)
print(f"Sharding Plan generated: {plan}")
Sharding Plan generated: module:
param | sharding type | compute kernel | ranks
*------------ | ------------- | -------------- | -----
product_table | table_wise | fused | [0]
user_table | table_wise | fused | [0]
param | shard offsets | shard sizes | placement
*------------ | ------------- | ----------- | -------------
product_table | [0, 0] | [4096, 64] | rank:0/cuda:0
user_table | [0, 0] | [4096, 64] | rank:0/cuda:0
规划器结果
如上所示,在运行规划器时会有相当多的输出。我们可以看到大量统计数据的计算过程,以及我们的表最终被放置的位置。
运行规划器的结果是一个静态计划,该计划可以重复用于分片!这使得生产模型中的分片可以保持静态,而不需要每次重新确定一个新的分片计划。接下来,我们使用这个分片计划最终生成我们的 ShardedEmbeddingBagCollection
。
# The static plan that was generated
plan
env = ShardingEnv.from_process_group(pg)
# Shard the ``EmbeddingBagCollection`` module using the ``EmbeddingBagCollectionSharder``
sharded_ebc = sharder.shard(ebc, plan.plan[""], env, torch.device("cuda"))
print(f"Sharded EBC Module: {sharded_ebc}")
Sharded EBC Module: ShardedEmbeddingBagCollection(
(lookups):
GroupedPooledEmbeddingsLookup(
(_emb_modules): ModuleList(
(0): BatchedFusedEmbeddingBag(
(_emb_module): SplitTableBatchedEmbeddingBagsCodegen()
)
)
)
(_output_dists):
TwPooledEmbeddingDist()
(embedding_bags): ModuleDict(
(product_table): Module()
(user_table): Module()
)
)
使用 LazyAwaitable
进行 GPU 训练
请记住,TorchRec 是一个为分布式嵌入高度优化的库。TorchRec 引入了一个概念 LazyAwaitable,以提高在 GPU 上训练的性能。您会在各种分片 TorchRec 模块的输出中看到 LazyAwaitable
类型。LazyAwaitable
类型的作用是尽可能延迟某些结果的计算,它通过像异步类型一样操作来实现这一点。
fromtypingimport List
fromtorchrec.distributed.typesimport LazyAwaitable
# Demonstrate a ``LazyAwaitable`` type:
classExampleAwaitable(LazyAwaitable[torch.Tensor]):
def__init__(self, size: List[int]) -> None:
super().__init__()
self._size = size
def_wait_impl(self) -> torch.Tensor:
return torch.ones(self._size)
awaitable = ExampleAwaitable([3, 2])
awaitable.wait()
kjt = kjt.to("cuda")
output = sharded_ebc(kjt)
# The output of our sharded ``EmbeddingBagCollection`` module is an `Awaitable`?
print(output)
kt = output.wait()
# Now we have our ``KeyedTensor`` after calling ``.wait()``
# If you are confused as to why we have a ``KeyedTensor ``output,
# give yourself a refresher on the unsharded ``EmbeddingBagCollection`` module
print(type(kt))
print(kt.keys())
print(kt.values().shape)
# Same output format as unsharded ``EmbeddingBagCollection``
result_dict = kt.to_dict()
for key, embedding in result_dict.items():
print(key, embedding.shape)
<torchrec.distributed.embeddingbag.EmbeddingBagCollectionAwaitable object at 0x7f993cea0fa0>
<class 'torchrec.sparse.jagged_tensor.KeyedTensor'>
['product', 'user']
torch.Size([2, 128])
product torch.Size([2, 64])
user torch.Size([2, 64])
Sharded TorchRec 模块解析
我们现在已经根据生成的切分计划成功地对EmbeddingBagCollection
进行了分片!分片后的模块拥有来自TorchRec的通用API,这些API抽象了多个GPU之间的分布式通信和计算。实际上,这些API在训练和推理性能上进行了高度优化。以下是TorchRec提供的用于分布式训练/推理的三个常见API:
-
input_dist
: 负责将输入从 GPU 分发到 GPU。 -
lookups
: 使用 FBGEMM TBE 以优化的批处理方式执行实际的嵌入查找(后续会详细介绍)。 -
output_dist
: 负责将输出从 GPU 分发到 GPU。
输入和输出的分发是通过 NCCL Collectives 完成的,特别是 All-to-Alls,即所有 GPU 之间相互发送和接收数据。TorchRec 与 PyTorch 分布式系统交互来进行集合操作,并为最终用户提供了清晰的抽象,消除了对底层细节的担忧。
反向传播过程会执行所有这些集合操作,但顺序相反,用于梯度的分发。input_dist
、lookup
和 output_dist
都依赖于分片方案。由于我们以表级方式进行分片,这些 API 是由 TwPooledEmbeddingSharding 构建的模块。
sharded_ebc
# Distribute input KJTs to all other GPUs and receive KJTs
sharded_ebc._input_dists
# Distribute output embeddings to all other GPUs and receive embeddings
sharded_ebc._output_dists
[TwPooledEmbeddingDist(
(_dist): PooledEmbeddingsAllToAll()
)]
优化嵌入查找
在为一组嵌入表执行查找时,一个简单的解决方案是遍历所有的 nn.EmbeddingBags
并对每个表进行查找。这正是标准的、未分片的 EmbeddingBagCollection
所做的。然而,尽管这种解决方案很简单,但它非常慢。
FBGEMM 是一个提供高度优化的 GPU 操作符(也称为内核)的库。其中一个操作符被称为 表批处理嵌入 (TBE),它提供了两个主要的优化:
-
表批处理,允许您通过一次内核调用查找多个嵌入。
-
优化器融合,允许模块在给定标准 PyTorch 优化器和参数的情况下自我更新。
ShardedEmbeddingBagCollection
使用 FBGEMM TBE 作为查找操作,而不是传统的 nn.EmbeddingBags
,以实现优化的嵌入查找。
sharded_ebc._lookups
[GroupedPooledEmbeddingsLookup(
(_emb_modules): ModuleList(
(0): BatchedFusedEmbeddingBag(
(_emb_module): SplitTableBatchedEmbeddingBagsCodegen()
)
)
)]
DistributedModelParallel
我们已经探索了如何对单个 EmbeddingBagCollection
进行分片!我们能够使用 EmbeddingBagCollectionSharder
和未分片的 EmbeddingBagCollection
来生成一个 ShardedEmbeddingBagCollection
模块。这个工作流程是可行的,但通常在实现模型并行时,DistributedModelParallel(DMP)会被用作标准接口。当使用 DMP 包装模型(在我们的例子中是 ebc
)时,会发生以下情况:
-
确定如何对模型进行分片。DMP 将收集可用的分片器,并制定最佳的分片计划来分割嵌入表(例如
EmbeddingBagCollection
)。 -
实际对模型进行分片。这包括在适当的设备上为每个嵌入表分配内存。
DMP 接收了我们刚刚实验的所有内容,比如静态分片计划、分片器列表等。然而,它也有一些不错的默认设置,可以无缝地对 TorchRec 模型进行分片。在这个简单的示例中,由于我们有两个嵌入表和一个 GPU,TorchRec 会将它们都放置在单个 GPU 上。
ebc
model = torchrec.distributed.DistributedModelParallel(ebc, device=torch.device("cuda"))
out = model(kjt)
out.wait()
model
DistributedModelParallel(
(_dmp_wrapped_module): ShardedEmbeddingBagCollection(
(lookups):
GroupedPooledEmbeddingsLookup(
(_emb_modules): ModuleList(
(0): BatchedFusedEmbeddingBag(
(_emb_module): SplitTableBatchedEmbeddingBagsCodegen()
)
)
)
(_input_dists):
TwSparseFeaturesDist(
(_dist): KJTAllToAll()
)
(_output_dists):
TwPooledEmbeddingDist(
(_dist): PooledEmbeddingsAllToAll()
)
(embedding_bags): ModuleDict(
(product_table): Module()
(user_table): Module()
)
)
)
分片最佳实践
目前,我们的配置仅在1个GPU(或rank)上进行分片,这很简单:只需将所有表放置在1个GPU的内存中。然而,在实际生产用例中,嵌入表通常会在数百个GPU上进行分片,并采用不同的分片方法,例如表级、行级和列级分片。确定一个合适的分片配置(以防止内存不足问题),同时不仅在内存方面保持平衡,还要在计算上保持平衡以实现最佳性能,这一点极为重要。
添加优化器
请记住,TorchRec 模块针对大规模分布式训练进行了高度优化。其中一项重要的优化涉及优化器。
TorchRec 模块提供了一个无缝的 API,用于在训练过程中融合反向传播和优化步骤,从而显著提升性能并减少内存使用,同时还支持为不同的模型参数分配不同的优化器,实现更细粒度的控制。
优化器类
TorchRec 使用 CombinedOptimizer
,它包含一组 KeyedOptimizer
。CombinedOptimizer
有效地简化了处理模型中不同子组的多个优化器的操作。KeyedOptimizer
继承自 torch.optim.Optimizer
,并通过参数字典进行初始化,暴露这些参数。EmbeddingBagCollection
中的每个 TBE
模块都会有自己的 KeyedOptimizer
,这些 KeyedOptimizer
会组合成一个 CombinedOptimizer
。
TorchRec 中的融合优化器
使用 DistributedModelParallel
时,优化器是融合的,这意味着优化器更新在反向传播过程中完成。这是 TorchRec 和 FBGEMM 中的一项优化,其中优化器的嵌入梯度不会被具体化,而是直接应用于参数。这带来了显著的内存节省,因为嵌入梯度通常与参数本身的大小相当。
然而,您可以选择将优化器设置为 dense
,这样就不会应用此优化,从而允许您检查嵌入梯度或对其进行所需的计算。在这种情况下,一个 dense
优化器将是您的规范的 PyTorch 模型训练循环与优化器。
通过 DistributedModelParallel
创建优化器后,您仍然需要管理那些与 TorchRec 嵌入模块无关的参数的优化器。要查找这些参数,请使用 in_backward_optimizer_filter(model.named_parameters())
。像对待普通的 Torch 优化器一样对这些参数应用优化器,并将其与 model.fused_optimizer
结合到一个 CombinedOptimizer
中,您可以在训练循环中使用它来执行 zero_grad
和 step
操作。
为 EmbeddingBagCollection
添加优化器
我们将通过两种方式来实现,这两种方式是等效的,但可以根据您的偏好提供不同的选择:
-
通过 sharder 中的
fused_params
传递优化器参数。 -
通过
apply_optimizer_in_backward
,将优化器参数转换为fused_params
,然后传递给EmbeddingBagCollection
或EmbeddingCollection
中的TBE
。
# Option 1: Passing optimizer kwargs through fused parameters
fromtorchrec.optim.optimizersimport in_backward_optimizer_filter
fromfbgemm_gpu.split_embedding_configsimport EmbOptimType
# We initialize the sharder with
fused_params = {
"optimizer": EmbOptimType.EXACT_ROWWISE_ADAGRAD,
"learning_rate": 0.02,
"eps": 0.002,
}
# Initialize sharder with ``fused_params``
sharder_with_fused_params = EmbeddingBagCollectionSharder(fused_params=fused_params)
# We'll use same plan and unsharded EBC as before but this time with our new sharder
sharded_ebc_fused_params = sharder_with_fused_params.shard(ebc, plan.plan[""], env, torch.device("cuda"))
# Looking at the optimizer of each, we can see that the learning rate changed, which indicates our optimizer has been applied correctly.
# If seen, we can also look at the TBE logs of the cell to see that our new optimizer is indeed being applied
print(f"Original Sharded EBC fused optimizer: {sharded_ebc.fused_optimizer}")
print(f"Sharded EBC with fused parameters fused optimizer: {sharded_ebc_fused_params.fused_optimizer}")
print(f"Type of optimizer: {type(sharded_ebc_fused_params.fused_optimizer)}")
fromtorch.distributed.optimimport _apply_optimizer_in_backward as apply_optimizer_in_backward
importcopy
# Option 2: Applying optimizer through apply_optimizer_in_backward
# Note: we need to call apply_optimizer_in_backward on unsharded model first and then shard it
# We can achieve the same result as we did in the previous
ebc_apply_opt = copy.deepcopy(ebc)
optimizer_kwargs = {"lr": 0.5}
for name, param in ebc_apply_opt.named_parameters():
print(f"{name=}")
apply_optimizer_in_backward(torch.optim.SGD, [param], optimizer_kwargs)
sharded_ebc_apply_opt = sharder.shard(ebc_apply_opt, plan.plan[""], env, torch.device("cuda"))
# Now when we print the optimizer, we will see our new learning rate, you can verify momentum through the TBE logs as well if outputted
print(sharded_ebc_apply_opt.fused_optimizer)
print(type(sharded_ebc_apply_opt.fused_optimizer))
# We can also check through the filter other parameters that aren't associated with the "fused" optimizer(s)
# Practically, just non TorchRec module parameters. Since our module is just a TorchRec EBC
# there are no other parameters that aren't associated with TorchRec
print("Non Fused Model Parameters:")
print(dict(in_backward_optimizer_filter(sharded_ebc_fused_params.named_parameters())).keys())
# Here we do a dummy backwards call and see that parameter updates for fused
# optimizers happen as a result of the backward pass
ebc_output = sharded_ebc_fused_params(kjt).wait().values()
loss = torch.sum(torch.ones_like(ebc_output) - ebc_output)
print(f"First Iteration Loss: {loss}")
loss.backward()
ebc_output = sharded_ebc_fused_params(kjt).wait().values()
loss = torch.sum(torch.ones_like(ebc_output) - ebc_output)
# We don't call an optimizer.step(), so for the loss to have changed here,
# that means that the gradients were somehow updated, which is what the
# fused optimizer automatically handles for us
print(f"Second Iteration Loss: {loss}")
Original Sharded EBC fused optimizer: : EmbeddingFusedOptimizer (
Parameter Group 0
lr: 0.01
)
Sharded EBC with fused parameters fused optimizer: : EmbeddingFusedOptimizer (
Parameter Group 0
lr: 0.02
)
Type of optimizer: <class 'torchrec.optim.keyed.CombinedOptimizer'>
name='embedding_bags.product_table.weight'
name='embedding_bags.user_table.weight'
: EmbeddingFusedOptimizer (
Parameter Group 0
lr: 0.5
)
<class 'torchrec.optim.keyed.CombinedOptimizer'>
Non Fused Model Parameters:
dict_keys([])
First Iteration Loss: 255.66006469726562
Second Iteration Loss: 245.43795776367188
推理
现在我们已经能够训练分布式嵌入模型,那么如何利用训练好的模型并优化它以进行推理呢?推理通常对模型的性能和大小非常敏感。直接在 Python 环境中运行训练好的模型效率极低。推理环境和训练环境有两个关键区别:
-
量化:推理模型通常会被量化,通过降低模型参数的精度来减少预测延迟和模型大小。例如,训练模型中的 FP32(4 字节)可以量化为每个嵌入权重的 INT8(1 字节)。鉴于嵌入表的规模庞大,这也是必要的,因为我们希望尽可能少地使用设备进行推理,以最大限度地减少延迟。
-
C++ 环境:推理延迟非常重要,因此为了确保足够的性能,模型通常在 C++ 环境中运行,尤其是在没有 Python 运行时的情况下,比如在设备上。
TorchRec 提供了一些原语,用于将 TorchRec 模型转换为推理就绪状态,具体包括:
-
用于量化模型的API,通过FBGEMM TBE自动引入优化
-
为分布式推理分片嵌入
-
将模型编译为TorchScript(兼容C++)
在本节中,我们将详细介绍以下工作流程:
-
量化模型
-
对量化后的模型进行分片
-
将分片后的量化模型编译为 TorchScript
ebc
classInferenceModule(torch.nn.Module):
def__init__(self, ebc: torchrec.EmbeddingBagCollection):
super().__init__()
self.ebc_ = ebc
defforward(self, kjt: KeyedJaggedTensor):
return self.ebc_(kjt)
module = InferenceModule(ebc)
for name, param in module.named_parameters():
# Here, the parameters should still be FP32, as we are using a standard EBC
# FP32 is default, regularly used for training
print(name, param.shape, param.dtype)
ebc_.embedding_bags.product_table.weight torch.Size([4096, 64]) torch.float32
ebc_.embedding_bags.user_table.weight torch.Size([4096, 64]) torch.float32
量化
如上所示,普通的 EBC 包含 FP32 精度的嵌入表权重(每个权重占 32 位)。在这里,我们将使用 TorchRec 推理库将模型的嵌入权重量化为 INT8。
fromtorchimport quantization as quant
fromtorchrec.modules.embedding_configsimport QuantConfig
fromtorchrec.quant.embedding_modulesimport (
EmbeddingBagCollection as QuantEmbeddingBagCollection,
)
quant_dtype = torch.int8
qconfig = QuantConfig(
# dtype of the result of the embedding lookup, post activation
# torch.float generally for compatibility with rest of the model
# as rest of the model here usually isn't quantized
activation=quant.PlaceholderObserver.with_args(dtype=torch.float),
# quantized type for embedding weights, aka parameters to actually quantize
weight=quant.PlaceholderObserver.with_args(dtype=quant_dtype),
)
qconfig_spec = {
# Map of module type to qconfig
torchrec.EmbeddingBagCollection: qconfig,
}
mapping = {
# Map of module type to quantized module type
torchrec.EmbeddingBagCollection: QuantEmbeddingBagCollection,
}
module = InferenceModule(ebc)
# Quantize the module
qebc = quant.quantize_dynamic(
module,
qconfig_spec=qconfig_spec,
mapping=mapping,
inplace=False,
)
print(f"Quantized EBC: {qebc}")
kjt = kjt.to("cpu")
qebc(kjt)
# Once quantized, goes from parameters -> buffers, as no longer trainable
for name, buffer in qebc.named_buffers():
# The shapes of the tables should be the same but the dtype should be int8 now
# post quantization
print(name, buffer.shape, buffer.dtype)
Quantized EBC: InferenceModule(
(ebc_): QuantizedEmbeddingBagCollection(
(_kjt_to_jt_dict): ComputeKJTToJTDict()
(embedding_bags): ModuleDict(
(product_table): Module()
(user_table): Module()
)
)
)
ebc_.embedding_bags.product_table.weight torch.Size([4096, 80]) torch.uint8
ebc_.embedding_bags.user_table.weight torch.Size([4096, 80]) torch.uint8
分片
在这里,我们对 TorchRec 量化模型进行分片。这是为了确保我们通过 FBGEMM TBE 使用高性能模块。为了与训练保持一致(1 个 TBE),我们在这里使用了一个设备。
fromtorchrecimport distributed as trec_dist
fromtorchrec.distributed.shardimport _shard_modules
sharded_qebc = _shard_modules(
module=qebc,
device=torch.device("cpu"),
env=trec_dist.ShardingEnv.from_local(
1,
0,
),
)
print(f"Sharded Quantized EBC: {sharded_qebc}")
sharded_qebc(kjt)
Sharded Quantized EBC: InferenceModule(
(ebc_): ShardedQuantEmbeddingBagCollection(
(lookups):
InferGroupedPooledEmbeddingsLookup()
(_output_dists): ModuleList()
(embedding_bags): ModuleDict(
(product_table): Module()
(user_table): Module()
)
(_input_dist_module): ShardedQuantEbcInputDist()
)
)
<torchrec.sparse.jagged_tensor.KeyedTensor object at 0x7f993cea19f0>
编译
现在我们有了经过优化的急切模式 TorchRec 推理模型。下一步是确保该模型可以在 C++ 中加载,因为目前它只能在 Python 运行时中运行。
在 Meta 推荐的编译方法分为两步:torch.fx tracing(生成模型的中间表示)以及将结果转换为 TorchScript,其中 TorchScript 是 C++ 兼容的。
fromtorchrec.fximport Tracer
tracer = Tracer(leaf_modules=["IntNBitTableBatchedEmbeddingBagsCodegen"])
graph = tracer.trace(sharded_qebc)
gm = torch.fx.GraphModule(sharded_qebc, graph)
print("Graph Module Created!")
print(gm.code)
scripted_gm = torch.jit.script(gm)
print("Scripted Graph Module Created!")
print(scripted_gm.code)
Graph Module Created!
torch.fx._symbolic_trace.wrap("torchrec_distributed_quant_embeddingbag_flatten_feature_lengths")
torch.fx._symbolic_trace.wrap("torchrec_fx_utils__fx_marker")
torch.fx._symbolic_trace.wrap("torchrec_distributed_quant_embedding_kernel__unwrap_kjt")
torch.fx._symbolic_trace.wrap("fbgemm_gpu_split_table_batched_embeddings_ops_inference_inputs_to_device")
torch.fx._symbolic_trace.wrap("torchrec_distributed_embedding_lookup_embeddings_cat_empty_rank_handle_inference")
def forward(self, kjt : torchrec_sparse_jagged_tensor_KeyedJaggedTensor):
flatten_feature_lengths = torchrec_distributed_quant_embeddingbag_flatten_feature_lengths(kjt); kjt = None
_fx_marker = torchrec_fx_utils__fx_marker('KJT_ONE_TO_ALL_FORWARD_BEGIN', flatten_feature_lengths); _fx_marker = None
split = flatten_feature_lengths.split([2])
getitem = split[0]; split = None
to = getitem.to(device(type='cuda', index=0), non_blocking = True); getitem = None
_fx_marker_1 = torchrec_fx_utils__fx_marker('KJT_ONE_TO_ALL_FORWARD_END', flatten_feature_lengths); flatten_feature_lengths = _fx_marker_1 = None
_unwrap_kjt = torchrec_distributed_quant_embedding_kernel__unwrap_kjt(to); to = None
getitem_1 = _unwrap_kjt[0]
getitem_2 = _unwrap_kjt[1]
getitem_3 = _unwrap_kjt[2]; _unwrap_kjt = getitem_3 = None
inputs_to_device = fbgemm_gpu_split_table_batched_embeddings_ops_inference_inputs_to_device(getitem_1, getitem_2, None, device(type='cuda', index=0)); getitem_1 = getitem_2 = None
getitem_4 = inputs_to_device[0]
getitem_5 = inputs_to_device[1]
getitem_6 = inputs_to_device[2]; inputs_to_device = None
_tensor_constant0 = self._tensor_constant0
_tensor_constant1 = self._tensor_constant1
bounds_check_indices = torch.ops.fbgemm.bounds_check_indices(_tensor_constant0, getitem_4, getitem_5, 1, _tensor_constant1, getitem_6); _tensor_constant0 = _tensor_constant1 = bounds_check_indices = None
_tensor_constant2 = self._tensor_constant2
_tensor_constant3 = self._tensor_constant3
_tensor_constant4 = self._tensor_constant4
_tensor_constant5 = self._tensor_constant5
_tensor_constant6 = self._tensor_constant6
_tensor_constant7 = self._tensor_constant7
_tensor_constant8 = self._tensor_constant8
_tensor_constant9 = self._tensor_constant9
int_nbit_split_embedding_codegen_lookup_function = torch.ops.fbgemm.int_nbit_split_embedding_codegen_lookup_function(dev_weights = _tensor_constant2, uvm_weights = _tensor_constant3, weights_placements = _tensor_constant4, weights_offsets = _tensor_constant5, weights_tys = _tensor_constant6, D_offsets = _tensor_constant7, total_D = 128, max_int2_D = 0, max_int4_D = 0, max_int8_D = 64, max_float16_D = 0, max_float32_D = 0, indices = getitem_4, offsets = getitem_5, pooling_mode = 0, indice_weights = getitem_6, output_dtype = 0, lxu_cache_weights = _tensor_constant8, lxu_cache_locations = _tensor_constant9, row_alignment = 16, max_float8_D = 0, fp8_exponent_bits = -1, fp8_exponent_bias = -1); _tensor_constant2 = _tensor_constant3 = _tensor_constant4 = _tensor_constant5 = _tensor_constant6 = _tensor_constant7 = getitem_4 = getitem_5 = getitem_6 = _tensor_constant8 = _tensor_constant9 = None
embeddings_cat_empty_rank_handle_inference = torchrec_distributed_embedding_lookup_embeddings_cat_empty_rank_handle_inference([int_nbit_split_embedding_codegen_lookup_function], dim = 1, device = 'cuda:0', dtype = torch.float32); int_nbit_split_embedding_codegen_lookup_function = None
to_1 = embeddings_cat_empty_rank_handle_inference.to(device(type='cpu')); embeddings_cat_empty_rank_handle_inference = None
keyed_tensor = torchrec_sparse_jagged_tensor_KeyedTensor(keys = ['product', 'user'], length_per_key = [64, 64], values = to_1, key_dim = 1); to_1 = None
return keyed_tensor
/usr/local/lib/python3.10/dist-packages/torch/jit/_check.py:178: UserWarning:
The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
Scripted Graph Module Created!
def forward(self,
kjt: __torch__.torchrec.sparse.jagged_tensor.KeyedJaggedTensor) -> __torch__.torchrec.sparse.jagged_tensor.KeyedTensor:
_0 = __torch__.torchrec.distributed.quant_embeddingbag.flatten_feature_lengths
_1 = __torch__.torchrec.fx.utils._fx_marker
_2 = __torch__.torchrec.distributed.quant_embedding_kernel._unwrap_kjt
_3 = __torch__.fbgemm_gpu.split_table_batched_embeddings_ops_inference.inputs_to_device
_4 = __torch__.torchrec.distributed.embedding_lookup.embeddings_cat_empty_rank_handle_inference
flatten_feature_lengths = _0(kjt, )
_fx_marker = _1("KJT_ONE_TO_ALL_FORWARD_BEGIN", flatten_feature_lengths, )
split = (flatten_feature_lengths).split([2], )
getitem = split[0]
to = (getitem).to(torch.device("cuda", 0), True, None, )
_fx_marker_1 = _1("KJT_ONE_TO_ALL_FORWARD_END", flatten_feature_lengths, )
_unwrap_kjt = _2(to, )
getitem_1 = (_unwrap_kjt)[0]
getitem_2 = (_unwrap_kjt)[1]
inputs_to_device = _3(getitem_1, getitem_2, None, torch.device("cuda", 0), )
getitem_4 = (inputs_to_device)[0]
getitem_5 = (inputs_to_device)[1]
getitem_6 = (inputs_to_device)[2]
_tensor_constant0 = self._tensor_constant0
_tensor_constant1 = self._tensor_constant1
ops.fbgemm.bounds_check_indices(_tensor_constant0, getitem_4, getitem_5, 1, _tensor_constant1, getitem_6)
_tensor_constant2 = self._tensor_constant2
_tensor_constant3 = self._tensor_constant3
_tensor_constant4 = self._tensor_constant4
_tensor_constant5 = self._tensor_constant5
_tensor_constant6 = self._tensor_constant6
_tensor_constant7 = self._tensor_constant7
_tensor_constant8 = self._tensor_constant8
_tensor_constant9 = self._tensor_constant9
int_nbit_split_embedding_codegen_lookup_function = ops.fbgemm.int_nbit_split_embedding_codegen_lookup_function(_tensor_constant2, _tensor_constant3, _tensor_constant4, _tensor_constant5, _tensor_constant6, _tensor_constant7, 128, 0, 0, 64, 0, 0, getitem_4, getitem_5, 0, getitem_6, 0, _tensor_constant8, _tensor_constant9, 16)
_5 = [int_nbit_split_embedding_codegen_lookup_function]
embeddings_cat_empty_rank_handle_inference = _4(_5, 1, "cuda:0", 6, )
to_1 = torch.to(embeddings_cat_empty_rank_handle_inference, torch.device("cpu"))
_6 = ["product", "user"]
_7 = [64, 64]
keyed_tensor = __torch__.torchrec.sparse.jagged_tensor.KeyedTensor.__new__(__torch__.torchrec.sparse.jagged_tensor.KeyedTensor)
_8 = (keyed_tensor).__init__(_6, _7, to_1, 1, None, None, )
return keyed_tensor
结论
在本教程中,您已经从训练一个分布式推荐系统模型开始,一直到使其准备好进行推理。TorchRec 仓库中有一个完整的示例,展示了如何将 TorchRec 的 TorchScript 模型加载到 C++ 中进行推理。
如需更多信息,请参阅我们的 dlrm 示例,其中包括使用《用于个性化和推荐系统的深度学习推荐模型》中描述的方法,在 Criteo 1TB 数据集上进行多节点训练。