Achieving Faster Open-Source Llama3 Serving with SGLang Runtime (vs. TensorRT-LLM, vLLM)
by: The SGLang Team, Jul 25, 2024
At LMSYS.org, we've been running the Chatbot Arena platform for over a year, serving millions of users. We know firsthand how crucial efficient serving is for AI products and research. Through our operational experiences and in-depth research, we've continuously enhanced the underlying serving systems, spanning from the high-level multi-model serving framework, FastChat, to the efficient serving engine, SGLang Runtime (SRT).
This post focuses on SGLang Runtime, a general-purpose serving engine for LLMs and VLMs. While existing options like TensorRT-LLM, vLLM, MLC-LLM, and Hugging Face TGI have their merits, we found them sometimes hard to use, difficult to customize, or lacking in performance. This motivated us to develop SGLang v0.2, aiming to create a serving engine that is not only user-friendly and easily modifiable but also delivers top-tier performance. While SGLang includes frontend language features, this post will focus solely on the backend runtime and use "SGLang" and "SGLang Runtime" interchangeably to refer to the runtime.
Compared to TensorRT-LLM and vLLM, SGLang Runtime consistently delivers superior or competitive performance in both online and offline scenarios, handling models from Llama-8B to Llama-405B, and on A100 and H100 GPUs, using FP8 and FP16. SGLang consistently outperforms vLLM, achieving up to 3.1x higher throughput on Llama-70B. It also often matches or sometimes outperforms TensorRT-LLM. More importantly, SGLang is fully open-source, written in pure Python, with the core schedulers implemented in fewer than 4K lines of code.
SGLang is an open-source project licensed under the Apache 2.0 license. It has been used by LMSYS Chatbot Arena to support parts of the models, Databricks, several startups, and research institutes, generating trillions of tokens and enabling faster iterations. As it gradually matures from a research prototype, we invite the community to join us in creating the next-generation efficient engine.
Benchmark Setup
We benchmark both offline and online use cases:
- Offline: We send 1K to 6K requests at once, measuring output throughput (tokens/second), defined as the number of output tokens divided by the total duration. The tested datasets include several synthetic datasets and the ShareGPT dataset. We use Input-512-Output-1024 to indicate a dataset where the input lengths are sampled from a uniform distribution [1, 512] and the output lengths from [1, 1024].
- Online: We send requests at rates ranging from 1 to 16 requests per second (RPS), measuring the median end-to-end latency. We use the synthetic dataset Input-1024-Output-1024.
We use vLLM 0.5.2 with default arguments and TensorRT-LLM v0.10.0 with the recommended arguments and tuned batch sizes. The prefix cache is turned off for all engines. The purpose is to benchmark the base performance without any additional features, such as speculative decoding or caching. OpenAI-compatible APIs are used to benchmark SGLang and vLLM, and the Triton interface for TensorRT-LLM.
More details and reproducible scripts are provided in Appendix A. For each model, we will first present the offline results and then the online results.
Update (2024-07-26 4 AM PST): We noticed some issues in our original synthetic data generation pipeline, which primarily generated short inputs, making the dataset description in the first version of this blog post inaccurate. In the current version, we have fixed these issues and introduced more dataset configurations to cover both long and short inputs.
Llama-8B on 1 x A100 (bf16)
Starting with the small model Llama-8B, the figure below shows the maximum output throughput each engine can achieve in offline settings across six different datasets. Both TensorRT-LLM and SGLang can achieve an excellent throughput of up to 5000 tokens per second on a dataset with short inputs, while vLLM lags behind.
The online benchmark figure below shows a trend similar to the offline case. TensorRT-LLM and SGLang perform equally well and can sustain an RPS > 10, while the latency of vLLM increases significantly at a high request rate.
Llama-70B on 8 x A100 (bf16)
Moving to the larger Llama-70B models with tensor parallelism on 8 GPUs, the trend is similar to the case with 8B. In the offline benchmark below, both TensorRT-LLM and SGLang can scale to a high throughput.
In the online figure below, TensorRT-LLM shows excellent latency performance thanks to its highly efficient kernel implementations and runtime.
Llama-70B on 8 x H100 (fp8)
Now, let us test the FP8 performance. Both vLLM and SGLang use FP8 kernels from CUTLASS. In the offline setting, SGLang’s batch scheduler is very efficient and can continue to scale the throughput with larger batch sizes, achieving the highest throughput in this case. Other systems cannot scale their throughput or batch sizes due to OOM, missing extensive manual tuning, or other overheads. Generally, SGLang performs better on short inputs, while TensorRT-LLM performs better on long inputs. This is likely due to their different kernel implementations and batch scheduling policies.
The above trend continues in the online case as well, with both SGLang and TensorRT achieving similar median latency.
Llama-405B on 8 x H100 (fp8)
Finally, we benchmarked the performance on the largest 405B model. Because the model is large, most of the time is spent on the GPU kernels. The limited KV cache size makes less room for scheduling as well, so the gap between different frameworks shrinks. SGLang still outperforms vLLM, but the improvement is less significant. As the 405B model just came out, some of the latest optimizations in TensorRT-LLM have not been included in the pre-built Docker image, so we omitted the performance of TensorRT-LLM here. We are working with the NVIDIA team to correctly benchmark the performance of TensorRT-LLM on this model.
SGLang Overview
SGLang is a serving framework for large language models and vision-language models. It builds on and enhances many good designs from several open-source LLM serving engines, including LightLLM, vLLM, and Guidance. It leverages high-performance attention CUDA kernels from FlashInfer and integrates torch.compile inspired by gpt-fast.
Additionally, we introduced innovations such as RadixAttention for automatic KV cache reuse and compressed state machine for fast constrained decoding. SGLang is known for its highly efficient batch scheduler, which is implemented entirely in Python. SGLang's efficient Python-based batch scheduler scales well, often matching or even outperforming closed-source implementations built with C++. The speedup shown in this blog post mainly comes from the excellent system engineering.
The table below compares various aspects of SGLang, TensorRT-LLM, and vLLM. In terms of performance, both SGLang and TensorRT-LLM excel. Regarding usability and customizability, SGLang's lightweight and modular core makes it easy to customize, whereas TensorRT-LLM's complex C++ tech stack and setup instructions make it harder to use and modify. SGLang's source code is fully open-source, while TensorRT-LLM is only partially open-source. In contrast, vLLM suffers from high CPU scheduling overhead.
SGLang | TensorRT-LLM | vLLM | |
---|---|---|---|
Performance | Excellent | Excellent | Fair |
Usability | Good | Poor | Good |
Customizability | High | Low | Medium |
Source Code Availability | Fully Open | Partially Open | Fully Open |
Programming Language | Python | C++ | Python |
What is Next
We're excited to share our latest benchmark results. While there's still more to do, this shows our philosophy of developing a simple, customizable, and high-performance serving engine is achievable. Stay tuned for new features like long context and MoE optimizations, and detailed technical walkthroughs. Join us in building the next-generation serving engine at https://github.com/sgl-project/sglang.
Try Llama Serving
You can serve a Llama model easily with the following steps.
- Install SGLang with pip, from source, or using Docker.
- Launch a server:
# Llama 8B python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct # Llama 405B python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-405B-Instruct-FP8 --tp 8
- Send a request with the OpenAI-compatible API:
curl http://localhost:30000/v1/completions \ -H "Content-Type: application/json" \ -d '{ "model": "default", "prompt": "Say this is a test", "max_tokens": 7, "temperature": 0 }'
- Run the benchmark:
python3 -m sglang.bench_serving --backend sglang --num-prompts 1000
The Team
This blog post is contributed by Liangsheng Yin, Yineng Zhang, Ying Sheng, and over 65 open-source contributors. We thank the support from Databricks, and Ying Sheng’s work was done at Databricks. We especially thank Lianmin Zheng, Zihao Ye, and Horace He for their technical support, Matei Zaharia for his helpful advice, and Cody Yu for his feedback.
Appendix A: Detailed Benchmark Setups
The instructions to reproduce the benchmark is at sglang/benchmark/blog_v0_2.
For all benchmarks, we set `ignore_eos` or `min_length/end_id` to ensure each engine outputs the same number of tokens. We tried using vLLM 0.5.3.post1, but it often crashes under high loads and seems to have similar or worse performance compared to vLLM 0.5.2 from our partial benchmarking. Therefore, we report results from vLLM 0.5.2 instead. While we are aware that different server configurations can significantly impact serving performance, we mostly use the default arguments in each engine to mimic the case of a normal user.
For the 8B and 70B models, we use the meta-llama/Meta-Llama-3-8B-Instruct and meta-llama/Meta-Llama-3-70B-Instruct bf16 checkpoints, and the neuralmagic/Meta-Llama-3-70B-Instruct-FP8 fp8 checkpoint. For the 405B models, we use dummy weights for all benchmarks. Since the TensorRT-LLM latest image r24.06 does not support fbgemm_fp8 quantization in the official meta-llama/Meta-Llama-3.1-405B-FP8 checkpoint, we use per-layer fp8 quantization in all frameworks and quantize all layers except lm_head. We believe this provides a fair comparison among all engines. The A100 and H100 GPUs are 80GB SXM versions.