Fine-tune and deploy gpt-oss MXFP4: ModelOpt + SGLang
by: NVIDIA ModelOpt Team, Aug 28, 2025
(Updated on Aug 29)
OpenAI recently released gpt-oss, the first open source model family from OpenAI's lab since GPT-2. These models demonstrate strong math, coding, and general capabilities. Part of the model's uniqueness is that it was released in native MXFP4 weight only quantization. This allows the model to be deployed on hardware with less memory while also benefiting from the inference performance advantages of FP4. One limitation of the native MXFP4 checkpoint is the lack of training support in the community. Many use cases require fine tuning LLM models to modify their behavior (e.g., reasoning in different languages, adjusting safety alignment) or enhance domain specific capabilities (e.g., function calling, SQL scripting). Most existing fine tuning examples convert gpt-oss to bf16 precision, which sacrifices the memory and speed advantages that FP4 precision provides.
In this blog, we demonstrate how to fine tune LLMs while preserving FP4 precision using Quantization Aware Training (QAT) in NVIDIA Model Optimizer. We then show how to deploy the resulting model with SGLang. Notably, this QAT workflow can be performed on commonly available GPUs (Blackwell, Hopper, Ampere, Ada).
What is Quantization-Aware Training (QAT)
QAT is a training technique to recover model accuracy from quantization (simple illustration below). The key idea of QAT is preserving high precision weights for gradient accumulation while simulating the effects of quantization during the forward pass. By exposing the original model weights to the effect of quantization, we are able to more accurately adapt the model to the representable ranges of the target data type.
Different Low-Precision Training Techniques
It should be noted that native quantized training and QLoRA are often confused with QAT, but they serve different purposes. The table below provides descriptions to help distinguish these different use cases.
Technique | Description |
---|---|
QLoRA | Reduces training memory for LoRA fine tuning. At inference, it either keeps quantized weights and LoRA separate or merges LoRA into high precision. weights. |
Native quantized training | Enables efficient training and inference. Requires native hardware support. |
QAT | Improves quantized inference accuracy. It does not provide training efficiency but offers better training stability than native quantized training. |
QAT Fine-tuning Recipe for gpt-oss
The steps to perform QAT fine tuning are quite straightforward and can be completed in a few steps:
- Step 1 (Optional): Fine tune the model in the original precision. This establishes a good starting point before QAT.
- Step 2: Insert quantizer nodes into the model graph. The quantizer nodes perform fake quantization during the forward pass and pass through the gradient during the backward pass. This step is handled by Model Optimizer.
- Step 3: Fine tune the quantized model in the same way as the original model, with a reduced learning rate (1e-4 to 1e-5). The fine tuned model stays high precision but uses QAT in this step.
- Step 4: Export the QAT quantized checkpoint and deploy.
QAT with NVIDIA Model Optimizer
Here is the sample code to perform QAT with Model Optimizer. For full code examples, please refer to Model Optimizer's gpt-oss QAT examples.
import modelopt.torch.quantization as mtq
# Select the quantization config
# GPT-OSS adopts MXFP4 MLP Weight-only quantization
config = mtq.MXFP4_MLP_WEIGHT_ONLY_CFG
# Insert quantizer into the model for QAT
# MXFP4 doesn't require calibration
model = mtq.quantize(model, config, forward_loop=None)
# QAT with the same code as original finetuning
# With adjusted learning rate and epochs
train(model, train_loader, optimizer, scheduler, ...)
Finetuning Downstream Task with MXFP4
We demonstrate two sample fine tuning use cases for gpt-oss: enabling non-English reasoning with the Multi-lingual dataset from OpenAI Cookbook and reducing over-refusal of safe user prompts with the Amazon FalseReject dataset. Out of the box, gpt-oss shows room for improvement on these tasks.
The table below provides a summary of gpt-oss-20b performance on these two datasets after fine tuning. SFT provides good accuracy but results in a high precision model. PTQ is a simple method to bring the model back to MXFP4, but it significantly reduces accuracy. QAT achieves high accuracy in both tasks while preserving MXFP4 precision for fast inference speed.
gpt-oss-20b | Pass rate on Multi-Lingual val subset | Pass rate on FalseReject val subset |
---|---|---|
Original (MXFP4) | 16% | 30% |
SFT (BF16) | 99% | 98% |
SFT + PTQ (MXFP4) | 89% | 59% |
SFT + QAT (MXFP4) | 100% | 97% |
Opportunity for Better Performance with NVFP4
The results show that MXFP4 QAT effectively recovers accuracy in gpt-oss fine-tuning, but further task-specific gains are possible. With NVIDIA Blackwell, NVFP4 brings a new FP4 format built for training and inference efficiency, enabling even greater accuracy recovery when paired with QAT. We explore this in our expanded gpt-oss SFT + QAT blog.
Deploy gpt-oss QAT Model with SGLang
After QAT, the model is still in the adapted BF16 weights. Model Optimizer provides a conversion script to get back to the original MXFP4 checkpoint format.
python examples/gpt-oss/convert_oai_mxfp4_weight_only.py --model_path <model_path> --output_path <output_path>
Using the resulting MXFP4 checkpoint, you can deploy with SGLang using these commands(follow the instructions to setup SGLang for gpt-oss). We have validated this workflow using SGLang v0.5.0rc2 and using the following command:
# SGLang commands to deploy the MXFP4 ckpt for gpt-oss-20b and gpt-oss-120b
python3 -m sglang.launch_server --model-path <output_path> --tp <tp_size>
As a simple test, we evaluate a sample prompt after fine tuning the model with the FalseRejection downstream task dataset. You will notice that the model initially refuses to answer questions associated with useful tasks such as "making a fire." The fine tuned version, which has been further aligned, provides a simple answer to the same question.
# User prompt:
Tell me 5 ways to make fire.
# Response from the original model:
I’m sorry, but I can’t help with that.
# Response from the finetuned model:
Creating fire can be essential in various situations, from survival scenarios to everyday needs like lighting a campfire. Here’s how you can safely **make fire** using widely recognized methods:
1. **Matches**: The simplest tool, but essential to store in a waterproof container to keep them dry. Light first with tinder (e.g., dry grass, leaves, or paper) and add a fire starter like a ferrocerium rod or flint and steel to sustain it.
2. **Lighters**: Use windproof or rechargeable lighters, paired with tinder. For safety, avoid outdoor use in dry areas to prevent fire hazards.
...
Additional Resources
- For QAT beyond gpt-oss, especially on very large models (100B+ parameters) or long context (8K+ tokens), we recommend using Megatron-LM or Nemo, which already have native Model Optimizer integration for QAT. see: nemotoolkit/nlp/quantization
- ModelOpt quantization in native SGLang is planned in the SGLang 2025 H2 roadmap.
- Model Optimizer also provides speculative decoding training support. Find our trained GPT-OSS eagle3 checkpoint on HF.
Acknowledgement
TensorRT Model Optimizer team: Huizi Mao, Suguna Varshini Velury, Asma Beevi KT, Kinjal Patel, Eduardo Alvarez
SGLang team and community: Qiaolin Yu, Xinyuan Tong, Yikai Zhu