NVIDIA recsys-examples · _rewrite_model · BaseForward
TorchRec 的 ShardedModule.forward() 默认实现里藏着两段通信:
Native 串行下,主线程只能等 input_dist 的 NCCL 完成才能启动后续 GPU 计算,AllToAll 的毫秒级延迟完全暴露在关键路径上。
拆分核心思路:把 input_dist 从 batch_i 的 forward 里剥离,提前到 batch_(i-1) 的 fwd/bwd 期间跑,且用一条独立 CUDA stream。这样 AllToAll 被 GPU 计算完全覆盖。
TorchRec's default ShardedModule.forward() hides two NCCL collectives:
Native runs both serially; the main thread blocks on input_dist's NCCL before any downstream GPU work can start. Milliseconds of AllToAll latency land directly on the critical path.
Core idea: extract input_dist out of batch_i's forward and run it during batch_(i-1)'s fwd/bwd on an independent CUDA stream. AllToAll is now fully hidden behind GPU compute.
self.shared_ebc(batch.feats_sparse) 的参数表达式会被 evaluate 一遍但结果被忽略——真实输入从 context 取。self.shared_ebc(batch.feats_sparse) still evaluates the arg expression, but the result is thrown away — the real input comes from context.model.eval() without restoring is unsafe — the patched forward always expects a populated context.每个 class box 的三条元数据——pop 来源、__call__ 里 wait 去留、用哪个 Pipeline——对应 §4 三种不同的 wait 时机策略。
Each class box's three metadata lines — pop source, wait inside __call__, which Pipeline uses it — correspond to the three wait-timing strategies in §4.
ArgInfo 是 Step 4 的产物——一条"从 batch 到 sharded 模块参数"的访问路径。简化示例:
ArgInfo is Step 4's output — a recorded access path from batch to the sharded-module arg. Stripped example:
# User wrote: self.shared_ebc(batch.sparse_features) # FX trace recovers this access path into ArgInfo ↓ ArgInfo( input_attrs = ["", "sparse_features"], # getattr chain from batch is_getitems = [False, False], postproc_modules = [None, None], constants = [None, None], name = None, # positional arg ) # Later, _start_data_dist replays this on batch_i+1: # batch_i+1 -> getattr "sparse_features" -> ShardedModule.input_dist(...)
| Context 字段Context field | 写入方Written by | 读取方Read by |
|---|---|---|
input_dist_splits_requests |
_start_data_dist |
_fuse_input_dist_splits |
fused_splits_awaitables |
_fuse_input_dist_splits |
wait_sparse_data_dist |
input_dist_tensors_requests |
wait_sparse_data_dist |
PipelinedForward.__call__ |
module_contexts |
_start_data_dist |
PipelinedForward.__call__ |
module_input_post_prefetch |
_prefetch (Prefetch variant) |
PrefetchPipelinedForward.__call__ |
module_contexts_post_prefetch |
_prefetch (Prefetch variant) |
PrefetchPipelinedForward.__call__ |
ShardedModule 的所有非-leaf 祖先——途中任何一段动态 Python 逻辑({**proxy}、len(proxy)、Python-level if/for 依赖 tensor 值)都会让 trace 失败。要么把那段代码改 FX-friendly,要么用 @torch.fx.wrap 把它包成黑盒。
The FX trace walks through every non-leaf ancestor of a ShardedModule. Any dynamic Python along the way — {**proxy}, len(proxy), value-dependent if/for — breaks the trace. Either rewrite that code to be FX-friendly, or wrap it with @torch.fx.wrap to hide it behind an opaque call_function node.
# Native (torchrec/distributed/types.py:1245) — both waits chained: def forward(self, *input, **kwargs): ctx = self.create_context() dist_input = self.input_dist(ctx, *input, **kwargs).wait().wait() # ← wait #1 . wait #2 return self.compute_and_output_dist(ctx, dist_input) # Pipelined — wait #1 happens in wait_sparse_data_dist: def wait_sparse_data_dist(self, context): for names, awaitable in context.fused_splits_awaitables: for name, request in zip(names, awaitable.wait()): # ← wait #1 context.input_dist_tensors_requests[name] = request # ... and wait #2 happens inside PipelinedForward.__call__: def __call__(self, *input, **kwargs): request = self._context.input_dist_tensors_requests.pop(self._name) with torch.cuda.stream(self._stream): data = request.wait() # ← wait #2 return self._module.compute_and_output_dist(ctx, data)
data_dist_stream,wait #2 跑在 default_stream——两条 NCCL 延迟都被相邻 batch 的 fwd/bwd 盖住。三种 forward 变体之间的 wait 放置差异已在 §2 的 class 图里标出。
wait #1 runs on data_dist_stream, wait #2 on default_stream — both NCCL latencies are covered by neighbour batches' fwd/bwd. Per-variant wait placement is already annotated in the §2 class diagram.
| Stage 1 (splits)Stage 1 (splits) | Stage 2 (tensors)Stage 2 (tensors) | |
|---|---|---|
| 单次数据量Payload per call | 每 feature 少量 int64(数十字节)a few int64 per feature (~tens of bytes) | 真实 KJT 数据(MB–百 MB)real KJT data (MB – hundreds of MB) |
| Launch 开销占比Launch-overhead share | 压倒性(~20 μs 远大于传输时间)dominant (~20 μs ≫ transfer time) | 可忽略negligible |
| Fuse 收益Fuse ROI | 高——N 次 launch → 1 次high — N launches → 1 | 低;还会打乱 per-module pipelinelow; also serialises per-module pipelining |
| 是否 fuseIs fused? | ✓ _fuse_input_dist_splits |
✗ 按模块各自异步left per-module async |
input.variable_stride_per_key() 判断走哪条路。KJTAllToAllSplitsAwaitable.__init__ 有短路;override 没短路,但 FusedKJTListSplitsAwaitable 的 pg is None 分支会跳过 NCCL。_fuse_input_dist_splits 按 pg 分组,每个 PG 一个 fused awaitable——不跨 PG。input.variable_stride_per_key().KJTAllToAllSplitsAwaitable.__init__ short-circuits; the override does not, but FusedKJTListSplitsAwaitable skips NCCL when pg is None._fuse_input_dist_splits groups by pg; one fused awaitable per PG — no fusing across PGs.