diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml new file mode 100644 index 000000000..3c15b7be6 --- /dev/null +++ b/.github/workflows/docs.yml @@ -0,0 +1,53 @@ +name: docs_pages_workflow + +on: [pull_request] + +permissions: + pull-requests: write + +jobs: + build_docs_job: + runs-on: ubuntu-latest + env: + GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} + steps: + - name: Checkout + uses: actions/checkout@v3 + + - name: Set up Python + uses: actions/setup-python@v3 + with: + python-version: 3.8 + + - name: Get pip cache dir + id: pip-cache + run: | + python -m pip install --upgrade pip + echo "dir={$(pip cache dir)}" >> $GITHUB_OUTPUT + + - name: pip cache + uses: actions/cache@v3 + with: + path: ${{ steps.pip-cache.outputs.dir }} + key: ${{ runner.os }}-pip-${{ hashFiles('**/setup.py', '**/requirements.txt', '**/docs/requirements.txt') }} + restore-keys: | + ${{ runner.os }}-pip- + + - name: Install docs requirements + run: | + python -m pip install -r docs/requirements.txt + + - name: make the sphinx docs + run: | + make -C docs clean + make -C docs html + + - uses: readthedocs/actions/preview@v1 + with: + project-slug: "trlx" + project-language: "en" + # see: https://github.com/readthedocs/actions/tree/main/preview + # message-template (optional): Text message to be injected by the action in the Pull Request description. It supports the following placeholders to be replaced: + # {docs-pr-index-url}: URL to the root of the documentation for the Pull Request preview. + # platform (optional): Read the Docs Community (community) or Read the Docs for Business (business). (default: community) + # single-version (optional): Set this to 'true' if your project is single version, so we can link to the correct URL. (default: 'false') diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d94d14f2a..cbf9fa775 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,7 +1,7 @@ # See https://pre-commit.com for more information # See https://pre-commit.com/hooks.html for more hooks repos: -- repo: https://github.com/pre-commit/pre-commit-hooks + - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.4.0 hooks: - id: check-case-conflict @@ -18,17 +18,24 @@ repos: args: [--fix=lf] - id: requirements-txt-fixer - id: trailing-whitespace -- repo: https://github.com/psf/black + - repo: https://github.com/psf/black rev: 23.1.0 hooks: - - id: black + - id: black files: ^(trlx|examples|tests|setup.py)/ -- repo: https://github.com/pycqa/isort + - repo: https://github.com/pycqa/isort rev: 5.12.0 hooks: - - id: isort + - id: isort name: isort (python) -- repo: https://github.com/pycqa/flake8 + - repo: https://github.com/pycqa/flake8 rev: 6.0.0 hooks: - - id: flake8 + - id: flake8 + - repo: https://github.com/codespell-project/codespell + rev: v2.2.2 + hooks: + - id: codespell + args: [--ignore-words, dictionary.txt] + additional_dependencies: + - tomli diff --git a/.readthedocs.yml b/.readthedocs.yml index c8f03ab0a..d5f60f2e8 100644 --- a/.readthedocs.yml +++ b/.readthedocs.yml @@ -1,9 +1,25 @@ +# .readthedocs.yml +# Read the Docs configuration file +# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details + +# Required version: 2 +build: + os: "ubuntu-20.04" + tools: + python: "3.8" + +# Build documentation in the docs/ directory with Sphinx sphinx: - configuration: docs/source/conf.py + configuration: docs/conf.py + fail_on_warning: false + +# Optionally build your docs in additional formats such as PDF and ePub +formats: + - htmlzip +# Optionally set the version of Python and requirements required to build your docs python: - version: 3.9 install: - - requirements: docs/requirements.txt + - requirements: docs/requirements.txt diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 000000000..6f9ee79d6 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,11 @@ +# Change log + +Best viewed on [trlx.readthedocs.io](https://trlx.readthedocs.io/en/latest/changelog.html). + + + +## trlx 0.4.0 (2022-12-05) + +- python 3.8 diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 0c5169ed5..8f97f071d 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -2,7 +2,11 @@ Looking to improve `trlX`? Thanks for considering! -There are many ways to contribute, from writing tutorials in [Colab notebooks](https://colab.research.google.com) to improving the project's [documentation](https://trlx.readthedocs.io), submitting bug reports and feature requests, or even implementing new features themselves. See the outstanding [issues](https://github.com/CarperAI/trlx/issues) for ideas on where to begin. +There are many ways to contribute, from writing tutorials in [Colab notebooks](https://colab.research.google.com) to improving the project's [documentation](https://trlx.readthedocs.io), to submitting bug reports and feature requests, or even implementing new features themselves. See the outstanding [issues](https://github.com/CarperAI/trlx/issues) for ideas on where to begin. + +- [Documentation Issues](https://github.com/CarperAI/trlx/issues?q=is%3Aissue+is%3Aopen+label%3Adocumentation) +- [Bug Fixes](https://github.com/CarperAI/trlx/issues?q=is%3Aissue+is%3Aopen+label%3Abug) +- [Feature Requests](https://github.com/CarperAI/trlx/issues?q=is%3Aissue+is%3Aopen+label%3A%22feature+request%22) Here are some guidelines to help you get started 🚀. @@ -16,40 +20,83 @@ To submit a bug report or a feature request, please open an [issue](https://gith Follow these steps to start contributing code: +1. Setup your environment: + +```bash +conda create -n trlx python=3.8 torch torch-cuda=11.7 -c pytorch -c nvidia +git clone https://github.com/CarperAI/trlx +cd trlx +pip install -e ".[dev]" +pre-commit install +``` + 1. Create your own [fork](https://docs.github.com/en/get-started/quickstart/fork-a-repo#forking-a-repository) of the repository and clone it to your local machine. + ```bash git clone https://github.com//trlx.git cd trlx git remote add upstream https://github.com/CarperAI/trlx.git ``` -2. Create a new branch for your changes and give it a concise name that reflects your contribution. + +1. Create a new branch for your changes and give it a concise name that reflects your contribution. + ```bash git checkout -b ``` -2. Install the development dependencies in a Python environment. + +1. Install the development dependencies in a Python environment. + ```bash pip install -e ".[dev]" pre-commit install ``` -4. Implement your changes. Make small, independent, and well documented commits along the way (check out [these](https://cbea.ms/git-commit/) tips). -5. Add unit tests whenever appropriate and ensure that the tests pass. To run the entire test suite, use the following command from within the project root directory. + +install pre-commit + +```bash +pip install pre-commit +pre-commit install +``` + +bonus: force run pre-commit on all the files + +```bash +pre-commit run --all-files +``` + +1. Implement your changes. Make small, independent, and well documented commits along the way (check out [these](https://cbea.ms/git-commit/) tips). + +1. Add unit tests whenever appropriate and ensure that the tests pass. To run the entire test suite, use the following command from within the project root directory. + ```bash pytest ``` + For changes with minimal project scope (e.g. a simple bug fix), you might want to run the unit tests for just a specific test file instead: + ```bash pytest -vv -k "" ``` -5. Commit your final changes. Our `pre-commit` hooks will automatically run before each commit and will prevent you from committing code that does not pass our style and linter checks. They'll also automatically format your code! To run these manually, use the following command: + +1. Commit your final changes. Our `pre-commit` hooks will automatically run before each commit and will prevent you from committing code that does not pass our style and linter checks. They'll also automatically format your code! To run these manually, use the following command: + ```bash pre-commit run --all-files ``` -6. Push the changes to your fork. +1. Push the changes to your fork. Finally ... 🥁 ... Create a [pull request](https://docs.github.com/en/github/collaborating-with-issues-and-pull-requests/creating-a-pull-request) to the `trlX` repository! Make sure to include a description of your changes and link to any relevant issues. -> __Tip__: If you're looking to introduce an experimental feature, we suggest testing the behavior of your proposed feature on some of the existing [examples](https://github.com/CarperAI/trlx/tree/master/examples), such as [random walks](https://github.com/CarperAI/trlx/blob/master/examples/randomwalks). This will help you get a better sense of how the feature would work in practice and will also help you identify any potential flaws in the implementation. +> **Tip**: If you're looking to introduce an experimental feature, we suggest testing the behavior of your proposed feature on some of the existing [examples](https://github.com/CarperAI/trlx/tree/master/examples), such as [random walks](https://github.com/CarperAI/trlx/blob/master/examples/randomwalks). This will help you get a better sense of how the feature would work in practice and will also help you identify any potential flaws in the implementation. + +## Tips & Tricks + +Set transformers verbosity level + +```bash +TRANSFORMERS_VERBOSITY=error +``` ## Asking questions @@ -63,4 +110,4 @@ This project adheres to the [Contributor Covenant Code of Conduct](https://githu By contributing, you agree that your contributions will be licensed under its MIT License. -# Thank you for your contribution 🐠! +## Thank you for your contribution! 🐠 diff --git a/Makefile b/Makefile new file mode 100644 index 000000000..b13cd2dd2 --- /dev/null +++ b/Makefile @@ -0,0 +1,6 @@ +gendoc: + docker build -t trlxgendocs -f docker/docs/Dockerfile . +run: + docker run --rm -it -v ${PWD}:/build \ + --entrypoint /bin/bash \ + trlxgendocs diff --git a/README.md b/README.md index 27b8ffa7d..5ca573cb7 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,5 @@ +![TRLX](./docs/_static/apple-touch-icon-114x114.png) + [docs-image]: https://readthedocs.org/projects/trlX/badge/?version=latest [docs-url]: https://trlX.readthedocs.io/en/latest/?badge=latest @@ -12,6 +14,7 @@ You can read more about trlX in our [documentation](https://trlX.readthedocs.io) Want to collect human annotations for your RL application? Check out [CHEESE!](https://github.com/carperai/cheese), our library for HiTL data collection. ## Installation + ```bash git clone https://github.com/CarperAI/trlx.git cd trlx @@ -28,23 +31,29 @@ For more usage see [examples](./examples). You can also try the colab notebooks ## How to Train + You can train a model using a reward function or a reward-labeled dataset. -#### Using a reward function +### Using a reward function + ```python trainer = trlx.train('gpt2', reward_fn=lambda samples, **kwargs: [sample.count('cats') for sample in samples]) ``` -#### Using a reward-labeled dataset + +### Using a reward-labeled dataset + ```python trainer = trlx.train('EleutherAI/gpt-j-6B', dataset=[('dolphins', 'geese'), (1.0, 100.0)]) ``` -#### Trainers provide a wrapper over their underlying model +### Trainers provide a wrapper over their underlying model + ```python trainer.generate(**tokenizer('Q: Who rules the world? A:', return_tensors='pt'), do_sample=True) ``` -#### Save the resulting model to a Hugging Face pretrained language model. (Ready to upload to the Hub!) +### Save the resulting model to a Hugging Face pretrained language model. (Ready to upload to the Hub!) + ```python trainer.save_pretrained('/path/to/output/folder/') ``` @@ -69,46 +78,11 @@ python examples/nemo_ilql_sentiments.py For more usage see the [NeMo README](./trlx/trainer/nemo) #### Use Ray Tune to launch hyperparameter sweep + ```bash python -m trlx.sweep --config configs/sweeps/ppo_sweep.yml examples/ppo_sentiments.py ``` -## Logging - -trlX uses the standard Python `logging` library to log training information to the console. The default logger is set to the `INFO` level, which means that `INFO`, `WARNING`, `ERROR`, and `CRITICAL` level messages will be printed to standard output. - -To change the log level directly, you can use the verbosity setter. For example, to set the log level to `WARNING` use: - -```python -import trlx - -trlx.logging.set_verbosity(trlx.logging.WARNING) -``` - -This will suppress `INFO` level messages, but still print `WARNING`, `ERROR`, and `CRITICAL` level messages. - -You can also control logging verbosity by setting the `TRLX_VERBOSITY` environment variable to one of the standard logging [level names](https://docs.python.org/3/library/logging.html#logging-levels): - -* `CRITICAL` (`trlx.logging.CRITICAL`) -* `ERROR` (`trlx.logging.ERROR`) -* `WARNING` (`trlx.logging.WARNING`) -* `INFO` (`trlx.logging.INFO`) -* `DEBUG` (`trlx.logging.DEBUG`) - -```sh -export TRLX_VERBOSITY=WARNING -``` - -By default, [`tqdm`](https://tqdm.github.io/docs/tqdm/) progress bars are used to display training progress. You can disable them by calling `trlx.logging.disable_progress_bar()`, otherwise `trlx.logging.enable_progress_bar()` to enable. - -Messages can be formatted with greater detail by setting `trlx.logging.enable_explicit_format()`. This will inject call-site information into each log which may be helpful for debugging. - -```sh -[2023-01-01 05:00:00,000] [INFO] [ppo_orchestrator.py:63:make_experience] [RANK 0] Message... -``` - -> 💡 Tip: To reduce the amount of logging output, you might find it helpful to change log levels of third-party libraries used by trlX. For example, try adding `transformers.logging.set_verbosity_error()` to the top of your trlX scripts to silence verbose messages from the `transformers` library (see their [logging docs](https://huggingface.co/docs/transformers/main_classes/logging#logging) for more details). - ## Contributing For development check out these [guidelines](./CONTRIBUTING.md) diff --git a/dictionary.txt b/dictionary.txt new file mode 100644 index 000000000..1cd3c8cd5 --- /dev/null +++ b/dictionary.txt @@ -0,0 +1 @@ +rouge diff --git a/docker/docs/Dockerfile b/docker/docs/Dockerfile new file mode 100644 index 000000000..7c8287b94 --- /dev/null +++ b/docker/docs/Dockerfile @@ -0,0 +1,13 @@ +FROM python:3.8-slim + +# pip install -r docs/requirements.txt +# sphinx-build -b html docs docs/build/html -j auto +# sphinx-build -b html -D nb_execution_mode=off docs docs/build/html -j auto + +RUN python -m pip install --upgrade --no-cache-dir pip +ADD docs/requirements.txt /tmp/requirements.txt +RUN python -m pip install --exists-action=w --no-cache-dir -r /tmp/requirements.txt +RUN mkdir /build +WORKDIR /build/docs/ +ENTRYPOINT /build/docs/build.sh +# RUN `which sphinx-build` -T -E -b html -d _build/doctrees-readthedocs -D language=en . _build/html diff --git a/docs/Makefile b/docs/Makefile index d0c3cbf10..ed8809902 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -5,7 +5,7 @@ # from the environment for the first two. SPHINXOPTS ?= SPHINXBUILD ?= sphinx-build -SOURCEDIR = source +SOURCEDIR = . BUILDDIR = build # Put it first so that "make" without argument is like "make help". diff --git a/docs/README.md b/docs/README.md new file mode 100644 index 000000000..62a4ae956 --- /dev/null +++ b/docs/README.md @@ -0,0 +1,5 @@ +# How To build the documentation + +```bash +make -c docs html +``` diff --git a/docs/_static/apple-touch-icon-114x114.png b/docs/_static/apple-touch-icon-114x114.png new file mode 100644 index 000000000..f41c09a09 Binary files /dev/null and b/docs/_static/apple-touch-icon-114x114.png differ diff --git a/docs/_static/apple-touch-icon-120x120.png b/docs/_static/apple-touch-icon-120x120.png new file mode 100644 index 000000000..6ff0b418e Binary files /dev/null and b/docs/_static/apple-touch-icon-120x120.png differ diff --git a/docs/_static/apple-touch-icon-144x144.png b/docs/_static/apple-touch-icon-144x144.png new file mode 100644 index 000000000..4657719af Binary files /dev/null and b/docs/_static/apple-touch-icon-144x144.png differ diff --git a/docs/_static/apple-touch-icon-152x152.png b/docs/_static/apple-touch-icon-152x152.png new file mode 100644 index 000000000..7bf86af1c Binary files /dev/null and b/docs/_static/apple-touch-icon-152x152.png differ diff --git a/docs/_static/apple-touch-icon-167x167.png b/docs/_static/apple-touch-icon-167x167.png new file mode 100644 index 000000000..7e2945647 Binary files /dev/null and b/docs/_static/apple-touch-icon-167x167.png differ diff --git a/docs/_static/apple-touch-icon-180x180.png b/docs/_static/apple-touch-icon-180x180.png new file mode 100644 index 000000000..31c726d0f Binary files /dev/null and b/docs/_static/apple-touch-icon-180x180.png differ diff --git a/docs/_static/apple-touch-icon-57x57.png b/docs/_static/apple-touch-icon-57x57.png new file mode 100644 index 000000000..25b5a2d8a Binary files /dev/null and b/docs/_static/apple-touch-icon-57x57.png differ diff --git a/docs/_static/apple-touch-icon-60x60.png b/docs/_static/apple-touch-icon-60x60.png new file mode 100644 index 000000000..eb7bfdc4f Binary files /dev/null and b/docs/_static/apple-touch-icon-60x60.png differ diff --git a/docs/_static/apple-touch-icon-72x72.png b/docs/_static/apple-touch-icon-72x72.png new file mode 100644 index 000000000..0562e6de7 Binary files /dev/null and b/docs/_static/apple-touch-icon-72x72.png differ diff --git a/docs/_static/apple-touch-icon-76x76.png b/docs/_static/apple-touch-icon-76x76.png new file mode 100644 index 000000000..084ad067c Binary files /dev/null and b/docs/_static/apple-touch-icon-76x76.png differ diff --git a/docs/_static/favicon-128x128.png b/docs/_static/favicon-128x128.png new file mode 100644 index 000000000..4e43cc31f Binary files /dev/null and b/docs/_static/favicon-128x128.png differ diff --git a/docs/_static/favicon-16x16.png b/docs/_static/favicon-16x16.png new file mode 100644 index 000000000..e06e67ffc Binary files /dev/null and b/docs/_static/favicon-16x16.png differ diff --git a/docs/_static/favicon-196x196.png b/docs/_static/favicon-196x196.png new file mode 100644 index 000000000..fcea049fc Binary files /dev/null and b/docs/_static/favicon-196x196.png differ diff --git a/docs/_static/favicon-32x32.png b/docs/_static/favicon-32x32.png new file mode 100644 index 000000000..5008598c0 Binary files /dev/null and b/docs/_static/favicon-32x32.png differ diff --git a/docs/_static/favicon-96x96.png b/docs/_static/favicon-96x96.png new file mode 100644 index 000000000..9d11839a5 Binary files /dev/null and b/docs/_static/favicon-96x96.png differ diff --git a/docs/_static/style.css b/docs/_static/style.css new file mode 100644 index 000000000..2fac0848d --- /dev/null +++ b/docs/_static/style.css @@ -0,0 +1,26 @@ +@import url("theme.css"); + +:root { + --block-bg-opacity: .5; +} + +.wy-side-nav-search { + background-color: #fff; +} + +.getting-started { + background-color: rgba(78, 150, 253, var(--block-bg-opacity)); +} + +.user-guides { + background-color: rgba(0, 169, 154, var(--block-bg-opacity)); +} + +.developer-docs { + background-color: rgba(171, 0, 182, var(--block-bg-opacity)); +} + +.key-ideas +{ + border: 0px +} diff --git a/docs/_static/trlx-logo-512x512.png b/docs/_static/trlx-logo-512x512.png new file mode 100644 index 000000000..d8bdb400c Binary files /dev/null and b/docs/_static/trlx-logo-512x512.png differ diff --git a/docs/_templates/layout.html b/docs/_templates/layout.html new file mode 100644 index 000000000..4c57ba830 --- /dev/null +++ b/docs/_templates/layout.html @@ -0,0 +1,2 @@ +{% extends "!layout.html" %} +{% set css_files = css_files + ["_static/style.css"] %} diff --git a/docs/build.sh b/docs/build.sh new file mode 100755 index 000000000..147ebab99 --- /dev/null +++ b/docs/build.sh @@ -0,0 +1,3 @@ +#!/bin/bash + +`which sphinx-build` -T -E -b html -d _build/doctrees-readthedocs -D language=en . _build/html diff --git a/docs/changelog.md b/docs/changelog.md new file mode 100644 index 000000000..66efc0fec --- /dev/null +++ b/docs/changelog.md @@ -0,0 +1,2 @@ +```{include} ../CHANGELOG.md +``` diff --git a/docs/conf.py b/docs/conf.py new file mode 100644 index 000000000..32a8c2df3 --- /dev/null +++ b/docs/conf.py @@ -0,0 +1,187 @@ +# Configuration file for the Sphinx documentation builder. +# +# This file only contains a selection of the most common options. For a full +# list see the documentation: +# https://www.sphinx-doc.org/en/master/usage/configuration.html + +# -- Path setup -------------------------------------------------------------- + +# If extensions (or modules to document with autodoc) are in another directory, +# add these directories to sys.path here. If the directory is relative to the +# documentation root, use os.path.abspath to make it absolute, like shown here. +# +import os +import sys + +sys.path.insert(0, os.path.abspath("..")) + + +# -- Project information ----------------------------------------------------- + +project = "trlX" +copyright = "2023, CarperAI" +author = "CarperAI" + +# -- General configuration --------------------------------------------------- + +# Add any Sphinx extension module names here, as strings. They can be +# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom +# ones. + +extensions = [ + "sphinx.ext.autodoc", + "sphinx.ext.autosummary", + "sphinx.ext.intersphinx", + "sphinx.ext.mathjax", + "sphinx.ext.napoleon", + "sphinx.ext.viewcode", + "matplotlib.sphinxext.plot_directive", + "sphinx_autodoc_typehints", + "myst_nb", + # "myst_parser", + "sphinx_remove_toctrees", + "sphinx_copybutton", + "sphinx_design", +] + +intersphinx_mapping = { + "python": ("https://docs.python.org/3/", None), + "numpy": ("https://docs.scipy.org/doc/numpy/", None), + "scipy": ("https://docs.scipy.org/doc/scipy/reference/", None), + "pytorch": ("https://pytorch.readthedocs.io/", None), +} + +autodoc_preserve_defaults = True + +# Add any paths that contain templates here, relative to this directory. +templates_path = ["_templates"] + +# The suffix(es) of source filenames. +# You can specify multiple suffix as a list of string: +# +source_suffix = [".rst", ".md"] + +# The master toctree document. +main_doc = "index" + +# List of patterns, relative to source directory, that match files and +# directories to ignore when looking for source files. +# This pattern also affects html_static_path and html_extra_path. +exclude_patterns = [ + # Sometimes sphinx reads its own outputs as inputs! + "build/html", +] + +# The name of the Pygments (syntax highlighting) style to use. +pygments_style = None + +autosummary_generate = True +napolean_use_rtype = False + +# -- Options for nbsphinx ----------------------------------------------------- + +# Execute notebooks before conversion: 'always', 'never', 'auto' (default) +# We execute all notebooks, exclude the slow ones using 'exclude_patterns' +nbsphinx_execute = "always" + +# Use this kernel instead of the one stored in the notebook metadata: +# nbsphinx_kernel_name = 'python3' + +# List of arguments to be passed to the kernel that executes the notebooks: +# nbsphinx_execute_arguments = [] + +# If True, the build process is continued even if an exception occurs: +# nbsphinx_allow_errors = True + + +# Controls when a cell will time out (defaults to 30; use -1 for no timeout): +nbsphinx_timeout = 180 + +# Default Pygments lexer for syntax highlighting in code cells: +# nbsphinx_codecell_lexer = 'ipython3' + +# Width of input/output prompts used in CSS: +# nbsphinx_prompt_width = '8ex' + +# If window is narrower than this, input/output prompts are on separate lines: +# nbsphinx_responsive_width = '700px' + +# This is processed by Jinja2 and inserted before each notebook +nbsphinx_prolog = r""" # noqa: E501 +{% set docname = 'docs/' + env.doc2path(env.docname, base=None) %} +.. only:: html + .. role:: raw-html(raw) + :format: html + .. nbinfo:: + Interactive online version: + :raw-html:`Open In Colab` + __ https://github.com/CarperAI/trlx/blob/ + {{ env.config.release }}/{{ docname }} +""" + +# This is processed by Jinja2 and inserted after each notebook +# nbsphinx_epilog = r""" +# """ + +# Input prompt for code cells. "%s" is replaced by the execution count. +# nbsphinx_input_prompt = 'In [%s]:' + +# Output prompt for code cells. "%s" is replaced by the execution count. +# nbsphinx_output_prompt = 'Out[%s]:' + +# Specify conversion functions for custom notebook formats: +# import jupytext +# nbsphinx_custom_formats = { +# '.Rmd': lambda s: jupytext.reads(s, '.Rmd'), +# } + +# Link or path to require.js, set to empty string to disable +# nbsphinx_requirejs_path = '' + +# Options for loading require.js +# nbsphinx_requirejs_options = {'async': 'async'} + +# mathjax_config = { +# 'TeX': {'equationNumbers': {'autoNumber': 'AMS', 'useLabelIds': True}}, +# } + +# Additional files needed for generating LaTeX/PDF output: +# latex_additional_files = ['references.bib'] + + +# -- Options for HTML output ------------------------------------------------- + +# The theme to use for HTML and HTML Help pages. See the documentation for +# a list of builtin themes. +# +html_theme = "sphinx_book_theme" + +# Add any paths that contain custom static files (such as style sheets) here, +# relative to this directory. They are copied after the builtin static files, +# so a file named "default.css" will overwrite the builtin "default.css". +html_static_path = ["_static"] + +# Output file base name for HTML help builder. +htmlhelp_basename = "TRLXdoc" + +# -- Extension configuration ------------------------------------------------- + +# Tell sphinx-autodoc-typehints to generate stub parameter annotations including +# types, even if the parameters aren't explicitly documented. +always_document_param_types = True + +# Theme options are theme-specific and customize the look and feel of a theme +# further. For a list of options available for each theme, see the +# documentation. +html_theme_options = { + # "logo_only": True, + "show_toc_level": 2, + "repository_url": "https://github.com/CarperAI/trlx", + "use_repository_button": True, # add a "link to repository" button +} + +# The name of an image file (relative to this directory) to place at the top +# of the sidebar. +html_logo = "_static/apple-touch-icon-144x144.png" + +html_favicon = "_static/favicon-16x16.png" diff --git a/docs/source/configs.rst b/docs/configs.rst similarity index 86% rename from docs/source/configs.rst rename to docs/configs.rst index da5e1f2e6..0e2abd369 100644 --- a/docs/source/configs.rst +++ b/docs/configs.rst @@ -25,10 +25,10 @@ the specific method being used (i.e. ILQL or PPO) **PPO** -.. autoclass:: trlx.data.method_configs.PPOConfig +.. autoclass:: trlx.trainer.nn.ppo_models.MethodConfig :members: **ILQL** -.. autoclass:: trlx.data.method_configs.ILQLConfig +.. autoclass:: trlx.trainer.nn.ilql_models.ILQLConfig :members: diff --git a/docs/source/data.rst b/docs/data.rst similarity index 100% rename from docs/source/data.rst rename to docs/data.rst diff --git a/docs/examples.md b/docs/examples.md new file mode 100644 index 000000000..3f518b7d9 --- /dev/null +++ b/docs/examples.md @@ -0,0 +1,26 @@ +# Examples + +In the ``examples`` folder you can find several example training tasks. + +Check the configs folder for the associated configs files. + +## [randomwalks](/examples.randomwalks) + +does offline reinforcement on a set of graph random walks to stitch shortest paths +to some destination. + +## simulacra + +optimizes prompts by using [prompts-ratings dataset](https://github.com/JD-P/simulacra-aesthetic-captions). + +## architext + +tries to optimize designs represented textually by minimizing number of rooms (pre-trained model is under a license on hf). + +## ilql_sentiments and ppo_sentiments + +train to generate movie reviews with a positive sentiment, in offline setting – by fitting to IMDB +dataset sentiment scores, and in online setting – by sampling finetuned on IMDB +model and rating samples with learned sentiment reward model, You can tweak +these scripts to your liking and tune hyperparameters to your problem if you +wish to use trlx for some custom task. diff --git a/docs/faq.md b/docs/faq.md new file mode 100644 index 000000000..e663db2d2 --- /dev/null +++ b/docs/faq.md @@ -0,0 +1,8 @@ +# Frequently Asked Questions + +```{admonition} How to add a new page to the documentation? +RST primer for Sphinx: https://thomas-cokelaer.info/tutorials/sphinx/rest_syntax.html +``` + +We are collecting here answers to frequently asked questions. +Contributions welcome! diff --git a/docs/glossary.md b/docs/glossary.md new file mode 100644 index 000000000..aef6f47bb --- /dev/null +++ b/docs/glossary.md @@ -0,0 +1,81 @@ +# Glossary of Terms + +```{glossary} +[Agent]() + An agent in reinforcement learning is the entity that interacts with the {term}`Environment` to learn how to maximize its {term}`Reward`. + +[Action]() + An action in reinforcement learning is the signal that the {term}`Agent` provides to the {term}`Environment` to indicate what it wants to do. + + In other words, an action is a scalar value that the agent provides to the environment to indicate what it wants to do. The agent's goal is to maximize the total reward it receives over a sequence of {term}`Steps`. + +[CPU](https://en.wikipedia.org/wiki/Central_processing_unit) + Short for *Central Processing Unit*, CPUs are the standard computational architecture + available in most computers. trlX can run computations on CPUs, but often can achieve + much better performance on {term}`GPU` . + + +[Device](https://en.wikipedia.org/wiki/Device_computing) + The generic name used to refer to the {term}`CPU`, {term}`GPU`, used + by TRLX to perform computations. + +[Environment]() + An environment in reinforcement learning is the system that the agent interacts with. It is the source of {term}`State`, {term}`Action`, and {term}`Reward`. + + In other words, an environment is a system that defines the agent's observation space, action space, and reward function. It is the source of the agent's experience, and the goal of the agent is to maximize the total reward it receives over a sequence of {term}`Steps`. + +[GPU](https://en.wikipedia.org/wiki/Graphics_processing_unit) + Short for *Graphical Processing Unit*, GPUs were originally specialized for operations + related to rendering of images on screen, but now are much more general-purpose. TRLX is + able to target GPUs for fast operations on arrays (see also {term}`CPU`). + +[Policy]() + A policy in reinforcement learning is a function that maps {term}`State` to {term}`Action`. + + In other words, a policy is a function that maps the agent's current state to the action it should take. The agent's goal is to maximize the total reward it receives over a sequence of {term}`Steps`. + +[PPO](https://arxiv.org/abs/1707.06347) + Short for *Proximal Policy Optimization*, PPO is a {term}`Policy Gradient` algorithm + that is able to learn policies in high-dimensional, continuous action spaces. + +[Policy Gradient](https://spinningup.openai.com/en/latest/spinningup/rl_intro3.html#policy-gradients) + Policy gradient methods are a class of reinforcement learning algorithms that are able to + learn policies in high-dimensional, continuous action spaces. + +[Reinforcement Learning](https://en.wikipedia.org/wiki/Reinforcement_learning) + Reinforcement learning (RL) is a machine learning paradigm that trains an agent to maximize its + {term}`Reward` by interacting with an {term}`Environment`. + +[Reward]() + A reward in reinforcement learning is the signal that the {term}`Environment` provides to the {term}`Agent` to indicate how well it is performing. + + In other words, a reward is a scalar value that the environment provides to the agent to indicate how well it is performing. The agent's goal is to maximize the total reward it receives over a sequence of {term}`Steps`. + +[Rollout]() + A rollout in reinforcement learning is the process of executing a {term}`Policy`, starting from a specific state in the {term}`Environment`, and following it to the end to obtain a complete trajectory of {term}`State`, {term}`Action`, and {term}`Reward`. + + In other words, a Rollout is a simulation of a policy's behavior in the environment over a fixed number of {term}`Steps` or until a terminal state is reached. It provides a means of evaluating the {term}`Policy`'s performance, as the total reward collected over the trajectory can be used as a measure of its effectiveness. + +[State]() + A state in reinforcement learning is the observation that the {term}`Environment` provides to the {term}`Agent`. + +[Steps]() + A step in reinforcement learning is the process of taking a single {term}`Action` in the {term}`Environment`, and observing the resulting {term}`State` and {term}`Reward`. + + In other words, a step is a single iteration of the environment's dynamics, where the agent takes an action and receives a reward and a new state. The agent's goal is to maximize the total reward it receives over a sequence of steps. + +[Trajectory] + + In a {term}`PPO` (Proximal Policy Optimization) setup, a fixed-length trajectory + segment refers to a fixed number of time steps in an episode of an + environment.At each time step, the agent takes an action based on the current + state and receives a reward from the environment. By using fixed-length + trajectory segments, the agent's behavior is divided into chunks of a fixed + length, and each chunk is used for a single PPO update. This allows for more + efficient use of the {term}`Agent`'s experience by breaking it into smaller pieces, and + it also helps to stabilize the learning process by making the training updates + less sensitive to the length of the episode. Fixed-length trajectory segments + are often used in Reinforcement Learning (RL) algorithms, including {term}`PPO`, to + update the policy network. + +``` diff --git a/docs/source/index.rst b/docs/index.rst similarity index 76% rename from docs/source/index.rst rename to docs/index.rst index 782e29ecc..fe381aaf0 100644 --- a/docs/source/index.rst +++ b/docs/index.rst @@ -8,6 +8,15 @@ Welcome to trlX's documentation! trlX is a library made for training large language models using reinforcement learning. It currently supports training using PPO or ILQL for models up to 20B using Accelerate. +Installation +------------ +.. tab-set:: + + .. code-block:: bash + + pip install "trlx" + + .. toctree:: :maxdepth: 2 :caption: Contents: @@ -17,8 +26,19 @@ currently supports training using PPO or ILQL for models up to 20B using Acceler orchestrator configs pipeline + trainer examples +.. toctree:: + :hidden: + :maxdepth: 1 + :caption: Resources + + changelog + faq + glossary + + Indices and tables ================== diff --git a/docs/models.md b/docs/models.md new file mode 100644 index 000000000..91361720e --- /dev/null +++ b/docs/models.md @@ -0,0 +1 @@ +# Models diff --git a/docs/source/orchestrator.rst b/docs/orchestrator.rst similarity index 100% rename from docs/source/orchestrator.rst rename to docs/orchestrator.rst diff --git a/docs/pipeline.md b/docs/pipeline.md new file mode 100644 index 000000000..066a28de5 --- /dev/null +++ b/docs/pipeline.md @@ -0,0 +1,30 @@ +# Pipelines and Rollout Store + +## Pipelines + +Pipelines in trlX provide a way to read from a dataset. They are used to fetch data from the dataset and feed it to the models for training or inference. The pipelines allow for efficient processing of the data and ensure that the models have access to the data they need for their tasks. + +## Rollout Stores + +Rollout stores in trlX are used to store experiences created for the models by the orchestrator. The experiences in the rollout stores serve as the training data for the models. The models use the experiences stored in their rollout stores to learn and improve their behavior. The rollout stores provide a convenient and efficient way for the models to access the experiences they need for training. + +## General + +.. autoclass:: trlx.pipeline.BasePipeline + :members: + +.. autoclass:: trlx.pipeline.BaseRolloutStore + :members: + +## PPO + +.. autoclass:: trlx.pipeline.ppo_pipeline.PPORolloutStorage + :members: + +## ILQL + +.. autoclass:: trlx.pipeline.offline_pipeline.PromptPipeline + :members: + +.. autoclass:: trlx.pipeline.offline_pipeline.ILQLRolloutStorage + :members: diff --git a/docs/requirements.txt b/docs/requirements.txt index 7a33f300e..3052a2f0c 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,11 +1,20 @@ -accelerate==0.12.0 -datasets==2.4.0 -deepspeed==0.7.3 -einops==0.4.1 -numpy==1.23.2 -sphinx==4.0.0 -sphinx_rtd_theme +accelerate +commonmark +datasets +deepspeed +docutils +jupyter-sphinx +matplotlib +myst-nb +nbsphinx +Pygments +ray +readthedocs-sphinx-ext +rich +sphinx-autodoc-typehints +sphinx-book-theme +sphinx-copybutton +sphinx-design +sphinx-remove-toctrees torchtyping -tqdm==4.64.0 -transformers==4.21.2 -wandb==0.13.2 +transformers diff --git a/docs/source/conf.py b/docs/source/conf.py deleted file mode 100644 index 0a9a11c86..000000000 --- a/docs/source/conf.py +++ /dev/null @@ -1,54 +0,0 @@ -# Configuration file for the Sphinx documentation builder. -# -# This file only contains a selection of the most common options. For a full -# list see the documentation: -# https://www.sphinx-doc.org/en/master/usage/configuration.html - -# -- Path setup -------------------------------------------------------------- - -# If extensions (or modules to document with autodoc) are in another directory, -# add these directories to sys.path here. If the directory is relative to the -# documentation root, use os.path.abspath to make it absolute, like shown here. -# -import os -import sys - -import sphinx_rtd_theme - -sys.path.insert(0, os.path.abspath('../..')) - - -# -- Project information ----------------------------------------------------- - -project = 'trlX' -copyright = '2022, CarperAI' -author = 'CarperAI' - -# -- General configuration --------------------------------------------------- - -# Add any Sphinx extension module names here, as strings. They can be -# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom -# ones. - -extensions = ['sphinx_rtd_theme', 'sphinx.ext.todo', 'sphinx.ext.viewcode', 'sphinx.ext.autodoc', 'sphinx.ext.autosummary', 'sphinx.ext.autosectionlabel'] - -# Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] - -# List of patterns, relative to source directory, that match files and -# directories to ignore when looking for source files. -# This pattern also affects html_static_path and html_extra_path. -exclude_patterns = [] - - -# -- Options for HTML output ------------------------------------------------- - -# The theme to use for HTML and HTML Help pages. See the documentation for -# a list of builtin themes. -# -html_theme = 'sphinx_rtd_theme' - -# Add any paths that contain custom static files (such as style sheets) here, -# relative to this directory. They are copied after the builtin static files, -# so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] diff --git a/docs/source/examples.rst b/docs/source/examples.rst deleted file mode 100644 index 6f5db49d1..000000000 --- a/docs/source/examples.rst +++ /dev/null @@ -1,18 +0,0 @@ -.. _examples: - -Examples -************************ - -In the ``examples`` folder you can find several example training tasks. Check -the configs folder for the associated configs files. ``examples.randomwalks`` -does offline reinforcement on a set of graph random walks to stitch shortest -paths to some destination. ``examples.simulacra`` optimizes prompts by using -prompts-ratings dataset (https://github.com/JD-P/simulacra-aesthetic-captions). -``examples.architext`` tries to optimize designs represented textually by -minimazing number of rooms (pretrained model is under a license on hf). -``examples.ilql_sentiments`` and ``examples.ppo_sentiments`` train to generate -movie reviews with a positive sentiment, in offline setting – by fitting to IMDB -dataset sentiment scores, and in online setting – by sampling finetuned on IMDB -model and rating samples with learned sentiment reward model, You can tweak -these scripts to your liking and tune hyperparameters to your problem if you -wish to use trlx for some custom task. diff --git a/docs/source/pipeline.rst b/docs/source/pipeline.rst deleted file mode 100644 index 04d1a8c04..000000000 --- a/docs/source/pipeline.rst +++ /dev/null @@ -1,28 +0,0 @@ -.. _pipeline: - -Pipelines -************************ - -Pipelines are how you read from a dataset with trlX. Rollout stores are how models store experiences created -for them by the orchestrator. It is these experiences in their rollout store that they are trained on. - -**General** - -.. autoclass:: trlx.pipeline.BasePipeline - :members: - -.. autoclass:: trlx.pipeline.BaseRolloutStore - :members: - -**PPO** - -.. autoclass:: trlx.pipeline.ppo_pipeline.PPORolloutStorage - :members: - -**ILQL** - -.. autoclass:: trlx.pipeline.offline_pipeline.PromptPipeline - :members: - -.. autoclass:: trlx.pipeline.offline_pipeline.ILQLRolloutStorage - :members: diff --git a/docs/source/trainer.rst b/docs/trainer.rst similarity index 100% rename from docs/source/trainer.rst rename to docs/trainer.rst diff --git a/examples/experiments/grounded_program_synthesis/lang.py b/examples/experiments/grounded_program_synthesis/lang.py index d2436c3f6..9c3f076c0 100644 --- a/examples/experiments/grounded_program_synthesis/lang.py +++ b/examples/experiments/grounded_program_synthesis/lang.py @@ -109,7 +109,7 @@ def __call__(self, statement_string: str): # This is used to store the input, output and the function template. # Input : List given as an input to the function. # function_template : The atomic function in a given DSL Grammar -# Output : Transformed outut by applying function on the input. +# Output : Transformed output by applying function on the input. generation_template = {"function_template": "NONE", "output": "NONE", "input": []} diff --git a/examples/experiments/grounded_program_synthesis/train_trlx.py b/examples/experiments/grounded_program_synthesis/train_trlx.py index 8071fc210..6cfe793a0 100644 --- a/examples/experiments/grounded_program_synthesis/train_trlx.py +++ b/examples/experiments/grounded_program_synthesis/train_trlx.py @@ -17,7 +17,7 @@ def __init__(self): self.train_data = json.load(f) with open("dataset/test.json", "r") as f: self.test_data = json.load(f) - logger.info("Sucessfully loaded the dataset") + logger.info("Successfully loaded the dataset") def load_datapoints(self, split="train"): if split == "train": @@ -74,7 +74,7 @@ def main(hparams={}): if __name__ == "__main__": - # TEST REWARD FUNTION + # TEST REWARD FUNCTION assert (reward_fn(["Input: 1 Output: [-4,-5,-2] Function: div_n(reverse([-2, -5, -4]),1)"])) == [1] assert (reward_fn(["Input: 1 Output: [-4,-5,-2] Function: div_n(reverse([-2, -5, -a]),1)"])) == [-1] assert (reward_fn(["Input: 1 Output: [-4,-5,-2] Function: div_n(reverse([-2, -5, -3]),1)"])) == [-0.5] diff --git a/pyproject.toml b/pyproject.toml index f6ec87f55..40e261241 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,3 +8,8 @@ profile = "black" [tool.black] line-length = 120 + +[tool.pytest.ini_options] +minversion = "6.0" +addopts = "-ra -q --doctest-modules -vv --cov=trlx/ " +testpaths = ["tests"] diff --git a/setup.cfg b/setup.cfg index 4a54f7747..2bf3cca73 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,38 +1,65 @@ [metadata] -name = trlx -author = Alex Havrilla -version = 0.3.0 -url = https://github.com/CarperAI/trlx -description = A repo for distributed training of language models with Reinforcement Learning via Human Feedback (RLHF) +author = The CarperAI team +description = Transformer Reinforcement Learning X: A repo for distributed training of language models with Reinforcement Learning via Human Feedback (RLHF) +keywords = + Deep Learning + Reinforcement Learning +license = MIT long_description = file: README.md long_description_content_type = text/markdown -license = MIT +name = trlx +url = https://github.com/CarperAI/trlx +version = v0.4.0-dev [options] +python_requires = + >=3.8.0 packages = find: install_requires = accelerate>=0.12.0 + accelerate~=0.15.0 datasets deepspeed>=0.7.3 einops>=0.4.1 + networkx + numpy>=1.17 numpy>=1.23.2 + packaging>=20.0 + psutil + pyyaml + ray + ray>=2.0.1 + tabulate>=0.9.0 + torch~=1.13.0 torchtyping - transformers>=4.21.2 + torchtyping~=0.1.4 tqdm + transformers>=4.21.2 + typing-extensions~=3.10.0 rich - wandb>=0.13.5 - ray>=2.0.1 - tabulate>=0.9.0 - networkx + +classifiers = + Development Status :: 3 - Alpha + Intended Audience :: Developers + Intended Audience :: Education + Intended Audience :: Science/Research + License :: OSI Approved :: MIT>=0.13.5 + Operating System :: OS Independent + Programming Language :: Python :: 3 + Programming Language :: Python :: 3.8 + Topic :: Scientific/Engineering :: Artificial Intelligence [options.extras_require] -bnb = bitsandbytes +bnb = + bitsandbytes +wandb = + wandb dev = black isort flake8 - pre-commit - pytest + pre-commit>= 2.21.0 + pytest>=6.0 pytest-cov [options.packages.find] diff --git a/setup.py b/setup.py index 606849326..e3d859351 100644 --- a/setup.py +++ b/setup.py @@ -1,3 +1,35 @@ -from setuptools import setup +from setuptools import find_packages, setup -setup() +extras = {} +extras["quality"] = [ + "black ~= 22.10.0", + "isort >= 5.5.4", + "flake8 >= 3.8.3", +] +extras["docs"] = ["sphinx==4.0.0", "sphinx_rtd_theme"] +extras["test_prod"] = ["pytest", "pytest-xdist", "pytest-subtests", "parameterized"] +extras["test_dev"] = [ + "datasets", + "evaluate", + "transformers", + "scipy", + "scikit-learn", + "deepspeed<0.7.7", + "tqdm", +] +extras["testing"] = extras["test_prod"] + extras["test_dev"] +extras["rich"] = ["rich"] + +extras["test_trackers"] = ["wandb"] +extras["dev"] = extras["quality"] + extras["testing"] + extras["rich"] + + +setup( + packages=find_packages("."), + entry_points={ + "console_scripts": [ + "trlx=trlx.commands.trlx_cli:main", + ] + }, + +) diff --git a/trlx/__init__.py b/trlx/__init__.py index 7b26a92a9..e84114dc7 100644 --- a/trlx/__init__.py +++ b/trlx/__init__.py @@ -1,2 +1 @@ from .trlx import train -from .utils import logging diff --git a/trlx/orchestrator/offline_orchestrator.py b/trlx/orchestrator/offline_orchestrator.py index d426e7aad..90207b19a 100644 --- a/trlx/orchestrator/offline_orchestrator.py +++ b/trlx/orchestrator/offline_orchestrator.py @@ -1,16 +1,11 @@ -import os from typing import List, Union import numpy as np import torch -from rich.console import Console -from rich.table import Table -import trlx.utils.logging as logging from trlx.orchestrator import Orchestrator, register_orchestrator from trlx.pipeline.offline_pipeline import ILQLRolloutStorage - -logger = logging.get_logger(__name__) +from trlx.utils import print_rank_0 def tokenize_dialogue( # noqa: C901 @@ -65,8 +60,6 @@ def make_experience(self, samples, rewards, max_length=2048): """ Tokenizes samples and shapes rewards into proper tensors and then inserts the resulting dataset into the trainer """ - logger.info("Collecting rollouts") - if self.trainer.tokenizer: samples = [tokenize_dialogue(s, self.trainer.tokenizer, max_length) for s in samples] @@ -91,29 +84,26 @@ def make_experience(self, samples, rewards, max_length=2048): all_actions_ixs.append(torch.hstack(actions_ixs)) all_states_ixs.append(states_ixs) - if self.trainer.tokenizer and os.environ.get("RANK", "0") == "0": - logger.info("Logging sample example") + if self.trainer.tokenizer: prompt = self.trainer.tokenizer.decode(all_input_ids[0][: all_states_ixs[0][1]]) response = self.trainer.tokenizer.decode(all_input_ids[0][all_states_ixs[0][1] :]) - columns = ["Prompt", "Response", "Reward"] - table = Table(*columns, title="Sample Example", show_lines=True) - table.add_row(prompt, response, str(rewards[0])) - Console().print(table) + print_rank_0("[Sample example]") + print_rank_0("Prompt: ", prompt) + print_rank_0("Response: ", response) + print_rank_0("Reward: ", rewards[0]) sample_lengths = np.array(list(map(len, all_input_ids))) output_lengths = np.array(list(map(len, all_actions_ixs))) prompt_lengths = sample_lengths - output_lengths returns = torch.tensor(rewards, dtype=float) - if os.environ.get("RANK", "0") == "0": - logger.info("Logging experience string statistics") - columns = ["Prompt Length", "Output Length", "Sample Length"] - table = Table(*columns, title="Experience String Stats (mean ∈ \[min, max])", show_lines=True) - row = [] - for lengths in [prompt_lengths, output_lengths, sample_lengths]: - row.append(f"{lengths.mean():.2f} ∈ [{min(lengths)}, {max(lengths)}]") - table.add_row(*row) - Console().print(table) + def string_stats(name: str, xs: np.array): + return f"[Mean {name}] {xs.mean():.2f} ∈ [{min(xs)}, {max(xs)}]" + + print_rank_0(string_stats("prompt length", prompt_lengths)) + print_rank_0(string_stats("output length", output_lengths)) + print_rank_0(string_stats("sample length", sample_lengths)) + print_rank_0(string_stats("return", returns)) returns = (returns - returns.mean()) / (returns.std() + 1e-30) rewards = [torch.zeros(len(x)) for x in all_actions_ixs] diff --git a/trlx/orchestrator/ppo_orchestrator.py b/trlx/orchestrator/ppo_orchestrator.py index 4f73f01c3..cd2b8f674 100644 --- a/trlx/orchestrator/ppo_orchestrator.py +++ b/trlx/orchestrator/ppo_orchestrator.py @@ -1,11 +1,9 @@ -import os from time import time import ray import torch import torch.nn.functional as F -import trlx.utils.logging as logging from trlx.data.accelerate_base_datatypes import PromptBatch from trlx.data.ppo_types import PPORLElement from trlx.orchestrator import Orchestrator, register_orchestrator @@ -14,8 +12,6 @@ from trlx.utils import Clock from trlx.utils.modeling import RunningMoments, logprobs_from_logits -logger = logging.get_logger(__name__) - @register_orchestrator class PPOOrchestrator(Orchestrator): @@ -57,24 +53,11 @@ def score(self, samples): def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noqa: """ Takes `num_rollouts` prompts from `pipeline`, samples model and computes the - KL againts a reference model. It then appends PPOElements to trainer's `store` + KL against a reference model. It then appends PPOElements to trainer's `store` """ - logger.info("Collecting rollouts") - tbar = logging.tqdm( - total=num_rollouts, - disable=os.environ.get("RANK", 0) != "0", - desc=f"[rollout 0 / {num_rollouts}]", - # Lower progress bar by 1 if we're in WARNING mode or above to avoid hiding high priority progress - # bars (e.g. loss progress in trainers) - position=logging.get_verbosity() >= logging.WARNING, - # Leave progress bar if we're in INFO mode or lower to avoid spamming in suppressed verbosity levels - leave=logging.get_verbosity() < logging.WARNING, - ) - ppo_rl_elements = [] stats = {} clock = Clock() - while len(ppo_rl_elements) < num_rollouts: # Get next batch in prompt dataset and refresh if exhausted try: @@ -92,7 +75,7 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq str_samples, str_prompts, str_outputs = self.trainer.decode(query_tensors, samples) # Convert trimmed samples back into tensors for another head pass - # This can be defered, instead letting the pass to made over the original samples + # This can be deferred, instead letting the pass to made over the original samples # after unbinding and truncating operations lower are fixed outputs = self.trainer.tokenizer(str_outputs).input_ids outputs = list(map(torch.LongTensor, outputs)) @@ -215,7 +198,6 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq rewards = -self.trainer.kl_ctl.value * (logprobs - ref_logprobs) rewards = [rs[start : ends[ix]] for ix, rs in enumerate(rewards)] - rollout_count = 0 for ix in range(n): if len(rewards[ix]) == 0 or len(all_logprobs[ix]) == 0: continue @@ -231,11 +213,8 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq rewards=rewards[ix], ) ) - rollout_count += 1 + exp_time = clock.tick() - tbar.set_description(f"[rollout {len(ppo_rl_elements)} / {num_rollouts}]") - tbar.update(min(rollout_count, num_rollouts)) - tbar.close() stats["kl_ctl_value"] = self.trainer.kl_ctl.value stats["time/exp"] = exp_time diff --git a/trlx/trainer/__init__.py b/trlx/trainer/__init__.py index 68142bab2..63bf189da 100644 --- a/trlx/trainer/__init__.py +++ b/trlx/trainer/__init__.py @@ -76,7 +76,7 @@ def sample(self, prompts: Iterable[str], length: int, n_samples: int) -> Iterabl :param prompts: List of prompts to tokenize and use as context - :param length: How many new tokens to genrate for each prompt + :param length: How many new tokens to generate for each prompt :type length: int :param n_samples: Default behavior is to take number of prompts as this diff --git a/trlx/trainer/accelerate_base_trainer.py b/trlx/trainer/accelerate_base_trainer.py index 11c0ce68d..a76ee5cbb 100644 --- a/trlx/trainer/accelerate_base_trainer.py +++ b/trlx/trainer/accelerate_base_trainer.py @@ -13,9 +13,9 @@ from ray.air.checkpoint import Checkpoint from rich.console import Console from rich.table import Table +from tqdm import tqdm from transformers import AutoTokenizer -import trlx.utils.logging as logging from trlx.data.configs import TRLConfig from trlx.trainer import BaseRLTrainer, register_trainer from trlx.utils import ( @@ -24,6 +24,7 @@ get_git_tag, get_optimizer_class, get_scheduler_class, + print_rank_0, significant, ) from trlx.utils.modeling import ( @@ -34,8 +35,6 @@ parse_delta_kwargs, ) -logger = logging.get_logger(__name__) - @register_trainer class AccelerateRLTrainer(BaseRLTrainer): @@ -117,8 +116,6 @@ def setup_model(self): """ Returns a model derived from an instance's TRLConfig """ - logger.info(f"Initializing model: {self.config.model.model_path}") - # Retrieves model equipped for ppo, ilql, etc model = self.get_arch(self.config) if self.config.model.model_arch_type == "seq2seq": @@ -282,7 +279,8 @@ def add_eval_pipeline(self, eval_pipeline): def evaluate(self): # noqa: C901 """Samples model on `eval_prompts`, logs stats with `reward_fn` or `metric_fn` if provided""" - logger.info("Evaluating model") + stats = {} + table = [] # Do multiple evaluations over a single list in `gen_kwargs` if present if self.generate_sweep_kwarg is not None: @@ -290,22 +288,7 @@ def evaluate(self): # noqa: C901 else: gen_sweep_values = [None] - desc = [ - f"generation sweep 0/{len(gen_sweep_values)}", - f"eval batch 0/{len(self.eval_dataloader)}", - ] - tbar = logging.tqdm( - total=len(self.eval_dataloader) * len(gen_sweep_values), - desc=f"[{' | '.join(desc)}]", - disable=not self.accelerator.is_main_process, - position=0, - leave=True, - ) - - stats = {} - table = [] - - for i_sweep, gen_sweep_value in enumerate(gen_sweep_values): + for gen_sweep_value in gen_sweep_values: # A dedicated suffix for wandb logging if gen_sweep_value is not None: sweep_suffix = f"@{gen_sweep_arg}={gen_sweep_value}" @@ -316,7 +299,7 @@ def evaluate(self): # noqa: C901 all_prompts = [] prompt_sizes = [] generate_time = time() - for i_prompt, prompts in enumerate(self.eval_dataloader): + for prompts in self.eval_dataloader: if self.generate_sweep_kwarg: samples = self.generate_eval(**prompts, **{gen_sweep_arg: gen_sweep_value}) else: @@ -343,14 +326,6 @@ def evaluate(self): # noqa: C901 torch.tensor(prompts.input_ids.shape[1], device=samples.device).repeat(len(prompts.input_ids)) ) - desc = [ - f"generation sweep {i_sweep + 1}/{len(gen_sweep_values)}", - f"eval batch {i_prompt + 1}/{len(self.eval_dataloader)}", - ] - tbar.set_description(f"[{' | '.join(desc)}]") - tbar.update() - tbar.close() - stats["time/generate"] = time() - generate_time samples = self.accelerator.gather(torch.vstack(all_samples)) @@ -365,7 +340,6 @@ def evaluate(self): # noqa: C901 # in online setting, compute the reward for validation if self.reward_fn: - logger.info("Computing rewards") rewards = torch.tensor( self.reward_fn( samples=str_samples, @@ -383,7 +357,6 @@ def evaluate(self): # noqa: C901 # additionally log any other metrics if self.metric_fn: - logger.info("Computing metrics") metric_time = time() metrics = self.metric_fn( samples=str_samples, @@ -412,7 +385,6 @@ def evaluate(self): # noqa: C901 table.append(list(zip(*columns_data))) # Log and display evaluation metrics - logger.info("Summarizing evaluation") if self.accelerator.is_main_process: rows = sum(list(map(list, zip(*table))), []) @@ -423,9 +395,9 @@ def evaluate(self): # noqa: C901 table_title += f" {k}: {significant(x)}" rich_table = Table(*columns, title=table_title, show_lines=True) + for ix in range(max(min(3, len(rows)), len(gen_sweep_values))): rich_table.add_row(*[str(significant(x)) for x in rows[ix]]) - Console().print(rich_table) if not ray.is_initialized(): if self.config.train.tracker == "wandb": @@ -433,6 +405,8 @@ def evaluate(self): # noqa: C901 stats["samples"] = wandb.Table(columns, rows) + Console().print(rich_table) + self.nth_evaluation += 1 return stats @@ -440,13 +414,11 @@ def learn(self): # noqa: C901 """ Samples batches from `self.store`, updates model and periodically evaluates it on `self.eval_dataloader` """ - logger.info("Starting training") - self.generate_sweep_kwarg = None for k, v in self.config.method.gen_kwargs.items(): if isinstance(v, list): if self.generate_sweep_kwarg is not None: - logger.info("Only a single sweep is allowed, {k} is going to be set to {v[0]}") + print_rank_0("Only a single sweep is allowed, {k} is going to be set to {v[0]}") self.generate_kwargs[k] = v[0] else: self.generate_sweep_kwarg = (k, v) @@ -468,12 +440,10 @@ def learn(self): # noqa: C901 results = self.evaluate() self.accelerator.log(results, step=self.iter_count) - tbar = logging.tqdm( + tbar = tqdm( initial=self.iter_count, total=self.total_steps, disable=not self.accelerator.is_local_main_process, - position=0, - leave=True, ) best_reward = -float("inf") @@ -521,7 +491,7 @@ def learn(self): # noqa: C901 torch.distributed.all_reduce(do_save, torch.distributed.ReduceOp.MAX) if do_save: best_path = f"{self.config.train.checkpoint_dir}/best_checkpoint" - logger.info(f"Saving the best state so far into {best_path}") + print_rank_0(f"saving the best state so far into {best_path}") self.save(best_path) # Report the metrics to Ray Tune. @@ -535,8 +505,8 @@ def learn(self): # noqa: C901 if not ray.is_initialized(): self.accelerator.log(stats, step=self.iter_count) - desc = " | ".join(f"{k}: {v:.2f}" for k, v in stats.items() if k.startswith("loss")) - tbar.set_description(f"[{desc}]") + desc = ", ".join(f"{k}: {v:.2f}" for k, v in stats.items() if k.startswith("loss")) + tbar.set_description(desc) tbar.update() if self.iter_count >= self.total_steps: @@ -546,7 +516,6 @@ def learn(self): # noqa: C901 self.post_backward_callback() self.post_epoch_callback() - tbar.close() @abstractmethod def get_arch(self, config: TRLConfig): diff --git a/trlx/trainer/accelerate_ilql_trainer.py b/trlx/trainer/accelerate_ilql_trainer.py index 35d0031b6..3c059b868 100644 --- a/trlx/trainer/accelerate_ilql_trainer.py +++ b/trlx/trainer/accelerate_ilql_trainer.py @@ -87,7 +87,7 @@ def save_pretrained(self, directory: Optional[str] = None): of the Trainer config checkpoint dir named "hf_model" (e.g. `/ckpts/hf_model`). """ # TODO: Support saving with `transformers.PreTrainedModel.save_pretrained`. - # This is currently not supported becasue `nn.ilql_models.CausalLMWithValueHeads` + # This is currently not supported because `nn.ilql_models.CausalLMWithValueHeads` # requires a custom `generate` method using its (value/q) heads to steer # sampling - something that is not possible with the default # `transformers.PreTrainedModel.generate`. diff --git a/trlx/trainer/nemo/gpt.py b/trlx/trainer/nemo/gpt.py index 89eb2554b..ca1abc51f 100644 --- a/trlx/trainer/nemo/gpt.py +++ b/trlx/trainer/nemo/gpt.py @@ -666,7 +666,7 @@ def fwd_output_and_loss_func(batch: List[torch.Tensor], model, checkpoint_activa ) else: # In-between stages are given data via the pipeline engine - # Still need to specify thes arguments to avoid errors + # Still need to specify these arguments to avoid errors model_output = model(input_ids=None, position_ids=None, attention_mask=None) def gather_ntc(t: torch.Tensor): diff --git a/trlx/trlx.py b/trlx/trlx.py index a6b269f81..fb1958544 100644 --- a/trlx/trlx.py +++ b/trlx/trlx.py @@ -34,14 +34,14 @@ def train( prompts (List[str]): Prompts to sample off from during online training eval_prompts (List[str]): Prompts to periodically validate training on metric_fn (Optional[Callable[[List[str], List[str], List[str]], Dict[str, List[float]]]]): - Function to compute statistics on batches of gnerated samples. Its arguments are the same + Function to compute statistics on batches of generated samples. Its arguments are the same as in `reward_fn` (`samples`, `prompts`, `outputs`) but the return is dictionary with keys as metric's name and values and lists of numeric values per each sample in batch config (Optional[TRLConfig]): TRL configuration object to override default settings logit_mask (Optional[List]): Bigram masking matrix stop_sequences (Optional[List[str]]): String sequences to trim generations (either for experience or evaluation) up to its - encounter in them. Generatations will not contain them and also will be right-stripped + encounter in them. Generations will not contain them and also will be right-stripped """ if reward_fn is not None: diff --git a/trlx/utils/logging.py b/trlx/utils/logging.py deleted file mode 100644 index 79badb4a3..000000000 --- a/trlx/utils/logging.py +++ /dev/null @@ -1,340 +0,0 @@ -# Copyright 2023 Optuna, Hugging Face, CarperAI -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Logging utilities.""" - -import logging -import os -import sys -import threading -from logging import CRITICAL # NOQA -from logging import DEBUG # NOQA -from logging import ERROR # NOQA -from logging import FATAL # NOQA -from logging import INFO # NOQA -from logging import NOTSET # NOQA -from logging import WARN # NOQA -from logging import WARNING # NOQA -from typing import Optional - -import torch -from tqdm import auto as tqdm_lib - -_lock = threading.Lock() -_default_handler: Optional[logging.Handler] = None - -log_levels = { - "debug": logging.DEBUG, - "info": logging.INFO, - "warning": logging.WARNING, - "error": logging.ERROR, - "critical": logging.CRITICAL, -} - -_default_log_level = logging.INFO - - -def _get_default_logging_level(): - """ - If `TRLX_VERBOSITY` env var is set to one of the valid choices, return that as the new default level. If it is - not - fall back to `_default_log_level` - """ - env_level_str = os.getenv("TRLX_VERBOSITY", None) - if env_level_str: - if env_level_str.lower() in log_levels: - return log_levels[env_level_str.lower()] - else: - logging.getLogger().warning( - f"Unknown option TRLX_VERBOSITY={env_level_str}, " f"has to be one of: { ', '.join(log_levels.keys()) }" - ) - return _default_log_level - - -def _get_library_name() -> str: - return __name__.split(".")[0] - - -def _get_library_root_logger() -> logging.Logger: - return logging.getLogger(_get_library_name()) - - -def _configure_library_root_logger() -> None: - global _default_handler - - with _lock: - if _default_handler: - # This library has already configured the library root logger. - return - _default_handler = logging.StreamHandler() # Set sys.stderr as stream. - _default_handler.flush = sys.stderr.flush - - # Apply our default configuration to the library root logger. - library_root_logger = _get_library_root_logger() - library_root_logger.addHandler(_default_handler) - library_root_logger.setLevel(_get_default_logging_level()) - library_root_logger.propagate = False - - -def _reset_library_root_logger() -> None: - global _default_handler - - with _lock: - if not _default_handler: - return - - library_root_logger = _get_library_root_logger() - library_root_logger.removeHandler(_default_handler) - library_root_logger.setLevel(logging.NOTSET) - _default_handler = None - - -def get_log_levels_dict(): - return log_levels - - -class MultiProcessAdapter(logging.LoggerAdapter): - """A logger adapter for handling multi-process logging""" - - def log(self, level, msg, *args, **kwargs): - """ - Consumes an additional kwarg called `ranks` to determine which processes should log. - NOTE: To specify all processes, pass in an empty list `ranks=[]` - - Default: ["0"], i.e. only the main process logs - """ - # By default, silence all non-main processes - ranks = kwargs.pop("ranks", ["0"]) - should_log = os.environ.get("RANK", "0") in ranks or len(ranks) == 0 - if self.isEnabledFor(level) and should_log: - msg, kwargs = self.process(msg, kwargs) - self.logger._log(level, msg, args, **kwargs) - - def process(self, msg, kwargs): - this_rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 - return f"[RANK {this_rank}] {msg}", kwargs - - -def get_logger(name: Optional[str] = None) -> MultiProcessAdapter: - """ - Returns a `logging.Logger` for `name` that can handle multiple processes - - Args: - name: Name of the logger - - Usage: - >> logger = get_logger(__name__) - >> logger.debug("Check the...", ranks=["0", "1"]) # Only main and rank 1 log - """ - if name is None: - name = _get_library_name() - _configure_library_root_logger() - logger = logging.getLogger(name) - return MultiProcessAdapter(logger, {}) - - -def get_verbosity() -> int: - """ - Return the current level for trlx's root logger as an int. - Returns: - `int`: The logging level. - - trlx has following logging levels: - - 50: `trlx.logging.CRITICAL` or `trlx.logging.FATAL` - - 40: `trlx.logging.ERROR` - - 30: `trlx.logging.WARNING` or `trlx.logging.WARN` - - 20: `trlx.logging.INFO` - - 10: `trlx.logging.DEBUG` - - """ - - _configure_library_root_logger() - return _get_library_root_logger().getEffectiveLevel() - - -def set_verbosity(verbosity: int) -> None: - """ - Set the verbosity level for trlX's root logger. - Args: - verbosity (`int`): - Logging level, e.g., one of: - - `trlx.logging.CRITICAL` or `trlx.logging.FATAL` - - `trlx.logging.ERROR` - - `trlx.logging.WARNING` or `trlx.logging.WARN` - - `trlx.logging.INFO` - - `trlx.logging.DEBUG` - """ - - _configure_library_root_logger() - _get_library_root_logger().setLevel(verbosity) - - -def disable_default_handler() -> None: - """Disable the default handler of trlx's root logger.""" - - _configure_library_root_logger() - - assert _default_handler is not None - _get_library_root_logger().removeHandler(_default_handler) - - -def enable_default_handler() -> None: - """Enable the default handler of trlx's root logger.""" - - _configure_library_root_logger() - - assert _default_handler is not None - _get_library_root_logger().addHandler(_default_handler) - - -def add_handler(handler: logging.Handler) -> None: - """Adds a handler to trlx's root logger.""" - - _configure_library_root_logger() - - assert handler is not None - _get_library_root_logger().addHandler(handler) - - -def remove_handler(handler: logging.Handler) -> None: - """Removes given handler from the trlx's root logger.""" - - _configure_library_root_logger() - - assert handler is not None and handler not in _get_library_root_logger().handlers - _get_library_root_logger().removeHandler(handler) - - -def disable_propagation() -> None: - """ - Disable propagation of the library log outputs. Note that log propagation is disabled by default. - """ - - _configure_library_root_logger() - _get_library_root_logger().propagate = False - - -def enable_propagation() -> None: - """ - Enable propagation of the library log outputs. Please disable the trlx's default handler to prevent - double logging if the root logger has been configured. - """ - - _configure_library_root_logger() - _get_library_root_logger().propagate = True - - -def enable_explicit_format() -> None: - """ - Enable explicit formatting for every trlx's logger. The explicit formatter is as follows: - ``` - [ASCTIME] [LEVELNAME] [FILENAME:LINE NUMBER:FUNCNAME] MESSAGE - ``` - All handlers currently bound to the root logger are affected by this method. - """ - handlers = _get_library_root_logger().handlers - - for handler in handlers: - formatter = logging.Formatter( - "[%(asctime)s] [%(levelname)s] [%(filename)s:%(lineno)d:%(funcName)s] %(message)s" - ) - handler.setFormatter(formatter) - - -def reset_format() -> None: - """ - Resets the formatting for trlx's loggers. - All handlers currently bound to the root logger are affected by this method. - """ - handlers = _get_library_root_logger().handlers - - for handler in handlers: - handler.setFormatter(None) - - -def warning_advice(self, *args, **kwargs): - """ - This method is identical to `logger.warning()`, but if env var TRLX_NO_ADVISORY_WARNINGS=1 is set, this - warning will not be printed - """ - no_advisory_warnings = os.getenv("TRLX_NO_ADVISORY_WARNINGS", False) - if no_advisory_warnings: - return - self.warning(*args, **kwargs) - - -logging.Logger.warning_advice = warning_advice - - -class EmptyTqdm: - """Dummy tqdm which doesn't do anything.""" - - def __init__(self, *args, **kwargs): # pylint: disable=unused-argument - self._iterator = args[0] if args else None - - def __iter__(self): - return iter(self._iterator) - - def __getattr__(self, _): - """Return empty function.""" - - def empty_fn(*args, **kwargs): # pylint: disable=unused-argument - return - - return empty_fn - - def __enter__(self): - return self - - def __exit__(self, type_, value, traceback): - return - - -_tqdm_active = True - - -class _tqdm_cls: - def __call__(self, *args, **kwargs): - if _tqdm_active: - return tqdm_lib.tqdm(*args, **kwargs) - else: - return EmptyTqdm(*args, **kwargs) - - def set_lock(self, *args, **kwargs): - self._lock = None - if _tqdm_active: - return tqdm_lib.tqdm.set_lock(*args, **kwargs) - - def get_lock(self): - if _tqdm_active: - return tqdm_lib.tqdm.get_lock() - - -tqdm = _tqdm_cls() - - -def is_progress_bar_enabled() -> bool: - """Return a boolean indicating whether tqdm progress bars are enabled.""" - global _tqdm_active - return bool(_tqdm_active) - - -def enable_progress_bar(): - """Enable tqdm progress bar.""" - global _tqdm_active - _tqdm_active = True - - -def disable_progress_bar(): - """Disable tqdm progress bar.""" - global _tqdm_active - _tqdm_active = False