TorchRec ShardedEmbeddingCollection
内部机制深度解析

PyTorch nn.Module FBGEMM TBE TorchRec Sharding
目录
  1. 问题现象:为什么 modules() 看不到 lookups?
  2. 完整模块层级结构
  3. 为什么用 Plain List 而不是 nn.ModuleList?
  4. TorchRec 重写/hook 了哪些 PyTorch 函数
  5. State Dict 完整工作流
  6. 如何访问底层 FBGEMM TBE Module
  7. 关键类速查表
  8. Embedding Weight 的双重身份

1. 问题现象:为什么 modules() 看不到 lookups?

⚠ 核心问题
__repr__state_dict() 能看到所有模块信息,但 .modules() / .named_modules() 只能看到 self.embeddings(一个 nn.ModuleDict),看不到实际执行计算的 lookup / input_dist / output_dist 模块。

根本原因在于 ShardedEmbeddingModule 基类将三大核心子模块存储为普通 Python List,而不是 nn.ModuleList

# torchrec/distributed/embedding_types.py:392-394
class ShardedEmbeddingModule(...):
    def __init__(self, ...):
        self._input_dists: List[nn.Module] = []   # ← 普通 list, 不是 nn.ModuleList!
        self._lookups: List[nn.Module] = []      # ← 普通 list, 不是 nn.ModuleList!
        self._output_dists: List[nn.Module] = [] # ← 普通 list, 不是 nn.ModuleList!

PyTorch 的 nn.Module.__setattr__ 只会自动注册以下类型到 _modules 字典:

普通 Python list 中的 nn.Module 不会被 PyTorch 追踪,因此 .modules().named_modules().parameters() 等方法都无法遍历到它们。

__repr__ 之所以能显示 lookups,是因为 ShardedEmbeddingModule 重写了 extra_repr(),手动遍历这些 list:

# torchrec/distributed/embedding_types.py:424-444
def extra_repr(self) -> str:
    rep = []
    rep.extend(loop("lookups", self._lookups))       # 手动遍历 list
    rep.extend(loop("_input_dists", self._input_dists))
    rep.extend(loop("_output_dists", self._output_dists))
    return "\n ".join(rep)

2. 完整模块层级结构

下图展示了 ShardedEmbeddingCollection 的完整内部结构。绿色为 PyTorch .modules() 可见部分,红色为不可见部分:

ShardedEmbeddingCollection │ ├── self.embeddings nn.ModuleDict ✓ modules()可见 │ ├── "table_0" → nn.Module │ │ └── .weight → TableBatchedEmbeddingSlice(nn.Parameter) │ ├── "table_1" → nn.Module │ │ └── .weight → TableBatchedEmbeddingSlice(nn.Parameter) │ └── ... │ ├── self._lookups plain list ✗ modules()不可见 │ └── GroupedEmbeddingsLookup(nn.Module) │ └── self._emb_modules (nn.ModuleList) │ └── BatchedFusedEmbedding(nn.Module) │ ├── self._emb_module → SplitTableBatchedEmbeddingBagsCodegen ← 真正的 FBGEMM TBE! │ └── self._optim → EmbeddingFusedOptimizer │ ├── self._input_dists plain list ✗ 不可见 │ └── input distribution modules (AllToAll, etc.) │ └── self._output_dists plain list ✗ 不可见 └── output distribution modules (AllToAll, etc.)

继承链

ShardedEmbeddingCollection ↑ extends ShardedEmbeddingModule[CompIn, DistOut, Out, ShrdCtx] # embedding_types.py:373 ↑ extends ShardedModule[CompIn, DistOut, Out, ShrdCtx] # types.py:1184 ↑ extends abc.ABC + nn.Module + ModuleNoCopyMixin 同时还实现: FusedOptimizerModule # optim/fused.py — 提供 fused_optimizer 属性 ModuleShardingMixIn # 提供 module_sharding_plan 属性

3. 为什么用 Plain List 而不是 nn.ModuleList?

设计意图
这是有意为之的设计选择,原因如下:
原因说明
避免重复参数追踪 Lookup 内部的 TBE weight 已经通过 self.embeddings[table].weight 注册为参数。如果 lookups 也被注册为子模块,.parameters() 会返回重复的参数引用,导致优化器问题。
State Dict 路径控制 需要 state_dict 的 key 为 embeddings.table_name.weight(与原始 EmbeddingCollection 一致),而不是 _lookups.0._emb_modules.0.table_name.weight
计算与状态分离 self.embeddings 负责对外暴露兼容的 nn.Module API(state_dict / parameters / named_modules),self._lookups 负责实际的 forward 计算。两者解耦。
训练模式手动管理 ShardedEmbeddingModule.train() 手动对 _lookups 调用 .train(mode),不依赖 PyTorch 自动递归。

4. TorchRec 重写 / Hook 了哪些 PyTorch 函数

以下是 TorchRec 在各个层级对 PyTorch nn.Module 标准方法的修改汇总:

