Skip to content

Python library for solving reinforcement learning (RL) problems using generative models.

License

Notifications You must be signed in to change notification settings

Mossforest/GenerativeRL_Preview

 
 

Repository files navigation

Generative Reinforcement Learning (GRL)

License

English | 简体中文(Simplified Chinese)

GenerativeRL, short for Generative Reinforcement Learning, is a Python library for solving reinforcement learning (RL) problems using generative models, such as diffusion models and flow models. This library aims to provide a framework for combining the power of generative models with the decision-making capabilities of reinforcement learning algorithms.

GenerativeRL_Preview is a preview version of GenerativeRL, which is still under rapid development with many experimental features. For stable version of GenerativeRL, please visit GenerativeRL.

Outline

Features

  • Support for training, evaluation and deploying diverse generative models, including diffusion models and flow models
  • Integration of generative models for state representation, action representation, policy learning and dynamic model learning in RL
  • Implementation of popular RL algorithms tailored for generative models, such as Q-guided policy optimization (QGPO)
  • Support for various RL environments and benchmarks
  • Easy-to-use API for training and evaluation

Framework Structure

Image Description 1

Integrated Generative Models

Score Matching Flow Matching
Diffusion Model Open In Colab
Linear VP SDE
Generalized VP SDE
Linear SDE
Flow Model Open In Colab
Independent Conditional Flow Matching 🚫
Optimal Transport Conditional Flow Matching 🚫

Integrated Algorithms

Algo./Models Diffusion Model Flow Model
QGPO 🚫
SRPO 🚫
GMPO Open In Colab
GMPG Open In Colab

Installation

Please install from source:

git clone https://github.com/zjowowen/GenerativeRL_Preview.git
cd GenerativeRL_Preview
pip install -e .

Or you can use the docker image:

docker pull zjowowen/grl:torch2.3.0-cuda12.1-cudnn8-runtime
docker run -it --rm --gpus all zjowowen/grl:torch2.3.0-cuda12.1-cudnn8-runtime /bin/bash

Quick Start

Here is an example of how to train a diffusion model for Q-guided policy optimization (QGPO) in the LunarLanderContinuous-v2 environment using GenerativeRL.

Install the required dependencies:

pip install 'gym[box2d]==0.23.1'

Download dataset from here and save it as data.npz in the current directory.

GenerativeRL uses WandB for logging. It will ask you to log in to your account when you use it. You can disable it by running:

wandb offline
import gym

from grl.algorithms.qgpo import QGPOAlgorithm
from grl.datasets import QGPOCustomizedDataset
from grl.utils.log import log
from grl_pipelines.diffusion_model.configurations.lunarlander_continuous_qgpo import config

def qgpo_pipeline(config):
    qgpo = QGPOAlgorithm(config, dataset=QGPOCustomizedDataset(numpy_data_path="./data.npz"))
    qgpo.train()

    agent = qgpo.deploy()
    env = gym.make(config.deploy.env.env_id)
    observation = env.reset()
    for _ in range(config.deploy.num_deploy_steps):
        env.render()
        observation, reward, done, _ = env.step(agent.act(observation))

if __name__ == '__main__':
    log.info("config: \n{}".format(config))
    qgpo_pipeline(config)

For more detailed examples and documentation, please refer to the GenerativeRL documentation.

Tutorials

We provide several case tutorials to help you better understand GenerativeRL. See more at tutorials.

Benchmark experiments

We offer some baseline experiments to evaluate the performance of generative reinforcement learning algorithms. See more at benchmark.

Contributing

We welcome contributions to GenerativeRL! If you are interested in contributing, please refer to the Contributing Guide.

License

GenerativeRL is licensed under the Apache License 2.0. See LICENSE for more details.

About

Python library for solving reinforcement learning (RL) problems using generative models.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 100.0%