Updating 1T parameters in seconds — P2P weight transfer in Large Scale Distributed RL
We introduced a RDMA-based, Peer to Peer weight update mechanism for RL workloads in SGLang as a supplement to traditional NCCL broadcast methods, compatible with all major open source models. By utilizing a source-side CPU engine replica and P2P RDMA transfers via Mooncake TransferEngine, we speed up weight transfer times for 1T-parameter Kimi-K2 7 times (53 seconds -> 7.2 seconds), at the cost of one additional inference engine replica (32G) per training rank on CPU memory. These optimizations minimize network redundancy and allow inference servers to resume rollout significantly faster.
Background
NVIDIA's NCCL optimizes primitives like all-gather and broadcast by auto-detecting hardware topology and coordinating data flow via ring or tree algorithms. As the default communication backend for PyTorch FSDP, DeepSpeed, and Megatron-LM, it is the industry standard for symmetric training. However, it relies on collective semantics, requiring every rank to invoke the same operation simultaneously with matching data shapes. While highly efficient for balanced workloads, this design becomes a liability in dynamic environments: NCCL operates in lock-step, meaning a "slow-start" from a single receiver can hang the entire group and leave resources idle.
RDMA (Remote Direct Memory Access) allows machines to access remote memory while bypassing the remote CPU and kernel networking stack entirely. Its efficiency stems from three core characteristics:
- Kernel Bypass: Applications submit Work Requests directly to the NIC, eliminating expensive system calls and context switching.
- Zero Copy: Data moves directly between registered Memory Regions and the network via DMA, avoiding intermediate copies within the kernel buffer.
- One-Sided Operations: RDMA READ/WRITE operations are initiated by one side, requiring no active CPU participation or interrupt handling on the remote end.
Unlike NCCL's global synchronization, RDMA allows any two endpoints to communicate independently and concurrently, making it the ideal foundation for high-speed weight transfers. This is precisely why the P2P weight update mechanism described here utilizes RDMA-based transfers via the Mooncake TransferEngine as its underlying foundation.
The RL Weight Transfer Problem: In large distributed RL training, weight transfer from trainers to inference engines is a critical-path operation: During weight transfer, the entire RL training comes to a halt — both trainer and inference are not making progress and resources are typically sitting idle. As models grow, this transfer must scale across multiple hosts and racks, all fighting for limited bandwidth. The existing NCCL-based workflow on open source solutions on miles/slime/verl relies on a broadcast primitive from a single source rank, which quickly becomes the bottleneck during transfer.
Left: Current weight transfer workflow in miles during distributed training/inference RL. On the source side, all nodes participate in all gather on TP and EP dimension, resulting in a gathered tensor at head rank for each PP rank. The head rank participates in a distributed update group to broadcast the full weight to every engine rank through the update_weight_from_distributed API, where the local rank loads its corresponding shard. This process runs for every PP rank, and for every bucketed weight tensors. Right: The P2P update design relies on a source side engine replica as an intermediary. The bucketed weight update has its first all gather step identical to miles. But then, the weights get loaded to a local replica of sglang engine shard on CPU memory, which sends its weight to sglang in the correct shape. Each replica's weight can be sent to multiple sglang ranks. Each target sglang TP ranks need to receive from every pp source.
Challenges with Existing NCCL Broadcast
The existing NCCL broadcast solution has the following challenges:
- Redundancy: Identical data is sent multiple times across the network.
- Inactivity: Most trainer ranks remain idle during the transfer while only a few participate in the broadcast.
- Rigidity: Once defined, the NCCL communication group is fixed, so dynamic scaling with newly created engine instances is an involved change.
This comparison evaluates the performance of transferring the 1T FP8 Kimi K2 model (~1TB). Note: the update_weights_from_tensor interface is excluded as it only supports colocated scenarios.
| Strategy | Efficiency | Open Source Support | Dynamic | Training Side Support | System Complexity | Architectural Flexibility |
|---|---|---|---|---|---|---|
Disk I/O Strategy (update_weights_from_disk) | 🆘 ~Several Minutes | ✅ Yes | ✅ Yes | ✅ Megatron FSDP | 😊 Simple | 😊 Single API |
NCCL Broadcast (update_weights_from_distributed) | 🥵 ~50 Seconds | ✅ Yes | 🚫 No (Requires NCCL group rebuild) | ✅ Megatron FSDP | 😊 Simple | 😊 Single API |
| Perplexity fabric-lib P2P | ⚡ ~1.2 Seconds | 🚫 No (RDMA lib only) | ✅ Yes | ❓ FSDP2 DTensor only | 🥵 Very Complex | 🥵 Write-only |
| RDMA P2P (Our Implementation) | 😊 ~7 Seconds | ✅ Yes | ✅ Yes | ✅ Megatron FSDP | 😥 Complex | 😀 Multiple APIs |
While there is a trade-off in transmission efficiency compared to Perplexity's approach, our solution offers a significant performance boost over existing SGLang interfaces. Furthermore, we have achieved high architectural flexibility by encapsulating these capabilities into new API interfaces. Refer to miles on run instructions, and full list of supported models.
Design
Our design shifts from a centralized broadcast to a distributed P2P mapping via RDMA (remote direct memory access); while staying compatible with all existing open source models and any parallelism configurations, reusing existing interfaces.
- Source-Side Engine Replicas: We create model replicas in the CPU memory of training ranks. This avoids wasting GPU VRAM without repetitive registration and de-registration.
- P2P Mapping Heuristics: We implement a peer-to-peer mapping between trainer ranks and inference ranks. Instead of a few ranks broadcasting everything, every trainer rank participates by sending its specific shard directly to the target.
- Zero-Copy Transfer: Using TransferEngine, memory is registered once at startup, bypassing the expensive serialization of CUDA IPC handles and kernel side copies.
The implementation heavily relies on existing infrastructure and interfaces:
- TransferEngine serves as the underlying transport layer to enable RDMA zero-copy transfer between CPU and GPU on the network.
- Reuse weight registration information through Rfork, the new remote instance weight loading mechanism exposed through SGLang API.
- The standard SGLang API of
load_weight(huggingface_tensor), that supports all quantization and sharding configurations.
Several new interfaces are needed on the SGLang side:
- Exposing model parallelism for replica creation: PR #20907
- Mapping hugging face tensor with its corresponding SGLang tensor shard: PR #17326
- A post-process weight engine call for GPU-local processing such as post-quantization similar to PR #15245.
Which are merged in the miles targeted sglang-miles branch. During the weight update, the caller side operates as below:
Initialization
| Step | Description |
|---|---|
get_remote_instance_transfer_engine_info | Call SGLang API to get weight registration info |
get_parallelism_info | Call SGLang API to get parallelism definition info (tp, ep, etc) |
build_transfer_plan | Construct training -> inference rank mapping relationship |
create_engine_replica | Create CPU engine replica |
During Each Update
| Step | Description |
|---|---|
pause_and_register_engine | Call SGLang API to pause engine, and register replica weights (once) |
update_weight (non-expert and expert) | Bucketed weight update, non-expert then expert weights |
post_process_weights | Call SGLang API to post process loaded weights, like quantization |
update_weight_version | Call SGLang API to update weight version |
continue_generation | Call SGLang API to resume operation |
The result is a general purpose weight update design that handles any model and all common quantization logic, while achieving fast RDMA zero-copy transfer with no redundancy, and higher bandwidth utilization. Imagine a scenario of M source ranks for training, and N target ranks for SGLang inference; source rank has pp_size of pp, and target rank has ep_size of ep; each engine rank has P parameters. We also allocate K as a memory buffer for bucketed all gather. If we assume the model only contains the expert weights:
| #Participating Source Ranks | #Params received per inference rank | #Additional buffer allocated on source | #Additional buffer allocated on target | |
|---|---|---|---|---|
| NCCL Broadcast | pp | ep * P | K | K |
| RDMA P2P | M | P | K* + P | 0 |
Table: An illustration of how RDMA P2P design trades off memory allocation to achieve less network transmission and more utilization. All source ranks participate, versus only the head rank of each pipeline parallel group with NCCL, only necessary tensors are sent across the network, versus NCCL needing a full all-gathered tensor be sent to each rank. RDMA P2P trades it off with additional memory allocation of P on the source CPU, while no longer needing any memory allocation on the receiving side. We note K* is in practice slightly larger than K, due to how some hugging face tensors need to be cached locally before the full SGLang tensor is updated, as there often exist a multi-to-one tensor mapping (q_proj, k_proj, v_proj → qkv_proj in sglang).
Implementation Results
We profile the transfer speed on common open source models on H100 8 GPU hosts with Infiniband connection. The time is logged between engine pause call returns and continue_generation call.
| Model Family | Model Name | Total Param | Train Config | Inference Config | NCCL (ms) | RDMA (ms) | Speedup |
|---|---|---|---|---|---|---|---|
| GLM4 | GLM-Z1-9B-0414 | 9B | TP=2, PP=1, CP=2, EP=1, ETP=1, 1 node | TP=4, EP=1, 1 node | 694.6 | 707.1 | 0.98x |
| DeepSeek-V2 ★ | Moonlight-16B-A3B | 16B(3B) | TP=2, PP=1, CP=1, EP=8, ETP=1, 1 node | TP=8, EP=8, 1 node | 1,482.0 | 1,073.3 | 1.38x |
| GLM4-MoE ★ | GLM-4.7-9B-Flash | 30B(3B) | TP=4, PP=1, CP=1, EP=8, ETP=1, 1 node | TP=4, EP=4, 1 node | 2,508.6 | 4,229.0 | 0.59x |
| Qwen3-MoE ★ | Qwen3-30B-A3B | 30B(3B) | TP=4, PP=1, CP=1, EP=8, ETP=1, 2 nodes | TP=8, EP=8, 2 nodes | 2,670.0 | 2,160.2 | 1.24x |
| GLM4-MoE ★ | GLM-4.5-Air | 106B(12B) | TP=1, PP=4, CP=1, EP=8, ETP=1, 4 nodes | TP=8, EP=8, 4 nodes | 5,001.1 | 2,637.2 | 1.90x |
| Qwen3-MoE ★ | Qwen3-235B-A22B | 235B(22B) | TP=4, PP=4, CP=2, EP=16, ETP=1, 8 nodes | TP=32, EP=32, 8 nodes | 10,753.6 | 3,162.0 | 3.40x |
| DeepSeek-V3p2 ★ | GLM-5 | 744B(40B) | TP=4, PP=8, CP=2, EP=16, ETP=1, 16 nodes | TP=64, EP=64, 16 nodes | 58,301.5 | 8,479.7 | 6.88x |
| DeepSeek-V3 ★ | Kimi-K2-fp8 (64-block-quantized) | 1T(64B) | TP=8, PP=8, CP=4, EP=32, ETP=1, 32 nodes | TP=32, EP=32, 32 nodes | 53,279.1 | 7,227.3 | 7.37x |
NOTE: Kimi-K2 special handling: We adjusted Kimi K2 to use [64, 64] block-quant size in fp8 to fit in our profiling configuration.
The performance gains are most visible in large MoE (Mixture-of-Experts) architectures with high expert parallelism on the rollout side. At low node configuration with the above GLM4-MoE example, when EP is small, the cost of loading tensors onto the CPU model locally outweighs the benefit of P2P transfer. P2P transfer scales well with more nodes involved.
Usage
In miles enable P2P update via --update-weight-transfer-mode p2p. It will let SGLang engines register their weight memories via --sglang-remote-instance-weight-loader-start-seed-via-transfer-engine, and choose P2P update flow over NCCL broadcast. Miles depends on the sglang-miles branch on SGLang, with more advanced experimental features supporting P2P transfer.
Future Plans
- Extending support: Offer official support for newer hardware like GB200, and SGLang side pipeline parallel. Support more quantizations. Merge SGLang side changes to main.
- Experiment Huge Page Allocation: Instead of permanently allocating CPU memory, consider enabling Transfer Engine huge page GPU allocation that can drastically improve registration and deregistration cost. This can enable in-place GPU replica creation and memory registration at transfer time.
Engineering Appendix
Design Iterations
Our initial design had placed the source side replica on the GPU. To avoid wasting training-time GPU memory utilization, significant efforts were put into optimization there — how to register, transfer, and deregister in a pipelined fashion? Could we offload the model onto the CPU while keeping the memory registration via virtual memory intact on NIC? Surprisingly, given the massive bandwidth available on modern clusters, the transfer itself is the least time-consuming — in the case of RDMA, weight registration is the biggest time sink in comparison, taking up tens of seconds for the entire replica. Moving to the CPU resolved it.
Another blocker is the CPU can go OOM initially. Multiple GPUs share the same HBM and creating a replica for every target rank on every source rank quickly becomes unmanageable — but we quickly realized that all SGLang ranks are built homogenous, meaning we could sacrifice a little bit of transfer concurrency to prioritize memory reuse. Our final design reuses the same underlying physical memory and carefully orchestrates the transfer to different engine shards one by one.
Peer to Peer Transfer Plan
Maps each training rank to its target rollout engine rank(s). Uses round-robin assignment with load balancing: the first ranks get 1:1 mapping, remaining targets are distributed evenly. This minimizes the number of RDMA sessions per source.
Imagine training pp=4, with 32 training ranks and 2 instances of SGLang engine, each with 16 ranks. After all-gather, every rank in each pp group contains the fully all-gathered tensor for the specific pp rank. Every target rank needs to receive weight from every pp rank. We start with pp_rank=0: where we need to map 8 training ranks to all 32 target ranks.
- Round-Robin: mapping
src_rank 0 -> tgt_rank 0, … mappingsrc_rank 7 -> tgt_rank 7 - All existing sources have the same load. Another round of round-robin to assign
src_rank 0 -> tgt_rank 8, …src_rank 7 -> tgt_rank 15 - Finally, note
tgt_rank 16is identical compared totgt_rank 0. Look back at existing assignments, and add identical engine rank to its existing source. We end up withsrc_rank 0 -> [tgt_rank [0, 16], tgt_rank [8, 24]], etc.
Identify SGLang Tensors to Transfer
To execute the weight transfer itself in the bucketed fashion, we need to identify parameters after each model.load_weight() that are ready to be sent across. We first construct a mapping between the hugging face tensor and its corresponding SGLang tensor or tensor shard.
sglang_name, shard_id, num_shards, expert_id, num_local_experts = parameter_mapper.map(hf_tensor)
Where for each model and for each hugging face tensor, we get the expert and tensor shard it maps to within SGLang. For example:
model.layers.0.mlp.experts.3.down_proj.weight -> model.layers.0.mlp.experts.w2_weight, w2, 2, 3, 5
We can only send a SGLang tensor once all num_shards for all num_local_experts have been updated.
Transfer Flow with Shared Replica
To support transferring to multiple target engine instances using the same underlying memory, we leverage a Threadpool based task pool and a cache buffer for partially updated tensors. A transfer task for a SGLang tensor is either on the critical path and needs to be waited for, or can be submitted to the pool and checked for completion only at the very end. TransferEngine library removes GIL and enables parallelism in multi-thread config.
Figure: Example of transfer flow for one training source rank and 2 target ranks, using the same underlying replica.
Note that the red tensors are after all-gather and include all experts and tp shard. However, one bucketed update may not contain all tensor shards for a SGLang tensor — as shown on the diagram, q_proj and k_proj were collected but buffered outside of the replica from the previous bucketed update. This is also why the shared replica update needs buffer size larger than 1. In the worst case, the buffer could be num_shard times the original buffer! However, as the named_parameters typically order related tensors together, the actual extra buffer needed is small.
Once a SGLang tensor has collected all necessary shards, we update the replica to load weights corresponding to the first target engine rank (SGLang engine rank 0 in example) using its own parallelism info. The transfer (send 1) is on the critical path, and must be completed (a.k.a target weight updated) before proceeding to updating the underlying CPU shard again, using the same hugging face tensor, but using parallelism information from another engine rank. For the last engine rank to be updated, no one is waiting on the same SGLang tensors anymore, so the task (send 2) can go to the threadpool and we continue to the next all gather immediately.
Other Quantization and Post Processes
Not all necessary tensors are registered in SGLang's model.named_parameters(). For example, with the DeepSeekV3 model lineage, the MLA contains local tensors w_kc and w_vc, which are generated after all weights are fully loaded in a post_weight_load() call. SGLang contains a huge number of custom quantization and hardware-specific logic that depends on device and can not be executed on our CPU replica.