Multi-thread, multi-stream execution with data dependency & NCCL ordering guarantees
stream decides which CUDA stream context it runs in.
The thread_map decides which CPU worker thread submits it. They are independent.
ThreadedExecutor(thread_map="by_stream") # default: group by task.stream
ThreadedExecutor(thread_map="per_task") # every task gets its own thread
ThreadedExecutor(thread_map={"h2d": "io", "fwd": "compute", ...}) # explicit dict
ThreadedExecutor(thread_map=lambda t: "io" if t.stream=="memcpy" else "compute") # callable
wait_stream — from cross_stream_waits (stream-based)threading.Event — from DAG analysis (thread-based)
Schedule (tasks with reads/writes/depends_on)
│
┌──────────────────┼──────────────────┐
▼ ▼
infer_cross_stream_waits() _compute_cpu_deps()
(deps.py — stream-based) (executor.py — thread-based)
│ │
▼ ▼
GPU: consumer_stream.wait_stream( CPU: threading.Event.wait()
producer_stream) before task.run() on different
thread than the dependency
Declaration order: [task_A (nccl), task_B (nccl), task_C (nccl)]
Ticket assignment: ticket=0 ticket=1 ticket=2
_NcclOrderedLock state machine:
┌─────────────────────────────────────────────────────────────┐
│ next_ticket = 0 │
│ │
│ Thread X: acquire(ticket=0) → runs immediately │
│ task_A.run() │
│ release() → next_ticket = 1, notify_all() │
│ │
│ Thread Y: acquire(ticket=1) → was waiting, now runs │
│ task_B.run() │
│ release() → next_ticket = 2, notify_all() │
│ │
│ Thread X: acquire(ticket=2) → runs │
│ task_C.run() │
│ release() → next_ticket = 3 │
└─────────────────────────────────────────────────────────────┘
If task_B fails:
release(failed=True) → self._failed = True
Thread X: acquire(ticket=2) → sees _failed → raises RuntimeError
→ task_C never runs (no desync across ranks)
batch_offset.
Within one internal iteration, tasks with different offsets operate on different batches
in the BatchRing simultaneously.
Task with batch_offset=k runs at iteration i iff:
(max_offset - k) ≤ i < M + (max_offset - k)
Where M = number of batches pulled from dataloader.
For in_flight_batches=2 (max_offset=1), M=5:
k=1 (h2d): runs at i=0,1,2,3,4 (prefetch: one iteration ahead)
k=0 (compute): runs at i=1,2,3,4,5 (steady + drain)
i=0: prefill — only h2d(batch_0) runs
i=1: steady — h2d(batch_1) + compute(batch_0) overlap!
i=2: steady — h2d(batch_2) + compute(batch_1) overlap!
...
i=5: drain — only compute(batch_4) runs
ThreadedExecutor, these can run on different CPU threads too.
thread_map, stream_map, and fire_ordering.
Each candidate then calls the existing auto_assign_lookaheads lookahead oracle before scoring.
max_threads <= 4
max_in_flight <= 5
forward / backward / finalize_model_grads / optimizer_step:
stream="default", lookahead=0
NCCL tasks with the same comm group serialize even on different streams.