Skip to content

Commit

Permalink
lint fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
dagardner-nv committed Jan 12, 2024
1 parent 02c74fb commit 1347cb3
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 10 deletions.
16 changes: 7 additions & 9 deletions tests/dfencoder/test_scalers.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,7 @@ def test_modified_scaler_transform(modified_scaler, tensor):
assert torch.equal(torch.round(results, decimals=2), expected), f"{results} != {expected}"

# Test alternate path where median absolute deviation is 1
t = torch.tensor([3.0, 4.0, 4.0, 5.0])
modified_scaler.fit(t)
modified_scaler.fit(torch.tensor([3.0, 4.0, 4.0, 5.0]))
results = modified_scaler.transform(tensor)
expected = torch.tensor([5.43, 6.86, 8.78])
assert torch.equal(torch.round(results, decimals=2), expected), f"{results} != {expected}"
Expand All @@ -128,8 +127,7 @@ def test_modified_scaler_inverse_transform(modified_scaler, tensor):
assert torch.equal(torch.round(results, decimals=2), expected), f"{results} != {expected}"

# Test alternate path where median absolute deviation is 1
t = torch.tensor([3.0, 4.0, 4.0, 5.0])
modified_scaler.fit(t)
modified_scaler.fit(torch.tensor([3.0, 4.0, 4.0, 5.0]))
results = modified_scaler.inverse_transform(tensor)
expected = torch.tensor([8.64, 9.2, 9.95])
assert torch.equal(torch.round(results, decimals=2), expected), f"{results} != {expected}"
Expand Down Expand Up @@ -161,13 +159,13 @@ def test_gauss_rank_scaler_fit_transform(gauss_rank_scaler, tensor):

def test_null_scaler(tensor):
orig = tensor.to(dtype=torch.float32, copy=True)
ns = scalers.NullScaler()
ns.fit(tensor)
scalar = scalers.NullScaler()
scalar.fit(tensor)

# Verify it does nothing
assert ns.transform(tensor) is tensor
assert ns.inverse_transform(tensor) is tensor
assert ns.fit_transform(tensor) is tensor
assert scalar.transform(tensor) is tensor
assert scalar.inverse_transform(tensor) is tensor
assert scalar.fit_transform(tensor) is tensor

# After all that the values should be the same
assert torch.equal(tensor, orig), f"{tensor} != {orig}"
2 changes: 1 addition & 1 deletion tests/llm/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def mock_nemollm_fixture(mock_nemollm: mock.MagicMock):
async def mock_task(fut: asyncio.Future, value: typing.Any = mock.DEFAULT):
fut.set_result(value)

def create_future(*args, **kwargs) -> asyncio.Future:
def create_future(*args, **kwargs) -> asyncio.Future: # pylint: disable=unused-argument
event_loop = asyncio.get_event_loop()
fut = event_loop.create_future()
event_loop.create_task(mock_task(fut, mock.DEFAULT))
Expand Down

0 comments on commit 1347cb3

Please sign in to comment.