Skip to content

Commit

Permalink
Add SMTWTP env
Browse files Browse the repository at this point in the history
  • Loading branch information
henry-yeh committed Oct 7, 2023
1 parent 9f2a579 commit 3c0ac1e
Show file tree
Hide file tree
Showing 5 changed files with 285 additions and 0 deletions.
2 changes: 2 additions & 0 deletions rl4co/envs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from rl4co.envs.pctsp import PCTSPEnv
from rl4co.envs.pdp import PDPEnv
from rl4co.envs.sdvrp import SDVRPEnv
from rl4co.envs.smtwtp import SMTWTPEnv
from rl4co.envs.spctsp import SPCTSPEnv
from rl4co.envs.tsp import TSPEnv

Expand All @@ -27,6 +28,7 @@
"sdvrp": SDVRPEnv,
"spctsp": SPCTSPEnv,
"tsp": TSPEnv,
"smtwtp": SMTWTPEnv,
}


Expand Down
239 changes: 239 additions & 0 deletions rl4co/envs/smtwtp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,239 @@
from typing import Optional

import torch

from tensordict.tensordict import TensorDict
from torchrl.data import (
BoundedTensorSpec,
CompositeSpec,
UnboundedContinuousTensorSpec,
UnboundedDiscreteTensorSpec,
)

from rl4co.envs.common.base import RL4COEnvBase
from rl4co.utils.pylogger import get_pylogger

log = get_pylogger(__name__)


class SMTWTPEnv(RL4COEnvBase):
"""
Single Machine Total Weighted Tardiness Problem environment as described in DeepACO (https://arxiv.org/pdf/2309.14032.pdf)
SMTWTP is a scheduling problem in which a set of jobs must be processed on a single machine.
Each job i has a processing time, a weight, and a due date. The objective is to minimize the sum of the weighted tardiness of all jobs,
where the weighted tardiness of a job is defined as the product of its weight and the duration by which its completion time exceeds its due date.
At each step, the agent chooses a job to process. The reward is the -infinite unless the agent processes all the jobs.
In that case, the reward is (-)objective value of the processing order: maximizing the reward is equivalent to minimizing the objective.
Args:
num_job: number of jobs
min_time_span: lower bound of jobs' due time. By default, jobs' due time is uniformly sampled from (min_time_span, max_time_span)
max_time_span: upper bound of jobs' due time. By default, it will be set to num_job / 2
min_job_weight: lower bound of jobs' weights. By default, jobs' weights are uniformly sampled from (min_job_weight, max_job_weight)
max_job_weight: upper bound of jobs' weights
min_process_time: lower bound of jobs' process time. By default, jobs' process time is uniformly sampled from (min_process_time, max_process_time)
max_process_time: upper bound of jobs' process time
td_params: parameters of the environment
seed: seed for the environment
device: device to use. Generally, no need to set as tensors are updated on the fly
"""

name = "smtwtp"

def __init__(
self,
num_job: int = 10,
min_time_span: float = 0,
max_time_span: float = None, # will be set to num_job/2 by default. In DeepACO, it is set to num_job, which would be too simple
min_job_weight: float = 0,
max_job_weight: float = 1,
min_process_time: float = 0,
max_process_time: float = 1,
td_params: TensorDict = None,
**kwargs,
):
super().__init__(**kwargs)
self.num_job = num_job
self.min_time_span = min_time_span
self.max_time_span = num_job / 2 if max_time_span is None else max_time_span
self.min_job_weight = min_job_weight
self.max_job_weight = max_job_weight
self.min_process_time = min_process_time
self.max_process_time = max_process_time
self._make_spec(td_params)

Check warning on line 63 in rl4co/envs/smtwtp.py

View check run for this annotation

Codecov / codecov/patch

rl4co/envs/smtwtp.py#L55-L63

Added lines #L55 - L63 were not covered by tests

@staticmethod
def _step(td: TensorDict) -> TensorDict:
current_job = td["action"]

Check warning on line 67 in rl4co/envs/smtwtp.py

