Skip to content

Commit

Permalink
Set the dataset format used by test_trainer to float32 (#28920)
Browse files Browse the repository at this point in the history
Co-authored-by: unit_test <[email protected]>
  • Loading branch information
statelesshz and unit_test authored Feb 14, 2024
1 parent 7252e8d commit 69ca640
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,8 @@ def __init__(self, length=64, seed=42, batch_size=8):
np.random.seed(seed)
sizes = np.random.randint(1, 20, (length // batch_size,))
# For easy batching, we make every batch_size consecutive samples the same size.
self.xs = [np.random.normal(size=(s,)) for s in sizes.repeat(batch_size)]
self.ys = [np.random.normal(size=(s,)) for s in sizes.repeat(batch_size)]
self.xs = [np.random.normal(size=(s,)).astype(np.float32) for s in sizes.repeat(batch_size)]
self.ys = [np.random.normal(size=(s,)).astype(np.float32) for s in sizes.repeat(batch_size)]

def __len__(self):
return self.length
Expand Down Expand Up @@ -547,7 +547,7 @@ def test_trainer_with_datasets(self):

np.random.seed(42)
x = np.random.normal(size=(64,)).astype(np.float32)
y = 2.0 * x + 3.0 + np.random.normal(scale=0.1, size=(64,))
y = 2.0 * x + 3.0 + np.random.normal(scale=0.1, size=(64,)).astype(np.float32)
train_dataset = datasets.Dataset.from_dict({"input_x": x, "label": y})

# Base training. Should have the same results as test_reproducible_training
Expand Down

0 comments on commit 69ca640

Please sign in to comment.