4.1 ShardedEmbeddingCollection 层

PyTorch 方法处理方式文件位置
state_dict() pre hook _pre_state_dict_hook — 先 flush lookups
post hook post_state_dict_hook — 将 plain tensor 替换为 ShardedTensor/DTensor
embedding.py:712, 1047
load_state_dict() pre hook _pre_load_state_dict_hook — 将 ShardedTensor/DTensor 拆为 local tensor
post hook _post_load_state_dict_hook — 清理 virtual table 的 missing_keys
embedding.py:724, 1160
extra_repr() override 手动遍历 _lookups, _input_dists, _output_dists 生成 repr embedding_types.py:424
train() override 除调用 super().train() 外,手动对每个 lookup 调用 .train(mode) embedding_types.py:446
modules() / named_modules() 未重写 — 只能看到 self.embeddings (nn.ModuleDict)
parameters() / named_parameters() 未重写 — 通过 self.embeddings[table].weight 注册的 TableBatchedEmbeddingSlice 可见

4.2 GroupedEmbeddingsLookup 层

PyTorch 方法处理方式文件位置
state_dict() override 遍历 _emb_modules 收集每个 kernel 的 state_dict embedding_lookup.py:388
load_state_dict() override 通过 _load_state_dict() 辅助函数分发到各 kernel embedding_lookup.py:407
named_parameters() override 手动遍历 _emb_modules embedding_lookup.py:415
named_buffers() override 手动遍历 _emb_modules embedding_lookup.py:425
自定义: named_parameters_by_table() 新增 按 table 名返回 TableBatchedEmbeddingSlice embedding_lookup.py:435

4.3 BatchedFusedEmbedding 层 (TBE Wrapper)

PyTorch 方法处理方式文件位置
state_dict() override 通过 get_state_dict() 辅助函数,从 TBE 的 split_embedding_weights() 构造 batched_embedding_kernel.py:1836
named_parameters() override 调用 named_split_embedding_weights(),包裹为 nn.Parameter 并挂 _in_backward_optimizers batched_embedding_kernel.py:2606
named_buffers() override 返回空 — fused params 以 buffer 形式存在于 TBE 内部,不对外暴露 batched_embedding_kernel.py:2596
自定义: named_parameters_by_table() 新增 返回 (table_name, TableBatchedEmbeddingSlice) 迭代器 batched_embedding_kernel.py:1890

5. State Dict 完整工作流

5.1 保存路径: model.state_dict()

1
pre_state_dict_hook — flush 所有 lookups
确保 UVM cache / KV store 中的数据被写回
2
PyTorch 标准递归 — 遍历 self.embeddings (nn.ModuleDict)
收集每个 table 的 .weight (TableBatchedEmbeddingSlice → 实际是 TBE weight 的 view)
3
post_state_dict_hook — 替换为分布式张量
将 destination 中的 plain tensor 替换为预构建的 ShardedTensorDTensor
对 KV store 类型额外调用 get_named_split_embedding_weights_snapshot() 获取快照

最终 state_dict 的 key 格式:embeddings.{table_name}.weight

5.2 加载路径: model.load_state_dict()

1
pre_load_state_dict_hook — 拆解分布式张量
ShardedTensor → 提取 local_shards()[0].tensor
DTensor → .to_local()
Virtual Table → 删除 weight_id, bucket, metadata keys
2
PyTorch 标准加载 — 将 tensor 拷贝到 self.embeddings[table].weight
因为 weight 是 TableBatchedEmbeddingSlice(TBE weight 的 view),数据直接写入 TBE
3
post_load_state_dict_hook — 清理 missing keys
Virtual table 的 runtime 生成 key 不应报错

5.3 Hook 注册代码

# torchrec/distributed/embedding.py:1170-1175
self.register_state_dict_pre_hook(self._pre_state_dict_hook)
self._register_state_dict_hook(post_state_dict_hook)
self._register_load_state_dict_pre_hook(
    self._pre_load_state_dict_hook, with_module=True
)
self.register_load_state_dict_post_hook(_post_load_state_dict_hook)

6. 如何访问底层 FBGEMM TBE Module

✓ 推荐方法
通过 _lookups 列表逐层访问。

方法一:直接遍历 _lookups(最常用)

sharded_ec: ShardedEmbeddingCollection = ...

for lookup in sharded_ec._lookups:
    # lookup 是 GroupedEmbeddingsLookup
    for emb_module in lookup._emb_modules:
        # emb_module 是 BatchedFusedEmbedding / BatchedDenseEmbedding / ...
        tbe = emb_module._emb_module
        # tbe 就是 SplitTableBatchedEmbeddingBagsCodegen (fbgemm TBE)
        print(type(tbe))
        # <class 'fbgemm_gpu...SplitTableBatchedEmbeddingBagsCodegen'>

方法二:获取所有 TBE 的通用辅助函数

