Skip to content

Commit

Permalink
details
Browse files Browse the repository at this point in the history
  • Loading branch information
LeoGrin committed Dec 18, 2024
1 parent 27b7534 commit 929e3e1
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions tabpfn_client/tests/unit/test_tabpfn_regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,12 +447,12 @@ def test_predict_uses_correct_model_path(self):

self.assertEqual(predict_kwargs["config"]["model_path"], expected_model_path)

def test_paper_version_behavior_regression(self):
def test_paper_version_behavior(self):
# this just tests that it doesn't break,
# but the actual behavior is easier to test
# on the server side
X = np.random.rand(10, 5)
y = np.random.rand(10) # Continuous target for regression
y = np.random.rand(10)
test_X = np.random.rand(5, 5)

# Mock the inference handler
Expand All @@ -474,13 +474,12 @@ def test_paper_version_behavior_regression(self):
y_pred_false = tabpfn_false.predict(test_X)
self.assertIsNotNone(y_pred_false)

def test_check_paper_version_with_non_numerical_data_raises_error_regression(self):
# Create a TabPFNRegressor with paper_version=True
def test_check_paper_version_with_non_numerical_data_raises_error(self):
tabpfn = TabPFNRegressor(paper_version=True)

# Create non-numerical data
X = pd.DataFrame({"feature1": ["a", "b", "c"], "feature2": ["d", "e", "f"]})
y = np.array([0.1, 0.2, 0.3]) # Continuous target for regression
y = np.array([0.1, 0.2, 0.3])

# Mock the inference handler
config.g_tabpfn_config.inference_handler = MagicMock()
Expand Down

0 comments on commit 929e3e1

Please sign in to comment.