Forward Hijack — TorchRec Pipeline 如何拆分 forward()How TorchRec Pipeline Splits forward()

NVIDIA recsys-examples · _rewrite_model · BaseForward

⭐ GitHub
On This Page

§1 为什么要拆分 forward?§1 Why Split forward()?

ShardedModule.forward() 内部串行运行两段 AllToAll——pipeline 的核心价值就是把它们拆开ShardedModule.forward() runs two AllToAll stages back-to-back — pipeline's whole job is to split them.

TorchRec 的 ShardedModule.forward() 默认实现里藏着两段通信:

  1. input_dist——把本 rank 的 KJT 按 sharding plan 做 AllToAll,分发到每个 shard 所在的 rank
  2. compute_and_output_dist——本地 embedding lookup + 再一次 AllToAll 回传结果

Native 串行下,主线程只能等 input_dist 的 NCCL 完成才能启动后续 GPU 计算,AllToAll 的毫秒级延迟完全暴露在关键路径上。

拆分核心思路:把 input_distbatch_i 的 forward 里剥离,提前到 batch_(i-1) 的 fwd/bwd 期间跑,且用一条独立 CUDA stream。这样 AllToAll 被 GPU 计算完全覆盖。

TorchRec's default ShardedModule.forward() hides two NCCL collectives:

  1. input_dist — AllToAll the local KJT per the sharding plan so each rank receives the keys it owns
  2. compute_and_output_dist — local embedding lookup + another AllToAll to ship results back

Native runs both serially; the main thread blocks on input_dist's NCCL before any downstream GPU work can start. Milliseconds of AllToAll latency land directly on the critical path.

Core idea: extract input_dist out of batch_i's forward and run it during batch_(i-1)'s fwd/bwd on an independent CUDA stream. AllToAll is now fully hidden behind GPU compute.

Native vs Pipelined forwardNative vs Pipelined forward block diagram
Native — ShardedModule.forward(kjt) main thread blocks on every NCCL · zero overlap ① input_dist(ctx, kjt) AllToAll — NETWORK BOUND ② compute_and_output_dist(ctx, data) embedding lookup — GPU BOUND ↓ Pipeline splits into two tasks on different streams ↓ Pipelined — forward is hijacked ① runs ahead on data_dist_stream · ② reads context on default_stream ① _start_data_dist(batch_i+1) on data_dist_stream — during batch_i fwd/bwd TrainPipelineContext ② PipelinedForward.__call__(batch_i) reads context · skips input_dist batch_(i+1) — next batch, pipelined ahead batch_(i) — current batch, forward pass
⚠️ 副作用⚠️ Side effects you'll meet
  • 用户代码里 self.shared_ebc(batch.feats_sparse) 的参数表达式会被 evaluate 一遍但结果被忽略——真实输入从 context 取。
  • forward 被替换后不能直接 eval——要先把原始 forward 恢复(或别走这条路径)。
  • User code like self.shared_ebc(batch.feats_sparse) still evaluates the arg expression, but the result is thrown away — the real input comes from context.
  • With the hijacked forward, plain model.eval() without restoring is unsafe — the patched forward always expects a populated context.

§2 BaseForward 类继承§2 BaseForward Class Hierarchy

3 种 pipeline 变体对应 3 种 forward 替换类——各自的 wait 时机不同3 pipeline variants · 3 forward-replacement classes · 3 different wait-timing strategies.

BaseForward 继承关系图BaseForward Inheritance class diagram
BaseForward[TForwardContext] + name · args · module · context · stream abstract __call__(*input) → Awaitable torchrec-generic; no pipeline knowledge PipelinedForward pop: input_dist_tensors_requests ★ wait #2 INSIDE __call__ → compute_and_output_dist TrainPipelineSparseDist PrefetchPipelinedForward pop: module_input_post_prefetch ✓ zero .wait() in __call__ → compute_and_output_dist PrefetchTrainPipelineSparseDist SplitPrefetchPipelined pop: self._cached_awaitable ★ split across two tasks → two-phase via _IdempotentAwaitable SWSerialTrainPipeline extends extends extends ★ wait #1 (Splits AllToAll) always happens in wait_sparse_data_dist() — see §4. wait #2 (Tensors AllToAll) is where the 3 variants diverge: inline vs. moved to prefetch vs. cached awaitable.

每个 class box 的三条元数据——pop 来源__call__ 里 wait 去留用哪个 Pipeline——对应 §4 三种不同的 wait 时机策略。

Each class box's three metadata lines — pop source, wait inside __call__, which Pipeline uses it — correspond to the three wait-timing strategies in §4.

§3 _rewrite_model 的拆分手术§3 The _rewrite_model Split Surgery

