rl
A modular, primitive-first, python-first PyTorch library for Reinforcement Learning.
Stars: 3062
TorchRL is an open-source Reinforcement Learning (RL) library for PyTorch. It provides pytorch and **python-first** , low and high level abstractions for RL that are intended to be **efficient** , **modular** , **documented** and properly **tested**. The code is aimed at supporting research in RL. Most of it is written in python in a highly modular way, such that researchers can easily swap components, transform them or write new ones with little effort.
README:
What's New | LLM API | Getting Started | Documentation | TensorDict | Features | Examples, tutorials and demos | Citation | Installation | Asking a question | Contributing
TorchRL is an open-source Reinforcement Learning (RL) library for PyTorch.
TorchRL now provides a powerful command-line interface that lets you train state-of-the-art RL agents with simple bash commands! No Python scripting required - just run training with customizable parameters:
- ๐ฏ One-Command Training:
python sota-implementations/ppo_trainer/train.py - โ๏ธ Full Customization: Override any parameter via command line:
trainer.total_frames=2000000 optimizer.lr=0.0003 - ๐ Multi-Environment Support: Switch between Gym, Brax, DM Control, and more with
env=gym training_env.create_env_fn.base_env.env_name=HalfCheetah-v4 - ๐ Built-in Logging: TensorBoard, Weights & Biases, CSV logging out of the box
- ๐ง Hydra-Powered: Leverages Hydra's powerful configuration system for maximum flexibility
- ๐โโ๏ธ Production Ready: Same robust training pipeline as our SOTA implementations
Perfect for: Researchers, practitioners, and anyone who wants to train RL agents without diving into implementation details.
๐ Prerequisites: The training interface requires Hydra for configuration management. Install with:
pip install "torchrl[utils]"
# or manually:
pip install hydra-core omegaconfCheck out the complete CLI documentation to get started!
This release introduces a comprehensive revamp of TorchRL's vLLM integration, delivering significant improvements in performance, scalability, and usability for large language model inference and training workflows:
- ๐ฅ AsyncVLLM Service: Production-ready distributed vLLM inference with multi-replica scaling and automatic Ray actor management
- โ๏ธ Multiple Load Balancing Strategies: Routing strategies including prefix-aware, request-based, and KV-cache load balancing for optimal performance
- ๐๏ธ Unified vLLM Architecture: New
RLvLLMEngineinterface standardizing all vLLM backends with simplifiedvLLMUpdaterV2for seamless weight updates - ๐ Distributed Data Loading: New
RayDataLoadingPrimerfor shared, distributed data loading across multiple environments - ๐ Enhanced Performance: Native vLLM batching, concurrent request processing, and optimized resource allocation via Ray placement groups
# Simple AsyncVLLM usage - production ready!
from torchrl.modules.llm import AsyncVLLM, vLLMWrapper
# Create distributed vLLM service with load balancing
service = AsyncVLLM.from_pretrained(
"Qwen/Qwen2.5-7B",
num_devices=2, # Tensor parallel across 2 GPUs
num_replicas=4, # 4 replicas for high throughput
max_model_len=4096
)
# Use with TorchRL's LLM wrappers
wrapper = vLLMWrapper(service, input_mode="history")
# Simplified weight updates
from torchrl.collectors.llm import vLLMUpdaterV2
updater = vLLMUpdaterV2(service) # Auto-configures from engineThis revamp positions TorchRL as the leading platform for scalable LLM inference and training, providing production-ready tools for both research and deployment scenarios.
TorchRL now includes an experimental PPOTrainer that provides a complete, configurable PPO training solution! This prototype feature combines TorchRL's modular components into a cohesive training system with sensible defaults:
- ๐ฏ Complete Training Pipeline: Handles environment setup, data collection, loss computation, and optimization automatically
- โ๏ธ Extensive Configuration: Comprehensive Hydra-based config system for easy experimentation and hyperparameter tuning
- ๐ Built-in Logging: Automatic tracking of rewards, actions, episode completion rates, and training statistics
- ๐ง Modular Design: Built on existing TorchRL components (collectors, losses, replay buffers) for maximum flexibility
- ๐ Minimal Code: Complete SOTA implementation in just ~20 lines!
Working Example: See sota-implementations/ppo_trainer/ for a complete, working PPO implementation that trains on Pendulum-v1 with full Hydra configuration support.
Prerequisites: Requires Hydra for configuration management: pip install "torchrl[utils]"
Complete Training Script (sota-implementations/ppo_trainer/train.py)
import hydra
from torchrl.trainers.algorithms.configs import *
@hydra.main(config_path="config", config_name="config", version_base="1.1")
def main(cfg):
trainer = hydra.utils.instantiate(cfg.trainer)
trainer.train()
if __name__ == "__main__":
main()Complete PPO training in ~20 lines with full configurability.
API Usage Examples
# Basic usage - train PPO on Pendulum-v1 with default settings
python sota-implementations/ppo_trainer/train.py
# Custom configuration with command-line overrides
python sota-implementations/ppo_trainer/train.py \
trainer.total_frames=2000000 \
training_env.create_env_fn.base_env.env_name=HalfCheetah-v4 \
networks.policy_network.num_cells=[256,256] \
optimizer.lr=0.0003
# Use different environment and logger
python sota-implementations/ppo_trainer/train.py \
env=gym \
training_env.create_env_fn.base_env.env_name=Walker2d-v4 \
logger=tensorboard
# See all available options
python sota-implementations/ppo_trainer/train.py --helpFuture Plans: Additional algorithm trainers (SAC, TD3, DQN) and full integration of all TorchRL components within the configuration system are planned for upcoming releases.
TorchRL includes a comprehensive LLM API for post-training and fine-tuning of language models! This framework provides everything you need for RLHF, supervised fine-tuning, and tool-augmented training:
- ๐ค Unified LLM Wrappers: Seamless integration with Hugging Face models and vLLM inference engines
- ๐ฌ Conversation Management: Advanced
Historyclass for multi-turn dialogue with automatic chat template detection - ๐ ๏ธ Tool Integration: Built-in support for Python code execution, function calling, and custom tool transforms
- ๐ฏ Specialized Objectives: GRPO (Group Relative Policy Optimization) and SFT loss functions optimized for language models
- โก High-Performance Collectors: Async data collection with distributed training support
- ๐ Flexible Environments: Transform-based architecture for reward computation, data loading, and conversation augmentation
The LLM API follows TorchRL's modular design principles, allowing you to mix and match components for your specific use case. Check out the complete documentation and GRPO implementation example to get started!
Quick LLM API Example
from torchrl.envs.llm import ChatEnv
from torchrl.modules.llm import TransformersWrapper
from torchrl.objectives.llm import GRPOLoss
from torchrl.collectors.llm import LLMCollector
# Create environment with Python tool execution
env = ChatEnv(
tokenizer=tokenizer,
system_prompt="You are an assistant that can execute Python code.",
batch_size=[1]
).append_transform(PythonInterpreter())
# Wrap your language model
llm = TransformersWrapper(
model=model,
tokenizer=tokenizer,
input_mode="history"
)
# Set up GRPO training
loss_fn = GRPOLoss(llm, critic, gamma=0.99)
collector = LLMCollector(env, llm, frames_per_batch=100)
# Training loop
for data in collector:
loss = loss_fn(data)
loss.backward()
optimizer.step()- ๐ Python-first: Designed with Python as the primary language for ease of use and flexibility
- โฑ๏ธ Efficient: Optimized for performance to support demanding RL research applications
- ๐งฎ Modular, customizable, extensible: Highly modular architecture allows for easy swapping, transformation, or creation of new components
- ๐ Documented: Thorough documentation ensures that users can quickly understand and utilize the library
- โ Tested: Rigorously tested to ensure reliability and stability
- โ๏ธ Reusable functionals: Provides a set of highly reusable functions for cost functions, returns, and data processing
- ๐ฅ Aligns with PyTorch ecosystem: Follows the structure and conventions of popular PyTorch libraries (e.g., dataset pillar, transforms, models, data utilities)
- โ Minimal dependencies: Only requires Python standard library, NumPy, and PyTorch; optional dependencies for common environment libraries (e.g., OpenAI Gym) and datasets (D4RL, OpenX...)
Read the full paper for a more curated description of the library.
Check our Getting Started tutorials for quickly ramp up with the basic features of the library!
The TorchRL documentation can be found here. It contains tutorials and the API reference.
TorchRL also provides a RL knowledge base to help you debug your code, or simply learn the basics of RL. Check it out here.
We have some introductory videos for you to get to know the library better, check them out:
TorchRL being domain-agnostic, you can use it across many different fields. Here are a few examples:
- ACEGEN: Reinforcement Learning of Generative Chemical Agents for Drug Discovery
- BenchMARL: Benchmarking Multi-Agent Reinforcement Learning
- BricksRL: A Platform for Democratizing Robotics and Reinforcement Learning Research and Education with LEGO
- OmniDrones: An Efficient and Flexible Platform for Reinforcement Learning in Drone Control
- RL4CO: an Extensive Reinforcement Learning for Combinatorial Optimization Benchmark
- Robohive: A unified framework for robot learning
RL algorithms are very heterogeneous, and it can be hard to recycle a codebase
across settings (e.g. from online to offline, from state-based to pixel-based
learning).
TorchRL solves this problem through TensorDict,
a convenient data structure(1) that can be used to streamline one's
RL codebase.
With this tool, one can write a complete PPO training script in less than 100
lines of code!
Code
import torch
from tensordict.nn import TensorDictModule
from tensordict.nn.distributions import NormalParamExtractor
from torch import nn
from torchrl.collectors import SyncDataCollector
from torchrl.data.replay_buffers import TensorDictReplayBuffer, \
LazyTensorStorage, SamplerWithoutReplacement
from torchrl.envs.libs.gym import GymEnv
from torchrl.modules import ProbabilisticActor, ValueOperator, TanhNormal
from torchrl.objectives import ClipPPOLoss
from torchrl.objectives.value import GAE
env = GymEnv("Pendulum-v1")
model = TensorDictModule(
nn.Sequential(
nn.Linear(3, 128), nn.Tanh(),
nn.Linear(128, 128), nn.Tanh(),
nn.Linear(128, 128), nn.Tanh(),
nn.Linear(128, 2),
NormalParamExtractor()
),
in_keys=["observation"],
out_keys=["loc", "scale"]
)
critic = ValueOperator(
nn.Sequential(
nn.Linear(3, 128), nn.Tanh(),
nn.Linear(128, 128), nn.Tanh(),
nn.Linear(128, 128), nn.Tanh(),
nn.Linear(128, 1),
),
in_keys=["observation"],
)
actor = ProbabilisticActor(
model,
in_keys=["loc", "scale"],
distribution_class=TanhNormal,
distribution_kwargs={"low": -1.0, "high": 1.0},
return_log_prob=True
)
buffer = TensorDictReplayBuffer(
storage=LazyTensorStorage(1000),
sampler=SamplerWithoutReplacement(),
batch_size=50,
)
collector = SyncDataCollector(
env,
actor,
frames_per_batch=1000,
total_frames=1_000_000,
)
loss_fn = ClipPPOLoss(actor, critic)
adv_fn = GAE(value_network=critic, average_gae=True, gamma=0.99, lmbda=0.95)
optim = torch.optim.Adam(loss_fn.parameters(), lr=2e-4)
for data in collector: # collect data
for epoch in range(10):
adv_fn(data) # compute advantage
buffer.extend(data)
for sample in buffer: # consume data
loss_vals = loss_fn(sample)
loss_val = sum(
value for key, value in loss_vals.items() if
key.startswith("loss")
)
loss_val.backward()
optim.step()
optim.zero_grad()
print(f"avg reward: {data['next', 'reward'].mean().item(): 4.4f}")Here is an example of how the environment API
relies on tensordict to carry data from one function to another during a rollout
execution:

