Skip to content

Commit

Permalink
[Feat,BugFix] solve #106; add checks and better docs
Browse files Browse the repository at this point in the history
  • Loading branch information
fedebotu committed Dec 5, 2023
1 parent 548897d commit 428d8b3
Showing 1 changed file with 50 additions and 20 deletions.
70 changes: 50 additions & 20 deletions rl4co/data/generate_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import os
import sys

from typing import List, Union

import numpy as np

from rl4co.data.utils import check_extension
Expand Down Expand Up @@ -212,26 +214,42 @@ def generate_mdpp_data(


def generate_dataset(
filename=None,
data_dir="data",
name=None,
problem="all",
data_distribution="all",
dataset_size=10000,
graph_sizes=[20, 50, 100],
overwrite=False,
seed=1234,
disable_warning=True,
distributions_per_problem=None,
filename: Union[str, List[str]] = None,
data_dir: str = "data",
name: str = None,
problem: Union[str, List[str]] = "all",
data_distribution: str = "all",
dataset_size: int = 10000,
graph_sizes: Union[int, List[int]] = [20, 50, 100],
overwrite: bool = False,
seed: int = 1234,
disable_warning: bool = True,
distributions_per_problem: Union[int, dict] = None,
):
"""We keep a similar structure as in Kool et al. 2019 but save and load the data as npz
This is way faster and more memory efficient than pickle and also allows for easy transfer to TensorDict
Args:
filename: Filename to save the data to. If None, the data is saved to data_dir/problem/problem_graph_size_seed.npz. Defaults to None.
data_dir: Directory to save the data to. Defaults to "data".
name: Name of the dataset. Defaults to None.
problem: Problem to generate data for. Defaults to "all".
data_distribution: Data distribution to generate data for. Defaults to "all".
dataset_size: Number of datasets to generate. Defaults to 10000.
graph_sizes: Graph size to generate data for. Defaults to [20, 50, 100].
overwrite: Whether to overwrite existing files. Defaults to False.
seed: Random seed. Defaults to 1234.
disable_warning: Whether to disable warnings. Defaults to True.
distributions_per_problem: Number of distributions to generate per problem. Defaults to None.
"""
assert filename is None or (
len(problem) == 1 and len(graph_sizes) == 1
), "Can only specify filename when generating a single dataset"

distributions_per_problem = DISTRIBUTIONS_PER_PROBLEM
if isinstance(problem, list) and len(problem) == 1:
problem = problem[0]

graph_sizes = [graph_sizes] if isinstance(graph_sizes, int) else graph_sizes

if distributions_per_problem is None:
distributions_per_problem = DISTRIBUTIONS_PER_PROBLEM

if problem == "all":
problems = distributions_per_problem
Expand All @@ -241,15 +259,18 @@ def generate_dataset(
if data_distribution == "all"
else [data_distribution]
}
# breakpoint()
fname = filename

# Support multiple filenames if necessary
filenames = [filename] if isinstance(filename, str) else filename
iter = 0

# Main loop for data generation. We loop over all problems, distributions and sizes
for problem, distributions in problems.items():
for distribution in distributions or [None]:
for graph_size in graph_sizes:
datadir = os.path.join(data_dir, problem)
os.makedirs(datadir, exist_ok=True)

if filename is None:
datadir = os.path.join(data_dir, problem)
os.makedirs(datadir, exist_ok=True)
fname = os.path.join(
datadir,
"{}{}{}_{}_seed{}.npz".format(
Expand All @@ -263,6 +284,15 @@ def generate_dataset(
),
)
else:
try:
fname = filenames[iter]
# make directory if necessary
os.makedirs(os.path.dirname(fname), exist_ok=True)
iter += 1
except Exception:
raise ValueError(
"Number of filenames does not match number of problems"
)
fname = check_extension(filename, extension=".npz")

if not overwrite and os.path.isfile(
Expand Down

0 comments on commit 428d8b3

Please sign in to comment.