View check run for this annotation

Codecov / codecov/patch

rl4co/envs/smtwtp.py#L67

Added line #L67 was not covered by tests

# Set not visited to 0 (i.e., we visited the node)
available = td["action_mask"].scatter(

Check warning on line 70 in rl4co/envs/smtwtp.py

View check run for this annotation

Codecov / codecov/patch

rl4co/envs/smtwtp.py#L70

Added line #L70 was not covered by tests
-1, current_job.unsqueeze(-1).expand_as(td["action_mask"]), 0
)

# Increase used time
selected_process_time = td["job_process_time"][

Check warning on line 75 in rl4co/envs/smtwtp.py

View check run for this annotation

Codecov / codecov/patch

rl4co/envs/smtwtp.py#L75

Added line #L75 was not covered by tests
torch.arange(current_job.size(0)), current_job
]
current_time = td["current_time"] + selected_process_time.unsqueeze(-1)

Check warning on line 78 in rl4co/envs/smtwtp.py

View check run for this annotation

Codecov / codecov/patch

rl4co/envs/smtwtp.py#L78

Added line #L78 was not covered by tests

# We are done there are no unvisited locations
done = torch.count_nonzero(available, dim=-1) <= 0

Check warning on line 81 in rl4co/envs/smtwtp.py

View check run for this annotation

Codecov / codecov/patch

rl4co/envs/smtwtp.py#L81

Added line #L81 was not covered by tests

# The reward is calculated outside via get_reward for efficiency, so we set it to -inf here
reward = torch.ones_like(done) * float("-inf")

Check warning on line 84 in rl4co/envs/smtwtp.py

View check run for this annotation

Codecov / codecov/patch

rl4co/envs/smtwtp.py#L84

Added line #L84 was not covered by tests

# The output must be written in a ``"next"`` entry
return TensorDict(

Check warning on line 87 in rl4co/envs/smtwtp.py

View check run for this annotation

Codecov / codecov/patch

rl4co/envs/smtwtp.py#L87

Added line #L87 was not covered by tests
{
"next": {
"job_due_time": td["job_due_time"],
"job_weight": td["job_weight"],
"job_process_time": td["job_process_time"],
"current_job": current_job,
"current_time": current_time,
"action_mask": available,
"reward": reward,
"done": done,
}
},
td.shape,
)

def _reset(self, td: Optional[TensorDict] = None, batch_size=None) -> TensorDict:
# Initialization
if batch_size is None:
batch_size = self.batch_size if td is None else td["job_due_time"].shape[:-1]
batch_size = [batch_size] if isinstance(batch_size, int) else batch_size

Check warning on line 107 in rl4co/envs/smtwtp.py

View check run for this annotation

Codecov / codecov/patch

rl4co/envs/smtwtp.py#L105-L107

Added lines #L105 - L107 were not covered by tests

self.device = device = (

Check warning on line 109 in rl4co/envs/smtwtp.py

View check run for this annotation

Codecov / codecov/patch

rl4co/envs/smtwtp.py#L109

Added line #L109 was not covered by tests
td["job_due_time"].device if td is not None else self.device
)

td = self.generate_data(batch_size) if td is None else td

Check warning on line 113 in rl4co/envs/smtwtp.py

View check run for this annotation

Codecov / codecov/patch

rl4co/envs/smtwtp.py#L113

Added line #L113 was not covered by tests

init_job_due_time = td["job_due_time"]
init_job_process_time = td["job_process_time"]
init_job_weight = td["job_weight"]

Check warning on line 117 in rl4co/envs/smtwtp.py

View check run for this annotation

Codecov / codecov/patch

rl4co/envs/smtwtp.py#L115-L117

Added lines #L115 - L117 were not covered by tests

