Skip to content

Commit

Permalink
addressing tempfile problem on windows
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrewq11 committed Nov 13, 2024
1 parent e796093 commit 474c802
Showing 1 changed file with 3 additions and 9 deletions.
12 changes: 3 additions & 9 deletions tests/test_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,7 @@
import torch
import pandas as pd
import tempfile
import pytest

import graphium
from graphium.utils.fs import rm, exists, get_size
from graphium.data import GraphOGBDataModule, MultitaskFromSmilesDataModule

Expand All @@ -29,10 +27,6 @@

class test_DataModule(ut.TestCase):

@pytest.fixture
def _setup_tmp_path(self, tmp_path):
self.tmp_path = tmp_path

def test_ogb_datamodule(self):
# other datasets are too large to be tested
dataset_names = ["ogbg-molhiv", "ogbg-molpcba", "ogbg-moltox21", "ogbg-molfreesolv"]
Expand Down Expand Up @@ -386,7 +380,6 @@ def test_datamodule_multiple_data_files(self):

self.assertEqual(len(ds.train_ds), 20)

@pytest.mark.usefixtures("_setup_tmp_path")
def test_splits_file(self):
# Test single CSV files
csv_file = "tests/data/micro_ZINC_shard_1.csv"
Expand Down Expand Up @@ -432,7 +425,7 @@ def test_splits_file(self):

try:
# Create a TemporaryFile to save the splits, and test the datamodule
temp_file = tempfile.NamedTemporaryFile(suffix=".pt", dir=self.tmp_path)
temp_file = tempfile.NamedTemporaryFile(suffix=".pt", delete=False)

# Save the splits
torch.save(splits, temp_file)
Expand Down Expand Up @@ -479,7 +472,8 @@ def test_splits_file(self):
np.testing.assert_array_equal(ds.test_ds.smiles_offsets_tensor, ds2.test_ds.smiles_offsets_tensor)

finally:
temp_file.close()
temp_file.close()
os.unlink(temp_file.name)


if __name__ == "__main__":
Expand Down

0 comments on commit 474c802

Please sign in to comment.