TensorDict makes it easy to re-use pieces of code across environments, models and
algorithms.
Code
For instance, here's how to code a rollout in TorchRL:
- obs, done = env.reset()
+ tensordict = env.reset()
policy = SafeModule(
model,
in_keys=["observation_pixels", "observation_vector"],
out_keys=["action"],
)
out = []
for i in range(n_steps):
- action, log_prob = policy(obs)
- next_obs, reward, done, info = env.step(action)
- out.append((obs, next_obs, action, log_prob, reward, done))
- obs = next_obs
+ tensordict = policy(tensordict)
+ tensordict = env.step(tensordict)
+ out.append(tensordict)
+ tensordict = step_mdp(tensordict) # renames next_observation_* keys to observation_*
- obs, next_obs, action, log_prob, reward, done = [torch.stack(vals, 0) for vals in zip(*out)]
+ out = torch.stack(out, 0) # TensorDict supports multiple tensor operationsUsing this, TorchRL abstracts away the input / output signatures of the modules, env, collectors, replay buffers and losses of the library, allowing all primitives to be easily recycled across settings.
Code
Here's another example of an off-policy training loop in TorchRL (assuming that a data collector, a replay buffer, a loss and an optimizer have been instantiated):
- for i, (obs, next_obs, action, hidden_state, reward, done) in enumerate(collector):
+ for i, tensordict in enumerate(collector):
- replay_buffer.add((obs, next_obs, action, log_prob, reward, done))
+ replay_buffer.add(tensordict)
for j in range(num_optim_steps):
- obs, next_obs, action, hidden_state, reward, done = replay_buffer.sample(batch_size)
- loss = loss_fn(obs, next_obs, action, hidden_state, reward, done)
+ tensordict = replay_buffer.sample(batch_size)
+ loss = loss_fn(tensordict)
loss.backward()
optim.step()
optim.zero_grad()This training loop can be re-used across algorithms as it makes a minimal number of assumptions about the structure of the data.
TensorDict supports multiple tensor operations on its device and shape (the shape of TensorDict, or its batch size, is the common arbitrary N first dimensions of all its contained tensors):
Code
# stack and cat
tensordict = torch.stack(list_of_tensordicts, 0)
tensordict = torch.cat(list_of_tensordicts, 0)
# reshape
tensordict = tensordict.view(-1)
tensordict = tensordict.permute(0, 2, 1)
tensordict = tensordict.unsqueeze(-1)
tensordict = tensordict.squeeze(-1)
# indexing
tensordict = tensordict[:2]
tensordict[:, 2] = sub_tensordict
# device and memory location
tensordict.cuda()
tensordict.to("cuda:1")
tensordict.share_memory_()TensorDict comes with a dedicated tensordict.nn
module that contains everything you might need to write your model with it.
And it is functorch and torch.compile compatible!
Code
transformer_model = nn.Transformer(nhead=16, num_encoder_layers=12)
+ td_module = SafeModule(transformer_model, in_keys=["src", "tgt"], out_keys=["out"])
src = torch.rand((10, 32, 512))
tgt = torch.rand((20, 32, 512))
+ tensordict = TensorDict({"src": src, "tgt": tgt}, batch_size=[20, 32])
- out = transformer_model(src, tgt)
+ td_module(tensordict)
+ out = tensordict["out"]The TensorDictSequential class allows to branch sequences of nn.Module instances in a highly modular way.
For instance, here is an implementation of a transformer using the encoder and decoder blocks:
encoder_module = TransformerEncoder(...)
encoder = TensorDictSequential(encoder_module, in_keys=["src", "src_mask"], out_keys=["memory"])
decoder_module = TransformerDecoder(...)
decoder = TensorDictModule(decoder_module, in_keys=["tgt", "memory"], out_keys=["output"])
transformer = TensorDictSequential(encoder, decoder)
assert transformer.in_keys == ["src", "src_mask", "tgt"]
assert transformer.out_keys == ["memory", "output"]TensorDictSequential allows to isolate subgraphs by querying a set of desired input / output keys:
transformer.select_subsequence(out_keys=["memory"]) # returns the encoder
transformer.select_subsequence(in_keys=["tgt", "memory"]) # returns the decoderCheck TensorDict tutorials to learn more!
-
A common interface for environments which supports common libraries (OpenAI gym, deepmind control lab, etc.)(1) and state-less execution (e.g. Model-based environments). The batched environments containers allow parallel execution(2). A common PyTorch-first class of tensor-specification class is also provided. TorchRL's environments API is simple but stringent and specific. Check the documentation and tutorial to learn more!
Code
env_make = lambda: GymEnv("Pendulum-v1", from_pixels=True) env_parallel = ParallelEnv(4, env_make) # creates 4 envs in parallel tensordict = env_parallel.rollout(max_steps=20, policy=None) # random rollout (no policy given) assert tensordict.shape == [4, 20] # 4 envs, 20 steps rollout env_parallel.action_spec.is_in(tensordict["action"]) # spec check returns True
-
multiprocess and distributed data collectors(2) that work synchronously or asynchronously. Through the use of TensorDict, TorchRL's training loops are made very similar to regular training loops in supervised learning (although the "dataloader" -- read data collector -- is modified on-the-fly):
Code
env_make = lambda: GymEnv("Pendulum-v1", from_pixels=True) collector = MultiaSyncDataCollector( [env_make, env_make], policy=policy, devices=["cuda:0", "cuda:0"], total_frames=10000, frames_per_batch=50, ... ) for i, tensordict_data in enumerate(collector): loss = loss_module(tensordict_data) loss.backward() optim.step() optim.zero_grad() collector.update_policy_weights_()
Check our distributed collector examples to learn more about ultra-fast data collection with TorchRL.
-
efficient(2) and generic(1) replay buffers with modularized storage:
Code
storage = LazyMemmapStorage( # memory-mapped (physical) storage cfg.buffer_size, scratch_dir="/tmp/" ) buffer = TensorDictPrioritizedReplayBuffer( alpha=0.7, beta=0.5, collate_fn=lambda x: x, pin_memory=device != torch.device("cpu"), prefetch=10, # multi-threaded sampling storage=storage )
Replay buffers are also offered as wrappers around common datasets for offline RL:
Code
from torchrl.data.replay_buffers import SamplerWithoutReplacement from torchrl.data.datasets.d4rl import D4RLExperienceReplay data = D4RLExperienceReplay( "maze2d-open-v0", split_trajs=True, batch_size=128, sampler=SamplerWithoutReplacement(drop_last=True), ) for sample in data: # or alternatively sample = data.sample() fun(sample)
-
cross-library environment transforms(1), executed on device and in a vectorized fashion(2), which process and prepare the data coming out of the environments to be used by the agent:
Code
env_make = lambda: GymEnv("Pendulum-v1", from_pixels=True) env_base = ParallelEnv(4, env_make, device="cuda:0") # creates 4 envs in parallel env = TransformedEnv( env_base, Compose( ToTensorImage(), ObservationNorm(loc=0.5, scale=1.0)), # executes the transforms once and on device ) tensordict = env.reset() assert tensordict.device == torch.device("cuda:0")
Other transforms include: reward scaling (
RewardScaling), shape operations (concatenation of tensors, unsqueezing etc.), concatenation of successive operations (CatFrames), resizing (Resize) and many more.Unlike other libraries, the transforms are stacked as a list (and not wrapped in each other), which makes it easy to add and remove them at will:
env.insert_transform(0, NoopResetEnv()) # inserts the NoopResetEnv transform at the index 0
Nevertheless, transforms can access and execute operations on the parent environment:
transform = env.transform[1] # gathers the second transform of the list parent_env = transform.parent # returns the base environment of the second transform, i.e. the base env + the first transform
-
various tools for distributed learning (e.g. memory mapped tensors)(2);
-
various architectures and models (e.g. actor-critic)(1):
Code
# create an nn.Module common_module = ConvNet( bias_last_layer=True, depth=None, num_cells=[32, 64, 64], kernel_sizes=[8, 4, 3], strides=[4, 2, 1], ) # Wrap it in a SafeModule, indicating what key to read in and where to # write out the output common_module = SafeModule( common_module, in_keys=["pixels"], out_keys=["hidden"], ) # Wrap the policy module in NormalParamsWrapper, such that the output # tensor is split in loc and scale, and scale is mapped onto a positive space policy_module = SafeModule( NormalParamsWrapper( MLP(num_cells=[64, 64], out_features=32, activation=nn.ELU) ), in_keys=["hidden"], out_keys=["loc", "scale"], ) # Use a SafeProbabilisticTensorDictSequential to combine the SafeModule with a # SafeProbabilisticModule, indicating how to build the # torch.distribution.Distribution object and what to do with it policy_module = SafeProbabilisticTensorDictSequential( # stochastic policy policy_module, SafeProbabilisticModule( in_keys=["loc", "scale"], out_keys="action", distribution_class=TanhNormal, ), ) value_module = MLP( num_cells=[64, 64], out_features=1, activation=nn.ELU, ) # Wrap the policy and value funciton in a common module actor_value = ActorValueOperator(common_module, policy_module, value_module) # standalone policy from this standalone_policy = actor_value.get_policy_operator()
-
exploration wrappers and modules to easily swap between exploration and exploitation(1):
Code
policy_explore = EGreedyWrapper(policy) with set_exploration_type(ExplorationType.RANDOM): tensordict = policy_explore(tensordict) # will use eps-greedy with set_exploration_type(ExplorationType.DETERMINISTIC): tensordict = policy_explore(tensordict) # will not use eps-greedy
-
A series of efficient loss modules and highly vectorized functional return and advantage computation.
Code
from torchrl.objectives import DQNLoss loss_module = DQNLoss(value_network=value_network, gamma=0.99) tensordict = replay_buffer.sample(batch_size) loss = loss_module(tensordict)
from torchrl.objectives.value.functional import vec_td_lambda_return_estimate advantage = vec_td_lambda_return_estimate(gamma, lmbda, next_state_value, reward, done, terminated)
-
a generic trainer class(1) that executes the aforementioned training loop. Through a hooking mechanism, it also supports any logging or data transformation operation at any given time.
-
various recipes to build models that correspond to the environment being deployed.
-
LLM API: Complete framework for language model fine-tuning with unified wrappers for Hugging Face and vLLM backends, conversation management with automatic chat template detection, tool integration (Python execution, function calling), specialized objectives (GRPO, SFT), and high-performance async collectors. Perfect for RLHF, supervised fine-tuning, and tool-augmented training scenarios.
Code
from torchrl.envs.llm import ChatEnv from torchrl.modules.llm import TransformersWrapper from torchrl.envs.llm.transforms import PythonInterpreter # Create environment with tool execution env = ChatEnv( tokenizer=tokenizer, system_prompt="You can execute Python code.", batch_size=[1] ).append_transform(PythonInterpreter()) # Wrap language model for training llm = TransformersWrapper( model=model, tokenizer=tokenizer, input_mode="history" ) # Multi-turn conversation with tool use obs = env.reset(TensorDict({"query": "Calculate 2+2"}, batch_size=[1])) llm_output = llm(obs) # Generates response obs = env.step(llm_output) # Environment processes response
If you feel a feature is missing from the library, please submit an issue! If you would like to contribute to new features, check our call for contributions and our contribution page.
A series of State-of-the-Art implementations are provided with an illustrative purpose:
| Algorithm | Compile Support** | Tensordict-free API | Modular Losses | Continuous and Discrete |
| DQN | 1.9x | + | NA | + (through ActionDiscretizer transform) |
| DDPG | 1.87x | + | + | - (continuous only) |
| IQL | 3.22x | + | + | + |
| CQL | 2.68x | + | + | + |
| TD3 | 2.27x | + | + | - (continuous only) |
| TD3+BC | untested | + | + | - (continuous only) |
| A2C | 2.67x | + | - | + |
| PPO | 2.42x | + | - | + |
| SAC | 2.62x | + | - | + |
| REDQ | 2.28x | + | - | - (continuous only) |
| Dreamer v1 | untested | + | + (different classes) | - (continuous only) |
| Decision Transformers | untested | + | NA | - (continuous only) |
| CrossQ | untested | + | + | - (continuous only) |
| Gail | untested | + | NA | + |
| Impala | untested | + | - | + |
| IQL (MARL) | untested | + | + | + |
| DDPG (MARL) | untested | + | + | - (continuous only) |
| PPO (MARL) | untested | + | - | + |
| QMIX-VDN (MARL) | untested | + | NA | + |
| SAC (MARL) | untested | + | - | + |
| RLHF | NA | + | NA | NA |
| LLM API (GRPO) | NA | + | + | NA |
** The number indicates expected speed-up compared to eager mode when executed on CPU. Numbers may vary depending on architecture and device.
and many more to come!
Code examples displaying toy code snippets and training scripts are also available
- LLM API & GRPO - Complete language model fine-tuning pipeline
- RLHF
- Memory-mapped replay buffers
Check the examples directory for more details about handling the various configuration settings.
We also provide tutorials and demos that give a sense of what the library can do.
If you're using TorchRL, please refer to this BibTeX entry to cite this work:
@misc{bou2023torchrl,
title={TorchRL: A data-driven decision-making library for PyTorch},
author={Albert Bou and Matteo Bettini and Sebastian Dittert and Vikash Kumar and Shagun Sodhani and Xiaomeng Yang and Gianni De Fabritiis and Vincent Moens},
year={2023},
eprint={2306.00577},
archivePrefix={arXiv},
primaryClass={cs.LG}
}
python -m venv torchrl
source torchrl/bin/activate # On Windows use: venv\Scripts\activateOr create a conda environment where the packages will be installed.
conda create --name torchrl python=3.10
conda activate torchrl
Depending on the use of torchrl that you want to make, you may want to
install the latest (nightly) PyTorch release or the latest stable version of PyTorch.
See here for a detailed list of commands,
including pip3 or other special installation instructions.
TorchRL offers a few pre-defined dependencies such as "torchrl[tests]", "torchrl[atari]", "torchrl[utils]" etc.
For the experimental training interface and configuration system, install:
pip3 install "torchrl[utils]" # Includes hydra-core and other utilitiesYou can install the latest stable release by using
pip3 install torchrlThis should work on linux (including AArch64 machines), Windows 10 and OsX (Metal chips only). On certain Windows machines (Windows 11), one should build the library locally. This can be done in two ways:
# Install and build locally v0.8.1 of the library without cloning
pip3 install git+https://github.com/pytorch/[email protected]
# Clone the library and build it locally
git clone https://github.com/pytorch/tensordict
git clone https://github.com/pytorch/rl
pip install -e tensordict
pip install -e rlNote that tensordict local build requires cmake to be installed via homebrew (MacOS) or another package manager
such as apt, apt-get, conda or yum but NOT pip, as well as pip install "pybind11[global]".
One can also build the wheels to distribute to co-workers using
pip install build
python -m build --wheelYour wheels will be stored there ./dist/torchrl<name>.whl and installable via
pip install torchrl<name>.whlThe nightly build can be installed via
pip3 install tensordict-nightly torchrl-nightlywhich we currently only ship for Linux machines. Importantly, the nightly builds require the nightly builds of PyTorch too. Also, a local build of torchrl with the nightly build of tensordict may fail - install both nightlies or both local builds but do not mix them.
Disclaimer: As of today, TorchRL requires Python 3.10+ and is roughly compatible with any pytorch version >= 2.1. Installing it will not directly require a newer version of pytorch to be installed. Indirectly though, tensordict still requires the latest PyTorch to be installed and we are working hard to loosen that requirement. The C++ binaries of TorchRL (mainly for prioritized replay buffers) will only work with PyTorch 2.7.0 and above. Some features (e.g., working with nested jagged tensors) may also be limited with older versions of pytorch. It is recommended to use the latest TorchRL with the latest PyTorch version unless there is a strong reason not to do so.
Optional dependencies
The following libraries can be installed depending on the usage one wants to make of torchrl:
# diverse
pip3 install tqdm tensorboard "hydra-core>=1.1" hydra-submitit-launcher
# rendering
pip3 install "moviepy<2.0.0"
# deepmind control suite
pip3 install dm_control
# gym, atari games
pip3 install "gym[atari]" "gym[accept-rom-license]" pygame
# tests
pip3 install pytest pyyaml pytest-instafail
# tensorboard
pip3 install tensorboard
# wandb
pip3 install wandb
Versioning issues can cause error message of the type undefined symbol
and such. For these, refer to the versioning issues document
for a complete explanation and proposed workarounds.
If you spot a bug in the library, please raise an issue in this repo.
If you have a more generic question regarding RL in PyTorch, post it on the PyTorch forum.
Internal collaborations to torchrl are welcome! Feel free to fork, submit issues and PRs. You can checkout the detailed contribution guide here. As mentioned above, a list of open contributions can be found in here.
Contributors are recommended to install pre-commit hooks (using pre-commit install). pre-commit will check for linting related issues when the code is committed locally. You can disable th check by appending -n to your commit command: git commit -m <commit message> -n
This library is released as a PyTorch beta feature. BC-breaking changes are likely to happen but they will be introduced with a deprecation warranty after a few release cycles.
TorchRL is licensed under the MIT License. See LICENSE for details.
For Tasks:
Click tags to check more tools for each tasksFor Jobs:
Alternative AI tools for rl
Similar Open Source Tools
rl
TorchRL is an open-source Reinforcement Learning (RL) library for PyTorch. It provides pytorch and **python-first** , low and high level abstractions for RL that are intended to be **efficient** , **modular** , **documented** and properly **tested**. The code is aimed at supporting research in RL. Most of it is written in python in a highly modular way, such that researchers can easily swap components, transform them or write new ones with little effort.
sdialog
SDialog is an MIT-licensed open-source toolkit for building, simulating, and evaluating LLM-based conversational agents end-to-end. It aims to bridge agent construction, user simulation, dialog generation, and evaluation in a single reproducible workflow, enabling the generation of reliable, controllable dialog systems or data at scale. The toolkit standardizes a Dialog schema, offers persona-driven multi-agent simulation with LLMs, provides composable orchestration for precise control over behavior and flow, includes built-in evaluation metrics, and offers mechanistic interpretability. It allows for easy creation of user-defined components and interoperability across various AI platforms.
lionagi
LionAGI is a robust framework for orchestrating multi-step AI operations with precise control. It allows users to bring together multiple models, advanced reasoning, tool integrations, and custom validations in a single coherent pipeline. The framework is structured, expandable, controlled, and transparent, offering features like real-time logging, message introspection, and tool usage tracking. LionAGI supports advanced multi-step reasoning with ReAct, integrates with Anthropic's Model Context Protocol, and provides observability and debugging tools. Users can seamlessly orchestrate multiple models, integrate with Claude Code CLI SDK, and leverage a fan-out fan-in pattern for orchestration. The framework also offers optional dependencies for additional functionalities like reader tools, local inference support, rich output formatting, database support, and graph visualization.
lionagi
LionAGI is a powerful intelligent workflow automation framework that introduces advanced ML models into any existing workflows and data infrastructure. It can interact with almost any model, run interactions in parallel for most models, produce structured pydantic outputs with flexible usage, automate workflow via graph based agents, use advanced prompting techniques, and more. LionAGI aims to provide a centralized agent-managed framework for "ML-powered tools coordination" and to dramatically lower the barrier of entries for creating use-case/domain specific tools. It is designed to be asynchronous only and requires Python 3.10 or higher.
curator
Bespoke Curator is an open-source tool for data curation and structured data extraction. It provides a Python library for generating synthetic data at scale, with features like programmability, performance optimization, caching, and integration with HuggingFace Datasets. The tool includes a Curator Viewer for dataset visualization and offers a rich set of functionalities for creating and refining data generation strategies.
qa-mdt
This repository provides an implementation of QA-MDT, integrating state-of-the-art models for music generation. It offers a Quality-Aware Masked Diffusion Transformer for enhanced music generation. The code is based on various repositories like AudioLDM, PixArt-alpha, MDT, AudioMAE, and Open-Sora. The implementation allows for training and fine-tuning the model with different strategies and datasets. The repository also includes instructions for preparing datasets in LMDB format and provides a script for creating a toy LMDB dataset. The model can be used for music generation tasks, with a focus on quality injection to enhance the musicality of generated music.
effective_llm_alignment
This is a super customizable, concise, user-friendly, and efficient toolkit for training and aligning LLMs. It provides support for various methods such as SFT, Distillation, DPO, ORPO, CPO, SimPO, SMPO, Non-pair Reward Modeling, Special prompts basket format, Rejection Sampling, Scoring using RM, Effective FAISS Map-Reduce Deduplication, LLM scoring using RM, NER, CLIP, Classification, and STS. The toolkit offers key libraries like PyTorch, Transformers, TRL, Accelerate, FSDP, DeepSpeed, and tools for result logging with wandb or clearml. It allows mixing datasets, generation and logging in wandb/clearml, vLLM batched generation, and aligns models using the SMPO method.
semantic-kernel
Semantic Kernel is an SDK that integrates Large Language Models (LLMs) like OpenAI, Azure OpenAI, and Hugging Face with conventional programming languages like C#, Python, and Java. Semantic Kernel achieves this by allowing you to define plugins that can be chained together in just a few lines of code. What makes Semantic Kernel _special_ , however, is its ability to _automatically_ orchestrate plugins with AI. With Semantic Kernel planners, you can ask an LLM to generate a plan that achieves a user's unique goal. Afterwards, Semantic Kernel will execute the plan for the user.
GraphRAG-SDK
Build fast and accurate GenAI applications with GraphRAG SDK, a specialized toolkit for building Graph Retrieval-Augmented Generation (GraphRAG) systems. It integrates knowledge graphs, ontology management, and state-of-the-art LLMs to deliver accurate, efficient, and customizable RAG workflows. The SDK simplifies the development process by automating ontology creation, knowledge graph agent creation, and query handling, enabling users to interact and query their knowledge graphs effectively. It supports multi-agent systems and orchestrates agents specialized in different domains. The SDK is optimized for FalkorDB, ensuring high performance and scalability for large-scale applications. By leveraging knowledge graphs, it enables semantic relationships and ontology-driven queries that go beyond standard vector similarity, enhancing retrieval-augmented generation capabilities.
zo2
ZO2 (Zeroth-Order Offloading) is an innovative framework designed to enhance the fine-tuning of large language models (LLMs) using zeroth-order (ZO) optimization techniques and advanced offloading technologies. It is tailored for setups with limited GPU memory, enabling the fine-tuning of models with over 175 billion parameters on single GPUs with as little as 18GB of memory. ZO2 optimizes CPU offloading, incorporates dynamic scheduling, and has the capability to handle very large models efficiently without extra time costs or accuracy losses.
sre
SmythOS is an operating system designed for building, deploying, and managing intelligent AI agents at scale. It provides a unified SDK and resource abstraction layer for various AI services, making it easy to scale and flexible. With an agent-first design, developer-friendly SDK, modular architecture, and enterprise security features, SmythOS offers a robust foundation for AI workloads. The system is built with a philosophy inspired by traditional operating system kernels, ensuring autonomy, control, and security for AI agents. SmythOS aims to make shipping production-ready AI agents accessible and open for everyone in the coming Internet of Agents era.
dLLM-RL
dLLM-RL is a revolutionary reinforcement learning framework designed for Diffusion Large Language Models. It supports various models with diverse structures, offers inference acceleration, RL training capabilities, and SFT functionalities. The tool introduces TraceRL for trajectory-aware RL and diffusion-based value models for optimization stability. Users can download and try models like TraDo-4B-Instruct and TraDo-8B-Instruct. The tool also provides support for multi-node setups and easy building of reinforcement learning methods. Additionally, it offers supervised fine-tuning strategies for different models and tasks.
Scrapling
Scrapling is a high-performance, intelligent web scraping library for Python that automatically adapts to website changes while significantly outperforming popular alternatives. For both beginners and experts, Scrapling provides powerful features while maintaining simplicity. It offers features like fast and stealthy HTTP requests, adaptive scraping with smart element tracking and flexible selection, high performance with lightning-fast speed and memory efficiency, and developer-friendly navigation API and rich text processing. It also includes advanced parsing features like smart navigation, content-based selection, handling structural changes, and finding similar elements. Scrapling is designed to handle anti-bot protections and website changes effectively, making it a versatile tool for web scraping tasks.
xFasterTransformer
xFasterTransformer is an optimized solution for Large Language Models (LLMs) on the X86 platform, providing high performance and scalability for inference on mainstream LLM models. It offers C++ and Python APIs for easy integration, along with example codes and benchmark scripts. Users can prepare models in a different format, convert them, and use the APIs for tasks like encoding input prompts, generating token ids, and serving inference requests. The tool supports various data types and models, and can run in single or multi-rank modes using MPI. A web demo based on Gradio is available for popular LLM models like ChatGLM and Llama2. Benchmark scripts help evaluate model inference performance quickly, and MLServer enables serving with REST and gRPC interfaces.
bee-agent-framework
The Bee Agent Framework is an open-source tool for building, deploying, and serving powerful agentic workflows at scale. It provides AI agents, tools for creating workflows in Javascript/Python, a code interpreter, memory optimization strategies, serialization for pausing/resuming workflows, traceability features, production-level control, and upcoming features like model-agnostic support and a chat UI. The framework offers various modules for agents, llms, memory, tools, caching, errors, adapters, logging, serialization, and more, with a roadmap including MLFlow integration, JSON support, structured outputs, chat client, base agent improvements, guardrails, and evaluation.
pgvecto.rs
pgvecto.rs is a Postgres extension written in Rust that provides vector similarity search functions. It offers ultra-low-latency, high-precision vector search capabilities, including sparse vector search and full-text search. With complete SQL support, async indexing, and easy data management, it simplifies data handling. The extension supports various data types like FP16/INT8, binary vectors, and Matryoshka embeddings. It ensures system performance with production-ready features, high availability, and resource efficiency. Security and permissions are managed through easy access control. The tool allows users to create tables with vector columns, insert vector data, and calculate distances between vectors using different operators. It also supports half-precision floating-point numbers for better performance and memory usage optimization.
For similar tasks
rl
TorchRL is an open-source Reinforcement Learning (RL) library for PyTorch. It provides pytorch and **python-first** , low and high level abstractions for RL that are intended to be **efficient** , **modular** , **documented** and properly **tested**. The code is aimed at supporting research in RL. Most of it is written in python in a highly modular way, such that researchers can easily swap components, transform them or write new ones with little effort.
Thinkless
Thinkless is a learnable framework that empowers a Language and Reasoning Model (LLM) to adaptively select between short-form and long-form reasoning based on task complexity and model's ability. It is trained under a reinforcement learning paradigm and employs control tokens for concise responses and detailed reasoning. The core method uses a Decoupled Group Relative Policy Optimization (DeGRPO) algorithm to govern reasoning mode selection and improve answer accuracy, reducing long-chain thinking by 50% - 90% on benchmarks like Minerva Algebra and MATH-500. Thinkless enhances computational efficiency of Reasoning Language Models.
For similar jobs
weave
Weave is a toolkit for developing Generative AI applications, built by Weights & Biases. With Weave, you can log and debug language model inputs, outputs, and traces; build rigorous, apples-to-apples evaluations for language model use cases; and organize all the information generated across the LLM workflow, from experimentation to evaluations to production. Weave aims to bring rigor, best-practices, and composability to the inherently experimental process of developing Generative AI software, without introducing cognitive overhead.
agentcloud
AgentCloud is an open-source platform that enables companies to build and deploy private LLM chat apps, empowering teams to securely interact with their data. It comprises three main components: Agent Backend, Webapp, and Vector Proxy. To run this project locally, clone the repository, install Docker, and start the services. The project is licensed under the GNU Affero General Public License, version 3 only. Contributions and feedback are welcome from the community.
oss-fuzz-gen
This framework generates fuzz targets for real-world `C`/`C++` projects with various Large Language Models (LLM) and benchmarks them via the `OSS-Fuzz` platform. It manages to successfully leverage LLMs to generate valid fuzz targets (which generate non-zero coverage increase) for 160 C/C++ projects. The maximum line coverage increase is 29% from the existing human-written targets.
LLMStack
LLMStack is a no-code platform for building generative AI agents, workflows, and chatbots. It allows users to connect their own data, internal tools, and GPT-powered models without any coding experience. LLMStack can be deployed to the cloud or on-premise and can be accessed via HTTP API or triggered from Slack or Discord.
VisionCraft
The VisionCraft API is a free API for using over 100 different AI models. From images to sound.
kaito
Kaito is an operator that automates the AI/ML inference model deployment in a Kubernetes cluster. It manages large model files using container images, avoids tuning deployment parameters to fit GPU hardware by providing preset configurations, auto-provisions GPU nodes based on model requirements, and hosts large model images in the public Microsoft Container Registry (MCR) if the license allows. Using Kaito, the workflow of onboarding large AI inference models in Kubernetes is largely simplified.
PyRIT
PyRIT is an open access automation framework designed to empower security professionals and ML engineers to red team foundation models and their applications. It automates AI Red Teaming tasks to allow operators to focus on more complicated and time-consuming tasks and can also identify security harms such as misuse (e.g., malware generation, jailbreaking), and privacy harms (e.g., identity theft). The goal is to allow researchers to have a baseline of how well their model and entire inference pipeline is doing against different harm categories and to be able to compare that baseline to future iterations of their model. This allows them to have empirical data on how well their model is doing today, and detect any degradation of performance based on future improvements.
Azure-Analytics-and-AI-Engagement
The Azure-Analytics-and-AI-Engagement repository provides packaged Industry Scenario DREAM Demos with ARM templates (Containing a demo web application, Power BI reports, Synapse resources, AML Notebooks etc.) that can be deployed in a customerโs subscription using the CAPE tool within a matter of few hours. Partners can also deploy DREAM Demos in their own subscriptions using DPoC.