# Other variables
current_job = torch.zeros((*batch_size, 1), dtype=torch.int64, device=device)
current_time = torch.zeros((*batch_size, 1), dtype=torch.int64, device=device)
available = torch.ones(

Check warning on line 122 in rl4co/envs/smtwtp.py

View check run for this annotation

Codecov / codecov/patch

rl4co/envs/smtwtp.py#L120-L122

Added lines #L120 - L122 were not covered by tests
(*batch_size, self.num_job + 1), dtype=torch.bool, device=device
)
available[:, 0] = 0 # mask the starting dummy node

Check warning on line 125 in rl4co/envs/smtwtp.py

View check run for this annotation

Codecov / codecov/patch

rl4co/envs/smtwtp.py#L125

Added line #L125 was not covered by tests

return TensorDict(

Check warning on line 127 in rl4co/envs/smtwtp.py

View check run for this annotation

Codecov / codecov/patch

rl4co/envs/smtwtp.py#L127

Added line #L127 was not covered by tests
{
"job_due_time": init_job_due_time,
"job_weight": init_job_weight,
"job_process_time": init_job_process_time,
"current_job": current_job,
"current_time": current_time,
"action_mask": available,
},
batch_size=batch_size,
)

def _make_spec(self, td_params: TensorDict = None):
self.observation_spec = CompositeSpec(

Check warning on line 140 in rl4co/envs/smtwtp.py

View check run for this annotation

Codecov / codecov/patch

rl4co/envs/smtwtp.py#L140

Added line #L140 was not covered by tests
job_due_time=BoundedTensorSpec(
minimum=self.min_time_span,
maximum=self.max_time_span,
shape=(self.num_job + 1,),
dtype=torch.float32,
),
job_weight=BoundedTensorSpec(
minimum=self.min_job_weight,
maximum=self.max_job_weight,
shape=(self.num_job + 1,),
dtype=torch.float32,
),
job_process_time=BoundedTensorSpec(
minimum=self.min_process_time,
maximum=self.max_process_time,
shape=(self.num_job + 1,),
dtype=torch.float32,
),
current_node=UnboundedDiscreteTensorSpec(
shape=(1,),
dtype=torch.int64,
),
action_mask=UnboundedDiscreteTensorSpec(
shape=(self.num_job + 1,),
dtype=torch.bool,
),
current_time=UnboundedContinuousTensorSpec(
shape=(1,),
dtype=torch.float32,
),
shape=(),
)
self.input_spec = self.observation_spec.clone()
self.action_spec = BoundedTensorSpec(

Check warning on line 174 in rl4co/envs/smtwtp.py

View check run for this annotation

Codecov / codecov/patch

rl4co/envs/smtwtp.py#L173-L174

Added lines #L173 - L174 were not covered by tests
shape=(1,),
dtype=torch.int64,
minimum=0,
maximum=self.num_job + 1,
)
self.reward_spec = UnboundedContinuousTensorSpec(shape=(1,))
self.done_spec = UnboundedDiscreteTensorSpec(shape=(1,), dtype=torch.bool)

Check warning on line 181 in rl4co/envs/smtwtp.py

View check run for this annotation

Codecov / codecov/patch

rl4co/envs/smtwtp.py#L180-L181

Added lines #L180 - L181 were not covered by tests

def get_reward(self, td, actions) -> TensorDict:
job_due_time = td["job_due_time"]
job_weight = td["job_weight"]
job_process_time = td["job_process_time"]

Check warning on line 186 in rl4co/envs/smtwtp.py

View check run for this annotation

Codecov / codecov/patch

rl4co/envs/smtwtp.py#L184-L186

Added lines #L184 - L186 were not covered by tests

batch_idx = torch.arange(

Check warning on line 188 in rl4co/envs/smtwtp.py

View check run for this annotation

Codecov / codecov/patch

rl4co/envs/smtwtp.py#L188

Added line #L188 was not covered by tests
job_process_time.shape[0], device=job_process_time.device
).unsqueeze(1)