def get_all_tbes(sharded_module):
    """从 ShardedEmbeddingCollection 或 ShardedEmbeddingBagCollection 提取所有 TBE"""
    tbes = []
    for lookup in sharded_module._lookups:
        # 处理 DDP 包裹
        while isinstance(lookup, DistributedDataParallel):
            lookup = lookup.module
        for emb_module in lookup._emb_modules:
            if hasattr(emb_module, '_emb_module'):
                tbes.append(emb_module._emb_module)
    return tbes

方法三:通过 named_parameters_by_table() 访问权重切片

# 按 table 获取权重 (不需要直接访问 TBE)
for lookup in sharded_ec._lookups:
    for table_name, tbe_slice in lookup.named_parameters_by_table():
        print(f"{table_name}: shape={tbe_slice.shape}, dtype={tbe_slice.dtype}")
        # tbe_slice 是 TableBatchedEmbeddingSlice(nn.Parameter)
        # 它是底层 TBE weight tensor 的一个 view

方法四:通过 DistributedModelParallel (DMP) 访问

# 如果你的模型被 DMP 包裹,sharded module 在 DMP 内部
dmp_model: DistributedModelParallel = ...

# 找到 sharded embedding collection
for name, module in dmp_model._dmp_wrapped_module.named_modules():
    if isinstance(module, ShardedEmbeddingCollection):
        for lookup in module._lookups:
            for emb in lookup._emb_modules:
                tbe = emb._emb_module  # fbgemm TBE
⚠ 注意事项

7. 关键类速查表

类名文件职责存储方式
ShardedEmbeddingCollection distributed/embedding.py Sharded EC 的顶层容器,管理 state_dict 兼容性
ShardedEmbeddingModule distributed/embedding_types.py 基类,定义 _lookups/_input_dists/_output_dists plain List[nn.Module]
GroupedEmbeddingsLookup distributed/embedding_lookup.py 按 sharding type 分组的 lookup,管理多个 embedding kernel nn.ModuleList (_emb_modules)
BatchedFusedEmbedding distributed/batched_embedding_kernel.py FUSED compute kernel 的实现,包裹 FBGEMM TBE _emb_module (nn.Module 自动注册)
SplitTableBatchedEmbedding
BagsCodegen
fbgemm_gpu (外部) 真正的 TBE — 执行 GPU embedding lookup 的 CUDA kernel
TableBatchedEmbeddingSlice distributed/composable/
table_batched_embedding_slice.py
TBE weight 的 view wrapper,继承 nn.Parameter registered in embeddings[table].weight
EmbeddingFusedOptimizer distributed/batched_embedding_kernel.py 包裹 TBE 的内置优化器(SGD/Adam/etc in CUDA) _optim 属性

8. Embedding Weight 的双重身份

TorchRec 中 embedding weight 同时存在于两条路径,这是理解整个系统的关键:

路径 A: State Dict / Parameter 路径 (对外) ShardedEmbeddingCollection └── .embeddings["table_0"].weight └── TableBatchedEmbeddingSlice (nn.Parameter, 是 TBE weight 的 view) │ │ post_state_dict_hook 将其替换为 ↓ │ └── ShardedTensor / DTensor (分布式张量, checkpoint 格式) 路径 B: 计算路径 (对内, forward pass) ShardedEmbeddingCollection._lookups[i] └── GroupedEmbeddingsLookup._emb_modules[j] └── BatchedFusedEmbedding._emb_module └── SplitTableBatchedEmbeddingBagsCodegen └── .split_embedding_weights() → 实际的 GPU 上的 weight tensor 两条路径指向同一块内存! TableBatchedEmbeddingSlice 是通过 TBE 的 split_embedding_weights() 获取的 view, 修改任一路径的值,另一条路径也会同步变化。

_initialize_torch_state() 如何建立连接

# torchrec/distributed/embedding.py:919-929
for (table_name, tbe_slice) in lookup.named_parameters_by_table():
    # tbe_slice 是 TableBatchedEmbeddingSlice
    # 它底层是 TBE weight tensor 的一个 view
    self.embeddings[table_name].register_parameter("weight", tbe_slice)
    #          ↑ nn.ModuleDict 中的空 nn.Module
    #            现在有了 .weight 参数, 指向 TBE 内存

这样做的效果:

总结

TorchRec 的 ShardedEmbeddingCollection 通过精心设计的双层结构实现了以下目标:

  1. API 兼容性state_dict(), parameters() 的行为与标准 EmbeddingCollection 一致
  2. 高性能计算 — forward pass 直接使用 FBGEMM TBE 的 CUDA kernel
  3. 分布式 checkpoint — 通过 hooks 自动在 plain tensor ↔ ShardedTensor/DTensor 之间转换
  4. 零内存开销TableBatchedEmbeddingSlice 是 view,不占额外内存

代价是 .modules() / .named_modules() 无法看到计算子模块,需要通过 ._lookups 私有属性手动访问。

Generated on 2026-04-10  |  Based on TorchRec source at torchrec/distributed/