Skip to content

Commit

Permalink
Merge pull request #1 from DavideTr8/feture/unbatched_jssp_l2d
Browse files Browse the repository at this point in the history
Feture/unbatched jssp l2d
  • Loading branch information
DavideTr8 authored Jan 28, 2024
2 parents 1a2da37 + f0b26ac commit c385153
Show file tree
Hide file tree
Showing 6 changed files with 862 additions and 1 deletion.
135 changes: 135 additions & 0 deletions notebooks/unbatched_jssp.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
{
"cells": [
{
"cell_type": "markdown",
"source": [
"# Tests on the JSSP"
],
"metadata": {
"collapsed": false
},
"id": "bf32bb4a17ac110a"
},
{
"cell_type": "markdown",
"source": [
"## Imports"
],
"metadata": {
"collapsed": false
},
"id": "c09ce06456e8b120"
},
{
"cell_type": "code",
"outputs": [],
"source": [
"import torch\n",
"from tensordict import TensorDict\n",
"from rl4co.envs.scheduling.jssp import JSSPEnv"
],
"metadata": {
"collapsed": false
},
"id": "64a468fadfb58d84",
"execution_count": null
},
{
"cell_type": "markdown",
"source": [
"Test the JSSP on the OR-tools instance: https://developers.google.com/optimization/scheduling/job_shop\"\"\"\n"
],
"metadata": {
"collapsed": false
},
"id": "18c1cd1b1d512f88"
},
{
"cell_type": "code",
"outputs": [],
"source": [
"low = 1\n",
"high = 99\n",
"n_j = 3\n",
"n_m = 3\n",
"torch.manual_seed(8)\n",
"env = JSSPEnv(n_j, n_m)\n",
"durations = torch.tensor([[3, 2, 2], [2, 1, 4], [4, 3, 0]], dtype=torch.float32)\n",
"machines = torch.tensor([[1, 2, 3], [1, 3, 2], [2, 3, 1]], dtype=torch.int32)\n",
"data = TensorDict(\n",
" {\"durations\": durations.unsqueeze(0), \"machines\": machines.unsqueeze(0)},\n",
" batch_size=1,\n",
")\n",
"ortools_sol = [0, 1, 2, 0, 1, 1, 0, 2, 2]\n",
"env.reset(data)\n",
"total_reward = 0\n",
"for action in ortools_sol:\n",
" data[\"action\"] = torch.tensor([action], dtype=torch.long)\n",
" td = env._step(data)\n",
" total_reward += td[\"reward\"].item()\n",
"\n",
"# env.render()\n",
"make_span = env.initial_quality - total_reward\n",
"assert make_span.item() == 11"
],
"metadata": {
"collapsed": false
},
"id": "5076924b5b09bc6a",
"execution_count": null
},
{
"cell_type": "markdown",
"source": [
"## Test with parallel envs"
],
"metadata": {
"collapsed": false
},
"id": "d94bd6bd3d57452a"
},
{
"cell_type": "code",
"outputs": [],
"source": [
"from torchrl.envs.vec_envs import ParallelEnv\n",
"envs = ParallelEnv(2, lambda: JSSPEnv(3, 3))\n",
"td = envs.reset()\n",
"td[\"action\"] = torch.tensor([[0], [1]])\n",
"td = envs.step(td)\n",
"\n",
"assert torch.allclose(td[\"next\"][\"feasible_actions\"], torch.tensor([[[1, 3, 6]], [[0, 4, 6]]]))\n",
"\n",
"td[\"action\"] = torch.tensor([[0], [1]])\n",
"td = envs.step(td)\n",
"assert torch.allclose(td[\"next\"][\"feasible_actions\"], torch.tensor([[[2, 3, 6]], [[0, 5, 6]]]))"
],
"metadata": {
"collapsed": false
},
"id": "ccfbbadbfa49d031",
"execution_count": null
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
3 changes: 2 additions & 1 deletion rl4co/envs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
)

# Scheduling
from rl4co.envs.scheduling import FFSPEnv, SMTWTPEnv
from rl4co.envs.scheduling import FFSPEnv, JSSPEnv, SMTWTPEnv

# Register environments
ENV_REGISTRY = {
Expand All @@ -35,6 +35,7 @@
"spctsp": SPCTSPEnv,
"tsp": TSPEnv,
"smtwtp": SMTWTPEnv,
"jssp": JSSPEnv,
}


Expand Down
1 change: 1 addition & 0 deletions rl4co/envs/scheduling/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from rl4co.envs.scheduling.ffsp import FFSPEnv
from rl4co.envs.scheduling.jssp import JSSPEnv
from rl4co.envs.scheduling.smtwtp import SMTWTPEnv
Loading

0 comments on commit c385153

Please sign in to comment.