From 3c0ac1eb3d118e46d05b250ef68cf488b2d70467 Mon Sep 17 00:00:00 2001 From: henry-yeh Date: Sat, 7 Oct 2023 15:15:58 +0800 Subject: [PATCH] Add SMTWTP env --- rl4co/envs/__init__.py | 2 + rl4co/envs/smtwtp.py | 239 ++++++++++++++++++++++ rl4co/models/nn/env_embeddings/context.py | 20 ++ rl4co/models/nn/env_embeddings/dynamic.py | 1 + rl4co/models/nn/env_embeddings/init.py | 23 +++ 5 files changed, 285 insertions(+) create mode 100644 rl4co/envs/smtwtp.py diff --git a/rl4co/envs/__init__.py b/rl4co/envs/__init__.py index 1926bd49..26eeecb0 100644 --- a/rl4co/envs/__init__.py +++ b/rl4co/envs/__init__.py @@ -11,6 +11,7 @@ from rl4co.envs.pctsp import PCTSPEnv from rl4co.envs.pdp import PDPEnv from rl4co.envs.sdvrp import SDVRPEnv +from rl4co.envs.smtwtp import SMTWTPEnv from rl4co.envs.spctsp import SPCTSPEnv from rl4co.envs.tsp import TSPEnv @@ -27,6 +28,7 @@ "sdvrp": SDVRPEnv, "spctsp": SPCTSPEnv, "tsp": TSPEnv, + "smtwtp": SMTWTPEnv, } diff --git a/rl4co/envs/smtwtp.py b/rl4co/envs/smtwtp.py new file mode 100644 index 00000000..41f5e5af --- /dev/null +++ b/rl4co/envs/smtwtp.py @@ -0,0 +1,239 @@ +from typing import Optional + +import torch + +from tensordict.tensordict import TensorDict +from torchrl.data import ( + BoundedTensorSpec, + CompositeSpec, + UnboundedContinuousTensorSpec, + UnboundedDiscreteTensorSpec, +) + +from rl4co.envs.common.base import RL4COEnvBase +from rl4co.utils.pylogger import get_pylogger + +log = get_pylogger(__name__) + + +class SMTWTPEnv(RL4COEnvBase): + """ + Single Machine Total Weighted Tardiness Problem environment as described in DeepACO (https://arxiv.org/pdf/2309.14032.pdf) + SMTWTP is a scheduling problem in which a set of jobs must be processed on a single machine. + Each job i has a processing time, a weight, and a due date. The objective is to minimize the sum of the weighted tardiness of all jobs, + where the weighted tardiness of a job is defined as the product of its weight and the duration by which its completion time exceeds its due date. + At each step, the agent chooses a job to process. The reward is the -infinite unless the agent processes all the jobs. + In that case, the reward is (-)objective value of the processing order: maximizing the reward is equivalent to minimizing the objective. + + Args: + num_job: number of jobs + min_time_span: lower bound of jobs' due time. By default, jobs' due time is uniformly sampled from (min_time_span, max_time_span) + max_time_span: upper bound of jobs' due time. By default, it will be set to num_job / 2 + min_job_weight: lower bound of jobs' weights. By default, jobs' weights are uniformly sampled from (min_job_weight, max_job_weight) + max_job_weight: upper bound of jobs' weights + min_process_time: lower bound of jobs' process time. By default, jobs' process time is uniformly sampled from (min_process_time, max_process_time) + max_process_time: upper bound of jobs' process time + td_params: parameters of the environment + seed: seed for the environment + device: device to use. Generally, no need to set as tensors are updated on the fly + """ + + name = "smtwtp" + + def __init__( + self, + num_job: int = 10, + min_time_span: float = 0, + max_time_span: float = None, # will be set to num_job/2 by default. In DeepACO, it is set to num_job, which would be too simple + min_job_weight: float = 0, + max_job_weight: float = 1, + min_process_time: float = 0, + max_process_time: float = 1, + td_params: TensorDict = None, + **kwargs, + ): + super().__init__(**kwargs) + self.num_job = num_job + self.min_time_span = min_time_span + self.max_time_span = num_job / 2 if max_time_span is None else max_time_span + self.min_job_weight = min_job_weight + self.max_job_weight = max_job_weight + self.min_process_time = min_process_time + self.max_process_time = max_process_time + self._make_spec(td_params) + + @staticmethod + def _step(td: TensorDict) -> TensorDict: + current_job = td["action"] + + # Set not visited to 0 (i.e., we visited the node) + available = td["action_mask"].scatter( + -1, current_job.unsqueeze(-1).expand_as(td["action_mask"]), 0 + ) + + # Increase used time + selected_process_time = td["job_process_time"][ + torch.arange(current_job.size(0)), current_job + ] + current_time = td["current_time"] + selected_process_time.unsqueeze(-1) + + # We are done there are no unvisited locations + done = torch.count_nonzero(available, dim=-1) <= 0 + + # The reward is calculated outside via get_reward for efficiency, so we set it to -inf here + reward = torch.ones_like(done) * float("-inf") + + # The output must be written in a ``"next"`` entry + return TensorDict( + { + "next": { + "job_due_time": td["job_due_time"], + "job_weight": td["job_weight"], + "job_process_time": td["job_process_time"], + "current_job": current_job, + "current_time": current_time, + "action_mask": available, + "reward": reward, + "done": done, + } + }, + td.shape, + ) + + def _reset(self, td: Optional[TensorDict] = None, batch_size=None) -> TensorDict: + # Initialization + if batch_size is None: + batch_size = self.batch_size if td is None else td["job_due_time"].shape[:-1] + batch_size = [batch_size] if isinstance(batch_size, int) else batch_size + + self.device = device = ( + td["job_due_time"].device if td is not None else self.device + ) + + td = self.generate_data(batch_size) if td is None else td + + init_job_due_time = td["job_due_time"] + init_job_process_time = td["job_process_time"] + init_job_weight = td["job_weight"] + + # Other variables + current_job = torch.zeros((*batch_size, 1), dtype=torch.int64, device=device) + current_time = torch.zeros((*batch_size, 1), dtype=torch.int64, device=device) + available = torch.ones( + (*batch_size, self.num_job + 1), dtype=torch.bool, device=device + ) + available[:, 0] = 0 # mask the starting dummy node + + return TensorDict( + { + "job_due_time": init_job_due_time, + "job_weight": init_job_weight, + "job_process_time": init_job_process_time, + "current_job": current_job, + "current_time": current_time, + "action_mask": available, + }, + batch_size=batch_size, + ) + + def _make_spec(self, td_params: TensorDict = None): + self.observation_spec = CompositeSpec( + job_due_time=BoundedTensorSpec( + minimum=self.min_time_span, + maximum=self.max_time_span, + shape=(self.num_job + 1,), + dtype=torch.float32, + ), + job_weight=BoundedTensorSpec( + minimum=self.min_job_weight, + maximum=self.max_job_weight, + shape=(self.num_job + 1,), + dtype=torch.float32, + ), + job_process_time=BoundedTensorSpec( + minimum=self.min_process_time, + maximum=self.max_process_time, + shape=(self.num_job + 1,), + dtype=torch.float32, + ), + current_node=UnboundedDiscreteTensorSpec( + shape=(1,), + dtype=torch.int64, + ), + action_mask=UnboundedDiscreteTensorSpec( + shape=(self.num_job + 1,), + dtype=torch.bool, + ), + current_time=UnboundedContinuousTensorSpec( + shape=(1,), + dtype=torch.float32, + ), + shape=(), + ) + self.input_spec = self.observation_spec.clone() + self.action_spec = BoundedTensorSpec( + shape=(1,), + dtype=torch.int64, + minimum=0, + maximum=self.num_job + 1, + ) + self.reward_spec = UnboundedContinuousTensorSpec(shape=(1,)) + self.done_spec = UnboundedDiscreteTensorSpec(shape=(1,), dtype=torch.bool) + + def get_reward(self, td, actions) -> TensorDict: + job_due_time = td["job_due_time"] + job_weight = td["job_weight"] + job_process_time = td["job_process_time"] + + batch_idx = torch.arange( + job_process_time.shape[0], device=job_process_time.device + ).unsqueeze(1) + + ordered_process_time = job_process_time[batch_idx, actions] + ordered_due_time = job_due_time[batch_idx, actions] + ordered_job_weight = job_weight[batch_idx, actions] + presum_process_time = torch.cumsum( + ordered_process_time, dim=1 + ) # ending time of each job + job_tardiness = presum_process_time - ordered_due_time + job_tardiness[job_tardiness < 0] = 0 + job_weighted_tardiness = ordered_job_weight * job_tardiness + + return -job_weighted_tardiness.sum(-1) + + def generate_data(self, batch_size) -> TensorDict: + batch_size = [batch_size] if isinstance(batch_size, int) else batch_size + # Sampling according to Ye et al. (2023) + job_due_time = ( + torch.FloatTensor(*batch_size, self.num_job + 1) + .uniform_(self.min_time_span, self.max_time_span) + .to(self.device) + ) + job_weight = ( + torch.FloatTensor(*batch_size, self.num_job + 1) + .uniform_(self.min_job_weight, self.max_job_weight) + .to(self.device) + ) + job_process_time = ( + torch.FloatTensor(*batch_size, self.num_job + 1) + .uniform_(self.min_process_time, self.max_process_time) + .to(self.device) + ) + + # Rollouts begin at dummy node 0, whose features are set to 0 + job_due_time[:, 0] = 0 + job_weight[:, 0] = 0 + job_process_time[:, 0] = 0 + + return TensorDict( + { + "job_due_time": job_due_time, + "job_weight": job_weight, + "job_process_time": job_process_time, + }, + batch_size=batch_size, + ) + + @staticmethod + def render(td, actions=None, ax=None): + raise NotImplementedError("TODO: render is not implemented yet") diff --git a/rl4co/models/nn/env_embeddings/context.py b/rl4co/models/nn/env_embeddings/context.py index 881c2f3b..30865da1 100644 --- a/rl4co/models/nn/env_embeddings/context.py +++ b/rl4co/models/nn/env_embeddings/context.py @@ -25,6 +25,7 @@ def env_context_embedding(env_name: str, config: dict) -> nn.Module: "mdpp": DPPContext, "pdp": PDPContext, "mtsp": MTSPContext, + "smtwtp": SMTWTPContext, } if env_name not in embedding_registry: @@ -213,3 +214,22 @@ def _distance_from_depot(self, td): # Euclidean distance from the depot (loc[..., 0, :]) cur_loc = gather_by_index(td["locs"], td["current_node"]) return torch.norm(cur_loc - td["locs"][..., 0, :], dim=-1) + + +class SMTWTPContext(EnvContext): + """Context embedding for the Single Machine Total Weighted Tardiness Problem (SMTWTP). + Project the following to the embedding space: + - current node embedding + - current time + """ + + def __init__(self, embedding_dim): + super(SMTWTPContext, self).__init__(embedding_dim, embedding_dim + 1) + + def _cur_node_embedding(self, embeddings, td): + cur_node_embedding = gather_by_index(embeddings, td["current_job"]) + return cur_node_embedding + + def _state_embedding(self, embeddings, td): + state_embedding = td["current_time"] + return state_embedding diff --git a/rl4co/models/nn/env_embeddings/dynamic.py b/rl4co/models/nn/env_embeddings/dynamic.py index 82984779..9c84f692 100644 --- a/rl4co/models/nn/env_embeddings/dynamic.py +++ b/rl4co/models/nn/env_embeddings/dynamic.py @@ -26,6 +26,7 @@ def env_dynamic_embedding(env_name: str, config: dict) -> nn.Module: "mdpp": StaticEmbedding, "pdp": StaticEmbedding, "mtsp": StaticEmbedding, + "smtwtp": StaticEmbedding, } if env_name not in embedding_registry: diff --git a/rl4co/models/nn/env_embeddings/init.py b/rl4co/models/nn/env_embeddings/init.py index 0358bb73..a27514cb 100644 --- a/rl4co/models/nn/env_embeddings/init.py +++ b/rl4co/models/nn/env_embeddings/init.py @@ -23,6 +23,7 @@ def env_init_embedding(env_name: str, config: dict) -> nn.Module: "mdpp": MDPPInitEmbedding, "pdp": PDPInitEmbedding, "mtsp": MTSPInitEmbedding, + "smtwtp": SMTWTPInitEmbedding, } if env_name not in embedding_registry: @@ -246,3 +247,25 @@ def forward(self, td): depot_embedding = self.init_embed_depot(td["locs"][..., 0:1, :]) node_embedding = self.init_embed(td["locs"][..., 1:, :]) return torch.cat([depot_embedding, node_embedding], -2) + + +class SMTWTPInitEmbedding(nn.Module): + """Initial embedding for the Single Machine Total Weighted Tardiness Problem (SMTWTP). + Embed the following node features to the embedding space: + - job_due_time: due time of the jobs + - job_weight: weights of the jobs + - job_process_time: the processing time of jobs + """ + + def __init__(self, embedding_dim, linear_bias=True): + super(SMTWTPInitEmbedding, self).__init__() + node_dim = 3 # job_due_time, job_weight, job_process_time + self.init_embed = nn.Linear(node_dim, embedding_dim, linear_bias) + + def forward(self, td): + job_due_time = td["job_due_time"] + job_weight = td["job_weight"] + job_process_time = td["job_process_time"] + feat = torch.stack((job_due_time, job_weight, job_process_time), dim=-1) + out = self.init_embed(feat) + return out