From 69ca640dd6d52860d9e1ba5701ee06b0aedb0a1f Mon Sep 17 00:00:00 2001 From: Huazhong Ji Date: Wed, 14 Feb 2024 21:55:12 +0800 Subject: [PATCH] Set the dataset format used by `test_trainer` to float32 (#28920) Co-authored-by: unit_test --- tests/trainer/test_trainer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 55cc35cf6aa3eb..2a098007852c87 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -176,8 +176,8 @@ def __init__(self, length=64, seed=42, batch_size=8): np.random.seed(seed) sizes = np.random.randint(1, 20, (length // batch_size,)) # For easy batching, we make every batch_size consecutive samples the same size. - self.xs = [np.random.normal(size=(s,)) for s in sizes.repeat(batch_size)] - self.ys = [np.random.normal(size=(s,)) for s in sizes.repeat(batch_size)] + self.xs = [np.random.normal(size=(s,)).astype(np.float32) for s in sizes.repeat(batch_size)] + self.ys = [np.random.normal(size=(s,)).astype(np.float32) for s in sizes.repeat(batch_size)] def __len__(self): return self.length @@ -547,7 +547,7 @@ def test_trainer_with_datasets(self): np.random.seed(42) x = np.random.normal(size=(64,)).astype(np.float32) - y = 2.0 * x + 3.0 + np.random.normal(scale=0.1, size=(64,)) + y = 2.0 * x + 3.0 + np.random.normal(scale=0.1, size=(64,)).astype(np.float32) train_dataset = datasets.Dataset.from_dict({"input_x": x, "label": y}) # Base training. Should have the same results as test_reproducible_training