5 步 FX 追踪 + 猴子补丁,TrainPipelineContext 是新旧两端之间的黑板5-step FX trace + monkey patch; TrainPipelineContext is the blackboard between the two halves.

_rewrite_model 5 步流程 + Context 总线_rewrite_model 5-step Flow + Context Bus flow diagram
A · One-time rewrite (inside fill_pipeline) STEP 1 Unwrap DMP/DDP STEP 2 Collect ShardedModule STEP 3 FX trace (leaf = SM) STEP 4 Build ArgInfo list STEP 5 Swap .forward module.forward is now PipelinedForward B · Each iter — writers and readers of the context TrainPipelineContext — shared blackboard input_dist_splits_requests · fused_splits_awaitables · input_dist_tensors_requests · module_contexts _start_data_dist runs on data_dist_stream writes splits_requests wait_sparse_data_dist ★ wait #1 (splits AllToAll) reads splits · writes tensors PipelinedForward.__call__ ★ wait #2 · on default_stream reads tensors_requests installs PipelinedForward write read → ← write read Red = writes to context · Blue = reads from context · the bus carries awaitables across pipeline stages.

ArgInfo 是 Step 4 的产物——一条"从 batch 到 sharded 模块参数"的访问路径。简化示例:

ArgInfo is Step 4's output — a recorded access path from batch to the sharded-module arg. Stripped example:

# User wrote:  self.shared_ebc(batch.sparse_features)
# FX trace recovers this access path into ArgInfo ↓
ArgInfo(
    input_attrs    = ["", "sparse_features"],  # getattr chain from batch
    is_getitems    = [False, False],
    postproc_modules = [None, None],
    constants      = [None, None],
    name           = None,                     # positional arg
)
# Later, _start_data_dist replays this on batch_i+1:
#     batch_i+1 -> getattr "sparse_features" -> ShardedModule.input_dist(...)
Context 字段Context field 写入方Written by 读取方Read by
input_dist_splits_requests _start_data_dist _fuse_input_dist_splits
fused_splits_awaitables _fuse_input_dist_splits wait_sparse_data_dist
input_dist_tensors_requests wait_sparse_data_dist PipelinedForward.__call__
module_contexts _start_data_dist PipelinedForward.__call__
module_input_post_prefetch _prefetch (Prefetch variant) PrefetchPipelinedForward.__call__
module_contexts_post_prefetch _prefetch (Prefetch variant) PrefetchPipelinedForward.__call__
⚠️ 模型必须 FX-traceable⚠️ Model must be FX-traceable FX trace 会穿过 ShardedModule 的所有非-leaf 祖先——途中任何一段动态 Python 逻辑({**proxy}len(proxy)、Python-level if/for 依赖 tensor 值)都会让 trace 失败。要么把那段代码改 FX-friendly,要么用 @torch.fx.wrap 把它包成黑盒。 The FX trace walks through every non-leaf ancestor of a ShardedModule. Any dynamic Python along the way — {**proxy}, len(proxy), value-dependent if/for — breaks the trace. Either rewrite that code to be FX-friendly, or wrap it with @torch.fx.wrap to hide it behind an opaque call_function node.

§4 两次 .wait():Native vs Pipelined§4 Two .wait()s: Native vs Pipelined

ShardedModule.input_dist() 返回 Awaitable[Awaitable[X]]——这个双层类型就是为 pipeline 拆分而设计的"留白"ShardedModule.input_dist() returns Awaitable[Awaitable[X]] — the double layer is the seam pipeline uses to split work.

