__repr__ 和 state_dict() 能看到所有模块信息,但 .modules() / .named_modules() 只能看到 self.embeddings(一个 nn.ModuleDict),看不到实际执行计算的 lookup / input_dist / output_dist 模块。
根本原因在于 ShardedEmbeddingModule 基类将三大核心子模块存储为普通 Python List,而不是 nn.ModuleList:
# torchrec/distributed/embedding_types.py:392-394
class ShardedEmbeddingModule(...):
def __init__(self, ...):
self._input_dists: List[nn.Module] = [] # ← 普通 list, 不是 nn.ModuleList!
self._lookups: List[nn.Module] = [] # ← 普通 list, 不是 nn.ModuleList!
self._output_dists: List[nn.Module] = [] # ← 普通 list, 不是 nn.ModuleList!
PyTorch 的 nn.Module.__setattr__ 只会自动注册以下类型到 _modules 字典:
nn.Module 实例nn.ModuleList / nn.ModuleDictadd_module() / register_module() 显式注册普通 Python list 中的 nn.Module 不会被 PyTorch 追踪,因此 .modules()、.named_modules()、.parameters() 等方法都无法遍历到它们。
而 __repr__ 之所以能显示 lookups,是因为 ShardedEmbeddingModule 重写了 extra_repr(),手动遍历这些 list:
# torchrec/distributed/embedding_types.py:424-444
def extra_repr(self) -> str:
rep = []
rep.extend(loop("lookups", self._lookups)) # 手动遍历 list
rep.extend(loop("_input_dists", self._input_dists))
rep.extend(loop("_output_dists", self._output_dists))
return "\n ".join(rep)
下图展示了 ShardedEmbeddingCollection 的完整内部结构。绿色为 PyTorch .modules() 可见部分,红色为不可见部分:
| 原因 | 说明 |
|---|---|
| 避免重复参数追踪 | Lookup 内部的 TBE weight 已经通过 self.embeddings[table].weight 注册为参数。如果 lookups 也被注册为子模块,.parameters() 会返回重复的参数引用,导致优化器问题。 |
| State Dict 路径控制 | 需要 state_dict 的 key 为 embeddings.table_name.weight(与原始 EmbeddingCollection 一致),而不是 _lookups.0._emb_modules.0.table_name.weight。 |
| 计算与状态分离 | self.embeddings 负责对外暴露兼容的 nn.Module API(state_dict / parameters / named_modules),self._lookups 负责实际的 forward 计算。两者解耦。 |
| 训练模式手动管理 | ShardedEmbeddingModule.train() 手动对 _lookups 调用 .train(mode),不依赖 PyTorch 自动递归。 |
以下是 TorchRec 在各个层级对 PyTorch nn.Module 标准方法的修改汇总:
| PyTorch 方法 | 处理方式 | 文件位置 |
|---|---|---|
state_dict() |
pre hook _pre_state_dict_hook — 先 flush lookupspost hook post_state_dict_hook — 将 plain tensor 替换为 ShardedTensor/DTensor
|
embedding.py:712, 1047 |
load_state_dict() |
pre hook _pre_load_state_dict_hook — 将 ShardedTensor/DTensor 拆为 local tensorpost hook _post_load_state_dict_hook — 清理 virtual table 的 missing_keys
|
embedding.py:724, 1160 |
extra_repr() |
override 手动遍历 _lookups, _input_dists, _output_dists 生成 repr |
embedding_types.py:424 |
train() |
override 除调用 super().train() 外,手动对每个 lookup 调用 .train(mode) |
embedding_types.py:446 |
modules() / named_modules() |
未重写 — 只能看到 self.embeddings (nn.ModuleDict) |
— |
parameters() / named_parameters() |
未重写 — 通过 self.embeddings[table].weight 注册的 TableBatchedEmbeddingSlice 可见 |
— |
| PyTorch 方法 | 处理方式 | 文件位置 |
|---|---|---|
state_dict() |
override 遍历 _emb_modules 收集每个 kernel 的 state_dict |
embedding_lookup.py:388 |
load_state_dict() |
override 通过 _load_state_dict() 辅助函数分发到各 kernel |
embedding_lookup.py:407 |
named_parameters() |
override 手动遍历 _emb_modules |
embedding_lookup.py:415 |
named_buffers() |
override 手动遍历 _emb_modules |
embedding_lookup.py:425 |
自定义: named_parameters_by_table() |
新增 按 table 名返回 TableBatchedEmbeddingSlice |
embedding_lookup.py:435 |
| PyTorch 方法 | 处理方式 | 文件位置 |
|---|---|---|
state_dict() |
override 通过 get_state_dict() 辅助函数,从 TBE 的 split_embedding_weights() 构造 |
batched_embedding_kernel.py:1836 |
named_parameters() |
override 调用 named_split_embedding_weights(),包裹为 nn.Parameter 并挂 _in_backward_optimizers |
batched_embedding_kernel.py:2606 |
named_buffers() |
override 返回空 — fused params 以 buffer 形式存在于 TBE 内部,不对外暴露 | batched_embedding_kernel.py:2596 |
自定义: named_parameters_by_table() |
新增 返回 (table_name, TableBatchedEmbeddingSlice) 迭代器 |
batched_embedding_kernel.py:1890 |
model.state_dict()self.embeddings (nn.ModuleDict).weight (TableBatchedEmbeddingSlice → 实际是 TBE weight 的 view)
ShardedTensor 或 DTensor,get_named_split_embedding_weights_snapshot() 获取快照
最终 state_dict 的 key 格式:embeddings.{table_name}.weight
model.load_state_dict()local_shards()[0].tensor.to_local()weight_id, bucket, metadata keys
self.embeddings[table].weightTableBatchedEmbeddingSlice(TBE weight 的 view),数据直接写入 TBE
# torchrec/distributed/embedding.py:1170-1175
self.register_state_dict_pre_hook(self._pre_state_dict_hook)
self._register_state_dict_hook(post_state_dict_hook)
self._register_load_state_dict_pre_hook(
self._pre_load_state_dict_hook, with_module=True
)
self.register_load_state_dict_post_hook(_post_load_state_dict_hook)
_lookups 列表逐层访问。
sharded_ec: ShardedEmbeddingCollection = ...
for lookup in sharded_ec._lookups:
# lookup 是 GroupedEmbeddingsLookup
for emb_module in lookup._emb_modules:
# emb_module 是 BatchedFusedEmbedding / BatchedDenseEmbedding / ...
tbe = emb_module._emb_module
# tbe 就是 SplitTableBatchedEmbeddingBagsCodegen (fbgemm TBE)
print(type(tbe))
# <class 'fbgemm_gpu...SplitTableBatchedEmbeddingBagsCodegen'>
def get_all_tbes(sharded_module):
"""从 ShardedEmbeddingCollection 或 ShardedEmbeddingBagCollection 提取所有 TBE"""
tbes = []
for lookup in sharded_module._lookups:
# 处理 DDP 包裹
while isinstance(lookup, DistributedDataParallel):
lookup = lookup.module
for emb_module in lookup._emb_modules:
if hasattr(emb_module, '_emb_module'):
tbes.append(emb_module._emb_module)
return tbes
# 按 table 获取权重 (不需要直接访问 TBE)
for lookup in sharded_ec._lookups:
for table_name, tbe_slice in lookup.named_parameters_by_table():
print(f"{table_name}: shape={tbe_slice.shape}, dtype={tbe_slice.dtype}")
# tbe_slice 是 TableBatchedEmbeddingSlice(nn.Parameter)
# 它是底层 TBE weight tensor 的一个 view
# 如果你的模型被 DMP 包裹,sharded module 在 DMP 内部
dmp_model: DistributedModelParallel = ...
# 找到 sharded embedding collection
for name, module in dmp_model._dmp_wrapped_module.named_modules():
if isinstance(module, ShardedEmbeddingCollection):
for lookup in module._lookups:
for emb in lookup._emb_modules:
tbe = emb._emb_module # fbgemm TBE
_lookups, _emb_modules, _emb_module 都是私有属性,不属于公开 API,未来版本可能变动。DistributedDataParallel 包裹(DATA_PARALLEL sharding type),需要先 unwrap lookup.module。_emb_module 属性。| 类名 | 文件 | 职责 | 存储方式 |
|---|---|---|---|
ShardedEmbeddingCollection |
distributed/embedding.py | Sharded EC 的顶层容器,管理 state_dict 兼容性 | — |
ShardedEmbeddingModule |
distributed/embedding_types.py | 基类,定义 _lookups/_input_dists/_output_dists | plain List[nn.Module] |
GroupedEmbeddingsLookup |
distributed/embedding_lookup.py | 按 sharding type 分组的 lookup,管理多个 embedding kernel | nn.ModuleList (_emb_modules) |
BatchedFusedEmbedding |
distributed/batched_embedding_kernel.py | FUSED compute kernel 的实现,包裹 FBGEMM TBE | _emb_module (nn.Module 自动注册) |
SplitTableBatchedEmbedding |
fbgemm_gpu (外部) | 真正的 TBE — 执行 GPU embedding lookup 的 CUDA kernel | — |
TableBatchedEmbeddingSlice |
distributed/composable/ table_batched_embedding_slice.py |
TBE weight 的 view wrapper,继承 nn.Parameter | registered in embeddings[table].weight |
EmbeddingFusedOptimizer |
distributed/batched_embedding_kernel.py | 包裹 TBE 的内置优化器(SGD/Adam/etc in CUDA) | _optim 属性 |
TorchRec 中 embedding weight 同时存在于两条路径,这是理解整个系统的关键:
# torchrec/distributed/embedding.py:919-929
for (table_name, tbe_slice) in lookup.named_parameters_by_table():
# tbe_slice 是 TableBatchedEmbeddingSlice
# 它底层是 TBE weight tensor 的一个 view
self.embeddings[table_name].register_parameter("weight", tbe_slice)
# ↑ nn.ModuleDict 中的空 nn.Module
# 现在有了 .weight 参数, 指向 TBE 内存
这样做的效果:
state_dict() 遍历 self.embeddings 时能找到 .weight 参数parameters() 能正确返回所有 embedding weightspost_state_dict_hook 在 checkpoint 时替换为 ShardedTensor/DTensor 格式TorchRec 的 ShardedEmbeddingCollection 通过精心设计的双层结构实现了以下目标:
state_dict(), parameters() 的行为与标准 EmbeddingCollection 一致TableBatchedEmbeddingSlice 是 view,不占额外内存代价是 .modules() / .named_modules() 无法看到计算子模块,需要通过 ._lookups 私有属性手动访问。
torchrec/distributed/