ArcticTraining: Simplifying and Accelerating Post-Training for LLMs
Post-training is a powerful method for advancing the capabilities of large language models (LLMs), enabling enhanced functionalities such as code generation, instruction following and complex reasoning. Achieving successful post-training requires two essential components: a scalable, adaptable training framework for algorithmic innovation; and a high-quality, user-friendly synthetic data generation pipeline.
Many current training frameworks (such as TRL, Axolotl and others) depend heavily on the Hugging Face Trainer class, known for its extensive coverage of various algorithms (e.g., PEFT) and training backends (e.g., DeepSpeed, PyTorch FSDP, bitsandbytes). However, the rapid pace of community growth sometimes turns this broad coverage into a bottleneck, complicating the process of quickly prototyping new algorithms and modeling ideas due to the complexity of the code structure.
Also, existing training frameworks often lack native support for data generation using popular cloud model providers. Consequently, users frequently have to resort to external platforms to generate high-quality synthetic data, which can slow down prototype development and hinder efficient exploration of new ideas.
To address these challenges, we present ArcticTraining — a framework designed to simplify and accelerate the post-training experience. With ArcticTraining, users gain:
Flexible trainer design: Trainers are modular, easy to modify and easy to extend, enabling rapid prototyping of new algorithms with minimal effort.
Simplified code structure: Shallow abstraction layers reduce complexity, while offering a robust callback system for deep customization.
Native data generation pipelines: Both data creation and cleaning are supported, using LLM-based or reward-model-based judgment.
In the blog, we will explore the core components of ArcticTraining and showcase its capabilities, from basic to advanced case studies. Finally, we’ll outline our roadmap for upcoming features and invite collaboration to make ArcticTraining even more powerful and versatile.
Architecture
ArcticTraining is designed to simplify and accelerate the prototyping of LLM post-training workflows. At its core, the library emphasizes trainers and data generation as the foundation for creating new training algorithms, while providing intuitive interfaces to customize any aspect of training.
A trainer in ArcticTraining is built using factory objects, or “building blocks,” that collectively define a training workflow. Each building block produces a training asset, such as a model, data set, optimizer or checkpoint engine, which the trainer uses to construct and execute the training pipeline. Data factories, for example, can pull from multiple sources to construct training and evaluation data sets, including from the ArcticTraining data generation pipeline.
The ability to swap and customize these building blocks is central to the ArcticTraining design. Each block is built around a well-defined, narrow interface that ensures seamless interaction with other components. Figure 2 highlights these interfaces for the trainer and model factory classes, showcasing attributes that define core functionality, properties that track runtime state and methods that enable interobject communication.
By adhering to these interfaces, building blocks integrate seamlessly into the training loop, keeping code for tasks like checkpointing, modeling and data processing, well-organized and independent. This modularity not only reduces code complexity but also allows users to easily swap components via YAML-defined training recipes. Developers can experiment with new configurations by simply replacing or extending specific blocks while still maintaining compatibility with the broader workflow.
The use of these building blocks eliminates much of the boilerplate code associated with writing custom trainers. Developers can focus on small, targeted code blocks to modify or extend existing methods and properties, enabling rapid prototyping and experimentation without sacrificing cohesion or correctness.
For deeper customization, ArcticTraining offers a robust callback system that wraps class methods. This system allows users to extend existing classes by modifying the inputs and outputs of interface methods via callback functions. Combined with the modular building blocks, these callbacks provide the flexibility to experiment freely while ensuring that training workflows remain robust and reliable.
ArcticTraining delivers all these features through a user-friendly CLI, where a YAML configuration file defines the training recipe. Custom code, including extensions of existing building blocks, can be registered via a Python script referenced in the YAML file. For even greater flexibility, ArcticTraining also provides a Python API, empowering users to tailor their workflows to meet specific needs.
Basic use case 1: SFT Trainer
The ArcticTraining library provides all the essential building blocks for creating custom training pipelines, along with a ready-to-use supervised fine-tuning (SFT) trainer included in the core library. The SFT trainer serves as a showcase for how minimal changes are needed to implement a custom trainer. By defining only the loss function, users can leverage all other components from the core building blocks without additional customization.
Using the SFT trainer is straightforward. Users specify the base model, input data sets and desired hyperparameters for training in a YAML configuration file. This file acts as a concise recipe for the training run, requiring no additional code. For example, to kick off training with the example configuration file below, users simply pass the file to the ArcticTraining launcher with the command: arctic_training sft-llama-8b.yaml
type: sft
micro_batch_size: 2
model:
name_or_path: meta-llama/Llama-3.1-8B-Instruct
# DeepSpeed config is passed through to DeepSpeed init
deepspeed:
zero_optimization:
stage: 2
# Describe what datasets to train over
data:
sources:
- HuggingFaceH4/ultrachat_200k
- meta-math/MetaMathQA
# Set optimizer hyperparameters
optimizer:
lr: 1e-5
# Describe checkpoint block
checkpoint:
- type: huggingface
output_dir: <output-path>
save_end_of_training: true
This simple example demonstrates how ArcticTraining abstracts the complexities of building a training pipeline, enabling users to focus on their tasks without writing extensive boilerplate code.
Basic use case 2: Data generation
Existing training frameworks often focus on algorithm coverage, utilizing methods such as SFT, direct preference optimization (DPO), and reinforcement learning from human feedback (RLHF). However, data generation is frequently overlooked, despite its crucial role in adapting to evolving user preferences and enabling new capabilities.
For instance, as preferences shift, data regeneration and creation become essential for maintaining optimal performance. Consider the question: "Is a cat or a dog better for pet raising?" In early 2023, the standard AI response was, "I do not know since I am a virtual AI assistant." However, by the time of this publication, the response had evolved to, "Whether a cat or a dog is a better pet depends on your lifestyle, preferences, and what you're looking for in a pet. Here are some factors to consider for each…" (as provided by GPT-4o).
To achieve the desired outcomes, it is necessary to recreate existing data and generate new data for emerging algorithms, such as Monte Carlo tree search (MCTS), DPO and new chain-of-thought (CoT) styles. To bridge the gap between training and data generation, ArcticTraining offers built-in data generation support. This includes prompt-based data generation through both cloud services (like OpenAI batch inference) and local model services (such as vLLM). Additionally, ArcticTraining uses LLMs as judges to ensure data quality and for data cleaning purposes.
Below is a simple example to illustrate how these features can be utilized effectively.
from arctic_training.synth import OpenAISynth, AzureOpenAISynth, VllmSynth, CortexSynth
# If we use Snowflake Cortex complete
client = CortexSynth(connection_params=...)
# If we use Azure OpenAI
client = AzureOpenAISynth(api_key=...)
# Other supported backends
client = OpenAISynth(api_key=...)
client = VllmSynth(model_params=...)
for name in ["Clara", "Justin", "Terry"]:
client.add_chat_to_batch_task(
task_name="say_hello",
model="gpt-4o",
messages=[
{"role": "user", "content": f"Generate greetings to {name}"},
],
)
extracted_messages = client.extract_messages_from_responses(
client.execute_batch_task("say_hello")
)
print(extracted_messages)
Advanced trainer case studies
To show how ArcticTraining can handle more complex training setups, we’ll look at two examples: SwiftKV and speculative decoder. While the SFT trainer only needs a single change to the base building blocks, these trainers involve deeper modifications, such as customizing the model construction and the training loop.
In both cases, ArcticTraining makes it easy to customize these components by letting you extend the base functionality. Using callbacks and selectively redefining class methods, you can add the specific behavior you need without breaking the rest of the framework. These examples highlight how flexible ArcticTraining can be when tackling more advanced training challenges.
SwiftKV trainer
SwiftKV, a novel approach to reducing LLM computation, is designed to significantly cut down both prefill compute and memory requirements during inference. For more details on how SwiftKV achieves these impressive results, see our previously published SwiftKV blog. In summary, SwiftKV approximates the key-value (KV) cache of the later transformer layers, resulting in a model transformation that reduces prefill computation by half and memory usage by 62.5%, all with minimal quality impact. This makes it an excellent choice for scaling inference in resource-constrained scenarios.
To train SwiftKV models using ArcticTraining, a user creates a train.py
that provides an implementation of the SwiftKVTrainer
class. The class defines and registers all the relevant components required for training a SwiftKV model, such as:
Model factory: Define and register the SwiftKV model factory to load the SwiftKV model. (This ensures specific weights are frozen, and the SwiftKV params are initialized correctly.)
Loss function: Define the SwiftKV loss function by overriding the SFT trainer loss method.
Configuration: Define and register the SwiftKV config that allows users to configure different SwiftKV hyperparameters.
Code example:
# train.py
from arctic_training import ModelConfig, SFTTrainer, HFModelFactory, register
class SwiftKVModelConfig(ModelConfig):
num_key_value_layers: int
key_value_group_size: int
class SwiftKVModelFactory(HFModelFactory):
name = "swiftkv"
config_type = SwiftKVModelConfig
def post_create_model_callback(self, model) -> PreTrainedModel:
...
# Trainer and ModelFactory are registered in ArcticTraining
@register
class SwiftKVTrainer(SFTTrainer):
name = "swiftkv"
model_factory_type = SwiftKVModelFactory
# Define custom loss method
def loss(self, batch) -> torch.Tensor:
...
The specific components of the YAML for SwiftKV is the trainer type swiftkv
, and the two model-specific parameters, num_key_value_layers
and key_value_group_size
. Otherwise the YAML is essentially the same as the previous SFT example YAML.
YAML example:
type: swiftkv
micro_batch_size: 2
model:
name_or_path: meta-llama/Llama-3.1-8B-Instruct
num_key_value_layers: 16
key_value_group_size: 1
deepspeed:
zero_optimization:
stage: 2
data:
sources:
- HuggingFaceH4/ultrachat_200k
- meta-math/MetaMathQA
optimizer:
lr: 1e-5
checkpoint:
- type: huggingface
output_dir: <output-path>
save_end_of_training: true
Speculative decoder trainer
Speculative decoding is a technique used to speed up inference for an LLM by speculating multiple output tokens with a small and fast speculator or draft model, then verifying it with the LLM. Speculative decoding leverages the fact that verification is faster than generation to speed up inference.
A good speculator or draft model needs its outputs to match closely with the output distribution of the original model. A good way to achieve this is to train the draft model on synthetic data generated by the original model, using a knowledge distillation process where the original model acts as a teacher.
The simplicity of ArcticTraining not only makes it simple to implement the knowledge-distillation pipeline with custom-loss function, but it also makes it possible to do the synthetic data generation on the fly during the training process itself.
With this release, we share an example speculative decoder trainer that implements a type of speculator called MLP-Speculator, where:
We define and register everything like SwiftKV (model, config, loss) but also override the core training loop via step (see here for details).
We retain all existing callbacks around logging, metric reporting, etc., even though the training loop is overridden.
The robust callback system in ArcticTraining works even on methods that are redefined in inheriting classes.
# train.py
from arctic_training import ModelConfig, SFTTrainer, HFModelFactory, register
class SpecDecodeModelConfig(ModelConfig):
speculator_path: str
gen_prompt_length: int = 64
gen_seq_length: int = 256
...
class SpecDecodeModelFactory(HFModelFactory):
name = "spec-decode"
config_type = SpecDecodeModelConfig
def post_create_model_callback(self, model) -> PreTrainedModel:
...
# Trainer and ModelFactory are registered in ArcticTraining
@register
class SpecDecodeTrainer(SFTTrainer):
name = "spec-decode"
model_factory_type = SpecDecodeModelFactory
def loss(self, batch) -> torch.Tensor:
...
# Overwrite the main training loop
def step(self, batch) -> None:
...
The specific components of the YAML for speculative decoding is the trainer type, spec-decode
, and any specific-generation-related parameters (e.g., gen_prompt_length
and gen_seq_length
, which control the prompt length and sequence length, respectively, for synthetic data generation). Otherwise the YAML is essentially the same as the previous SFT and SwiftKV example YAMLs.
YAML config file:
type: spec-decode
micro_batch_size: 2
model:
name_or_path: meta-llama/Llama-3.1-8B-Instruct
gen_prompt_length: 64
gen_seq_length: 256
deepspeed:
zero_optimization:
stage: 2
data:
sources:
- HuggingFaceH4/ultrachat_200k
- meta-math/MetaMathQA
optimizer:
lr: 1e-5
checkpoint:
- type: huggingface
output_dir: <output-path>
save_end_of_training: true
Benefits of ArcticTraining
The versatility of ArcticTraining shines through in both simple and advanced use cases:
Simplified training pipelines: From SFT to advanced trainers, like SwiftKV and speculative decoders, ArcticTraining minimizes boilerplate code and accelerates setup.
Unified framework: Consistent YAML configurations and modular components make ArcticTraining a one-stop solution for diverse needs, from data generation to algorithm development.
Customizability: Whether extending a loss function or rewriting the training loop, ArcticTraining’s design ensures users can implement their vision with minimal disruption.
Robust ecosystem: A callback system and built-in support for integration with cloud and local services ensure reliability across different setups.
Scalability: ArcticTraining handles everything from small-scale fine-tuning to enterprise-grade optimizations with ease.
By addressing the full lifecycle of training — from data generation to advanced trainer customization — ArcticTraining empowers researchers and developers to focus on innovation and achieve their goals faster.
Getting started
Explore the repository and documentation: Visit the ArcticTraining GitHub repository for source code, examples and quickstart guides. The official documentation provides in-depth tutorials and resources to help you kick off your training workflows.
Roadmap and upcoming features: We’re continuously enhancing ArcticTraining with exciting features, including:
Direct preference optimization (DPO) Trainer: Simplify preference-based training
Preference data generation and verification: Generate and validate high-quality training data
Long-sequence training support: Efficiently train models on extended contexts
Multimodal training support: Seamlessly integrate text, image and other data modalities
Join the collaboration: We welcome feedback, contributions and collaboration! Whether you have feature requests, ideas for improvements or want to contribute code, we’d love to hear from you via GitHub and LinkedIn.
Contributors:
Snowflake AI Research: Zhewei Yao, Jeff Rasley, Michael Wyatt, Canwen Xu, Samyam Rajbhandari, Aurick Qiao, Yuxiong He