ordered_process_time = job_process_time[batch_idx, actions]
ordered_due_time = job_due_time[batch_idx, actions]
ordered_job_weight = job_weight[batch_idx, actions]
presum_process_time = torch.cumsum(

Check warning on line 195 in rl4co/envs/smtwtp.py

View check run for this annotation

Codecov / codecov/patch

rl4co/envs/smtwtp.py#L192-L195

Added lines #L192 - L195 were not covered by tests
ordered_process_time, dim=1
) # ending time of each job
job_tardiness = presum_process_time - ordered_due_time
job_tardiness[job_tardiness < 0] = 0
job_weighted_tardiness = ordered_job_weight * job_tardiness

Check warning on line 200 in rl4co/envs/smtwtp.py

View check run for this annotation

Codecov / codecov/patch

rl4co/envs/smtwtp.py#L198-L200

Added lines #L198 - L200 were not covered by tests

return -job_weighted_tardiness.sum(-1)

Check warning on line 202 in rl4co/envs/smtwtp.py

View check run for this annotation

Codecov / codecov/patch

rl4co/envs/smtwtp.py#L202

Added line #L202 was not covered by tests

def generate_data(self, batch_size) -> TensorDict:
batch_size = [batch_size] if isinstance(batch_size, int) else batch_size

Check warning on line 205 in rl4co/envs/smtwtp.py

View check run for this annotation

Codecov / codecov/patch

rl4co/envs/smtwtp.py#L205

Added line #L205 was not covered by tests
# Sampling according to Ye et al. (2023)
job_due_time = (

Check warning on line 207 in rl4co/envs/smtwtp.py

View check run for this annotation

Codecov / codecov/patch

rl4co/envs/smtwtp.py#L207

Added line #L207 was not covered by tests
torch.FloatTensor(*batch_size, self.num_job + 1)
.uniform_(self.min_time_span, self.max_time_span)
.to(self.device)
)
job_weight = (

Check warning on line 212 in rl4co/envs/smtwtp.py

View check run for this annotation

Codecov / codecov/patch

rl4co/envs/smtwtp.py#L212

Added line #L212 was not covered by tests
torch.FloatTensor(*batch_size, self.num_job + 1)
.uniform_(self.min_job_weight, self.max_job_weight)
.to(self.device)
)
job_process_time = (

Check warning on line 217 in rl4co/envs/smtwtp.py

View check run for this annotation

Codecov / codecov/patch

rl4co/envs/smtwtp.py#L217

Added line #L217 was not covered by tests
torch.FloatTensor(*batch_size, self.num_job + 1)
.uniform_(self.min_process_time, self.max_process_time)
.to(self.device)
)

# Rollouts begin at dummy node 0, whose features are set to 0
job_due_time[:, 0] = 0
job_weight[:, 0] = 0
job_process_time[:, 0] = 0

Check warning on line 226 in rl4co/envs/smtwtp.py

View check run for this annotation

Codecov / codecov/patch

rl4co/envs/smtwtp.py#L224-L226

Added lines #L224 - L226 were not covered by tests

return TensorDict(

Check warning on line 228 in rl4co/envs/smtwtp.py

View check run for this annotation

Codecov / codecov/patch

rl4co/envs/smtwtp.py#L228

Added line #L228 was not covered by tests
{
"job_due_time": job_due_time,
"job_weight": job_weight,
"job_process_time": job_process_time,
},
batch_size=batch_size,
)

@staticmethod
def render(td, actions=None, ax=None):
raise NotImplementedError("TODO: render is not implemented yet")

Check warning on line 239 in rl4co/envs/smtwtp.py

View check run for this annotation

Codecov / codecov/patch

rl4co/envs/smtwtp.py#L239

Added line #L239 was not covered by tests
20 changes: 20 additions & 0 deletions rl4co/models/nn/env_embeddings/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def env_context_embedding(env_name: str, config: dict) -> nn.Module:
"mdpp": DPPContext,
"pdp": PDPContext,
"mtsp": MTSPContext,
"smtwtp": SMTWTPContext,
}

