Skip to content

Commit

Permalink
[Feat] automatic testing
Browse files Browse the repository at this point in the history
  • Loading branch information
fedebotu committed Jul 19, 2024
1 parent af7c37b commit db251b0
Show file tree
Hide file tree
Showing 4 changed files with 151 additions and 0 deletions.
38 changes: 38 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
name: Tests
on: [push, pull_request]

jobs:
build:
runs-on: ${{ matrix.os }}
strategy:
fail-fast: true
max-parallel: 15
matrix:
os: [ubuntu-latest, macos-latest, windows-latest]
python-version: ['3.11']
defaults:
run:
shell: bash
steps:
- name: Check out repository
uses: actions/checkout@v3

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}

- name: Load cached venv
id: cached-pip-wheels
uses: actions/cache@v3
with:
path: ~/.cache
key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('**/poetry.lock') }}

- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -e ".[all]"
- name: Run pytest
run: pytest tests/*.py
46 changes: 46 additions & 0 deletions tests/env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import pytest

from routefinder.envs.mtvrp import MTVRPEnv, MTVRPGenerator
from routefinder.models import RouteFinderPolicy
from routefinder.utils import greedy_policy, rollout


@pytest.mark.parametrize(
"variant_preset",
[
"all",
"single_feat",
"single_feat_otw",
"cvrp",
"ovrp",
"vrpb",
"vrpl",
"vrptw",
"ovrptw",
"ovrpb",
"ovrpl",
"vrpbl",
"vrpbtw",
"vrpltw",
"ovrpbl",
"ovrpbtw",
"ovrpltw",
"vrpbltw",
"ovrpbltw",
],
)
def test_env(variant_preset):
# Sample all variants in the same batch (Mixed-Batch Training)
generator = MTVRPGenerator(num_loc=10, variant_preset=variant_preset)
env = MTVRPEnv(generator, check_solution=True)
td_data = env.generator(3)
td_test = env.reset(td_data)
actions = rollout(env, td_test.clone(), greedy_policy)
rewards_nearest_neighbor = env.get_reward(td_test, actions)
assert rewards_nearest_neighbor.shape == (3,)

policy = RouteFinderPolicy()
out = policy(
td_test.clone(), env, phase="test", decode_type="greedy", return_actions=True
)
assert out["reward"].shape == (3,)
41 changes: 41 additions & 0 deletions tests/policy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import pytest

from routefinder.envs.mtvrp import MTVRPEnv, MTVRPGenerator
from routefinder.models import RouteFinderPolicy


@pytest.mark.parametrize(
"variant_preset",
[
"all",
"single_feat",
"single_feat_otw",
"cvrp",
"ovrp",
"vrpb",
"vrpl",
"vrptw",
"ovrptw",
"ovrpb",
"ovrpl",
"vrpbl",
"vrpbtw",
"vrpltw",
"ovrpbl",
"ovrpbtw",
"ovrpltw",
"vrpbltw",
"ovrpbltw",
],
)
def test_policy(variant_preset):
# Sample all variants in the same batch (Mixed-Batch Training)
generator = MTVRPGenerator(num_loc=10, variant_preset=variant_preset)
env = MTVRPEnv(generator, check_solution=True)
td_data = env.generator(3)
td_test = env.reset(td_data)
policy = RouteFinderPolicy()
out = policy(
td_test.clone(), env, phase="test", decode_type="greedy", return_actions=True
)
assert out["reward"].shape == (3,)
26 changes: 26 additions & 0 deletions tests/training.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from rl4co.utils.trainer import RL4COTrainer

from routefinder.envs.mtvrp import MTVRPEnv
from routefinder.models import RouteFinderBase, RouteFinderPolicy


def test_training():
env = MTVRPEnv(generator_params={"num_loc": 10, "variant_preset": "all"})
policy = RouteFinderPolicy()
model = RouteFinderBase(
env,
policy,
batch_size=3,
train_data_size=3,
val_data_size=3,
test_data_size=3,
optimizer_kwargs={"lr": 3e-4, "weight_decay": 1e-6},
)
trainer = RL4COTrainer(
max_epochs=1,
gradient_clip_val=None,
devices=1,
accelerator="auto",
)
trainer.fit(model)
trainer.test(model)

0 comments on commit db251b0

Please sign in to comment.