wait #1 / wait #2 的去处对比Where wait #1 and wait #2 live sequence diagram · side-by-side
Native — two waits back-to-back Pipelined — waits split across stages user code ShardedModule input_dist model(batch) input_dist(ctx, kjt) Awaitable[Awaitable[X]] ★ .wait() (#1 splits) ★ .wait() (#2 tensors) Both waits run serially on the main thread. → AllToAll latency is ENTIRELY exposed. → No opportunity to overlap with fwd/bwd of prior batch. embedding output "self.input_dist(...).wait().wait()" — types.py:1245 _start_data_ wait_sparse_ context PipelinedFwd push splits_request pop splits ★ wait #1 (splits) push tensors_request pop tensors ★ wait #2 (tensors) compute_and_output_dist Each wait runs on a dedicated stream. → wait #1 on data_dist_stream, wait #2 on default_stream → both NCCL latencies hidden behind fwd/bwd of neighbor batches wait #1 @ train_pipeline.py:435 · wait #2 @ utils.py:519
# Native (torchrec/distributed/types.py:1245) — both waits chained:
def forward(self, *input, **kwargs):
    ctx = self.create_context()
    dist_input = self.input_dist(ctx, *input, **kwargs).wait().wait()  # ← wait #1 . wait #2
    return self.compute_and_output_dist(ctx, dist_input)

# Pipelined — wait #1 happens in wait_sparse_data_dist:
def wait_sparse_data_dist(self, context):
    for names, awaitable in context.fused_splits_awaitables:
        for name, request in zip(names, awaitable.wait()):        # ← wait #1
            context.input_dist_tensors_requests[name] = request

# ... and wait #2 happens inside PipelinedForward.__call__:
def __call__(self, *input, **kwargs):
    request = self._context.input_dist_tensors_requests.pop(self._name)
    with torch.cuda.stream(self._stream):
        data = request.wait()                                            # ← wait #2
    return self._module.compute_and_output_dist(ctx, data)
4 路 stream × 3 个 iteration 的完整时序4 streams × 3 iterations timeline Gantt swimlane
iter i-1 iter i iter i+1 memcpy H2D copy data_dist input_dist NCCL prefetch dyn cache warmup default fwd · bwd · optim fwd + bwd + optim (iter i-1) H2D fused AllToAll ★ wait #1 prefetch ★ wait #2 model(batch_i) fwd · bwd · optim H2D fused AllToAll prefetch model(batch_i+1) fwd · bwd · optim Overlap: while default runs iter-i compute, memcpy/data_dist/prefetch already pre-stage iter-(i+1). Both wait #1 and wait #2 land on dedicated streams — neither blocks the default-stream critical path.
💡 两次 wait 分别在哪条 stream 上💡 Each wait rides its own stream wait #1 跑在 data_dist_stream,wait #2 跑在 default_stream——两条 NCCL 延迟都被相邻 batch 的 fwd/bwd 盖住。三种 forward 变体之间的 wait 放置差异已在 §2 的 class 图里标出。 wait #1 runs on data_dist_stream, wait #2 on default_stream — both NCCL latencies are covered by neighbour batches' fwd/bwd. Per-variant wait placement is already annotated in the §2 class diagram.

§5 把 N 次 splits AllToAll 融成一次§5 Fusing N splits AllToAlls into one

下图展示 Stage 1 为什么值得 fuse、Stage 2 为什么不动The sequence below shows why Stage 1 splits (not Stage 2 tensors) are the fusing target.

N 次独立启动 vs 一次 fused 启动N independent launches vs one fused launch sequence diagram
Native — each module launches its own splits AllToAll module_1 module_2 module_N NCCL (PG) ... all_to_all_single(splits_1) all_to_all_single(splits_2) all_to_all_single(splits_N) N × ~20 μs launch overhead ↓ override returns metadata, _fuse_input_dist_splits concats, ONE NCCL launches ↓ Fused — metadata collected, one all_to_all_single per PG module_1 module_2 module_N _fuse_input_dist_splits NCCL (PG) ... return KJTSplitsAllToAllMeta (no NCCL) return Meta return Meta stack + flatten → ONE all_to_all_single 1 × ~20 μs
Stage 1 (splits)Stage 1 (splits) Stage 2 (tensors)Stage 2 (tensors)
单次数据量Payload per call 每 feature 少量 int64(数十字节)a few int64 per feature (~tens of bytes) 真实 KJT 数据(MB–百 MB)real KJT data (MB – hundreds of MB)
Launch 开销占比Launch-overhead share 压倒性(~20 μs 远大于传输时间)dominant (~20 μs ≫ transfer time) 可忽略negligible
Fuse 收益Fuse ROI 高——N 次 launch → 1 次high — N launches → 1 低;还会打乱 per-module pipelinelow; also serialises per-module pipelining
是否 fuseIs fused? _fuse_input_dist_splits 按模块各自异步left per-module async
💡 Edge cases💡 Edge cases
  • variable_stride_per_key=True(每 key 变长 batch):override 里用 input.variable_stride_per_key() 判断走哪条路。
  • 单 worker (world_size=1):原版 KJTAllToAllSplitsAwaitable.__init__ 有短路;override 没短路,但 FusedKJTListSplitsAwaitablepg is None 分支会跳过 NCCL。
  • 多 PG (hybrid sharding)_fuse_input_dist_splitspg 分组,每个 PG 一个 fused awaitable——不跨 PG。
  • variable_stride_per_key=True (per-key variable batch): the override branches on input.variable_stride_per_key().
  • Single worker (world_size=1): native KJTAllToAllSplitsAwaitable.__init__ short-circuits; the override does not, but FusedKJTListSplitsAwaitable skips NCCL when pg is None.
  • Multiple PGs (hybrid sharding): _fuse_input_dist_splits groups by pg; one fused awaitable per PG — no fusing across PGs.