if env_name not in embedding_registry:
Expand Down Expand Up @@ -213,3 +214,22 @@ def _distance_from_depot(self, td):
# Euclidean distance from the depot (loc[..., 0, :])
cur_loc = gather_by_index(td["locs"], td["current_node"])
return torch.norm(cur_loc - td["locs"][..., 0, :], dim=-1)


class SMTWTPContext(EnvContext):
"""Context embedding for the Single Machine Total Weighted Tardiness Problem (SMTWTP).
Project the following to the embedding space:
- current node embedding
- current time
"""

def __init__(self, embedding_dim):
super(SMTWTPContext, self).__init__(embedding_dim, embedding_dim + 1)

Check warning on line 227 in rl4co/models/nn/env_embeddings/context.py

View check run for this annotation

Codecov / codecov/patch

rl4co/models/nn/env_embeddings/context.py#L227

Added line #L227 was not covered by tests

def _cur_node_embedding(self, embeddings, td):
cur_node_embedding = gather_by_index(embeddings, td["current_job"])
return cur_node_embedding

Check warning on line 231 in rl4co/models/nn/env_embeddings/context.py

View check run for this annotation

Codecov / codecov/patch

rl4co/models/nn/env_embeddings/context.py#L230-L231

Added lines #L230 - L231 were not covered by tests

def _state_embedding(self, embeddings, td):
state_embedding = td["current_time"]
return state_embedding

Check warning on line 235 in rl4co/models/nn/env_embeddings/context.py

View check run for this annotation

Codecov / codecov/patch

rl4co/models/nn/env_embeddings/context.py#L234-L235

Added lines #L234 - L235 were not covered by tests
1 change: 1 addition & 0 deletions rl4co/models/nn/env_embeddings/dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def env_dynamic_embedding(env_name: str, config: dict) -> nn.Module:
"mdpp": StaticEmbedding,
"pdp": StaticEmbedding,
"mtsp": StaticEmbedding,
"smtwtp": StaticEmbedding,
}

if env_name not in embedding_registry:
Expand Down
23 changes: 23 additions & 0 deletions rl4co/models/nn/env_embeddings/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def env_init_embedding(env_name: str, config: dict) -> nn.Module:
"mdpp": MDPPInitEmbedding,
"pdp": PDPInitEmbedding,
"mtsp": MTSPInitEmbedding,
"smtwtp": SMTWTPInitEmbedding,
}

if env_name not in embedding_registry:
Expand Down Expand Up @@ -246,3 +247,25 @@ def forward(self, td):
depot_embedding = self.init_embed_depot(td["locs"][..., 0:1, :])
node_embedding = self.init_embed(td["locs"][..., 1:, :])
return torch.cat([depot_embedding, node_embedding], -2)


class SMTWTPInitEmbedding(nn.Module):
"""Initial embedding for the Single Machine Total Weighted Tardiness Problem (SMTWTP).
Embed the following node features to the embedding space:
- job_due_time: due time of the jobs
- job_weight: weights of the jobs
- job_process_time: the processing time of jobs
"""

def __init__(self, embedding_dim, linear_bias=True):
super(SMTWTPInitEmbedding, self).__init__()
node_dim = 3 # job_due_time, job_weight, job_process_time
self.init_embed = nn.Linear(node_dim, embedding_dim, linear_bias)

Check warning on line 263 in rl4co/models/nn/env_embeddings/init.py

View check run for this annotation

Codecov / codecov/patch

rl4co/models/nn/env_embeddings/init.py#L261-L263

Added lines #L261 - L263 were not covered by tests

def forward(self, td):
job_due_time = td["job_due_time"]
job_weight = td["job_weight"]
job_process_time = td["job_process_time"]
feat = torch.stack((job_due_time, job_weight, job_process_time), dim=-1)
out = self.init_embed(feat)
return out

Check warning on line 271 in rl4co/models/nn/env_embeddings/init.py

View check run for this annotation

Codecov / codecov/patch

rl4co/models/nn/env_embeddings/init.py#L266-L271

Added lines #L266 - L271 were not covered by tests

0 comments on commit 3c0ac1e

Please sign in to comment.