-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
# Conflicts: # dist/lightautoml-0.3.8b1-py3-none-any.whl
- Loading branch information
Showing
8 changed files
with
2,432 additions
and
30 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
Empty file.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
import pandas as pd | ||
from lightautoml.addons.hypex.ABTesting.ab_tester import AATest | ||
from lightautoml.addons.hypex.utils.tutorial_data_creation import create_test_data | ||
|
||
|
||
def test_aa_simple(): | ||
data = create_test_data(rs=52) | ||
info_col = "user_id" | ||
iterations = 20 | ||
|
||
model = AATest( | ||
data=data, | ||
target_fields=["pre_spends", "post_spends"], | ||
info_cols=info_col | ||
) | ||
res, datas_dict = model.search_dist_uniform_sampling(iterations=iterations) | ||
|
||
assert isinstance(res, pd.DataFrame), "Metrics are not dataframes" | ||
assert res.shape[0] == iterations, "Metrics dataframe contains more or less rows with random states " \ | ||
"(#rows should be equal #of experiments" | ||
assert info_col not in model.data, "Info_col is take part in experiment, it should be deleted in preprocess" | ||
assert isinstance(datas_dict, dict), "Result is not dict" | ||
assert len(datas_dict) == iterations, "# of dataframes is not equal # of iterations" | ||
assert all(data.columns) == all(datas_dict[0].drop(columns=['group']).columns), \ | ||
"Columns in the result are not the same as columns in initial data " | ||
|
||
|
||
def test_aa_group(): | ||
data = create_test_data(rs=52) | ||
info_col = "user_id" | ||
group_cols = 'industry' | ||
iterations = 20 | ||
|
||
model = AATest( | ||
data=data, | ||
target_fields=["pre_spends", "post_spends"], | ||
info_cols=info_col, | ||
group_cols=group_cols | ||
) | ||
res, datas_dict = model.search_dist_uniform_sampling(iterations=iterations) | ||
|
||
assert isinstance(res, pd.DataFrame), "Metrics are not dataframes" | ||
assert res.shape[0] == iterations, "Metrics dataframe contains more or less rows with random states " \ | ||
"(#rows should be equal #of experiments" | ||
assert info_col not in model.data, "Info_col is take part in experiment, it should be deleted in preprocess" | ||
assert isinstance(datas_dict, dict), "Result is not dict" | ||
assert len(datas_dict) == iterations, "# of dataframes is not equal # of iterations" | ||
assert all(data.columns) == all(datas_dict[0].drop(columns=['group']).columns), "Columns in the result are not " \ | ||
"the same as columns in initial " \ | ||
"data " | ||
|
||
|
||
def test_aa_quantfields(): | ||
data = create_test_data(rs=52) | ||
info_col = "user_id" | ||
group_cols = 'industry' | ||
quant_field = 'gender' | ||
iterations = 20 | ||
|
||
model = AATest( | ||
data=data, | ||
target_fields=["pre_spends", "post_spends"], | ||
info_cols=info_col, | ||
group_cols=group_cols, | ||
quant_field=quant_field | ||
) | ||
res, datas_dict = model.search_dist_uniform_sampling(iterations=iterations) | ||
|
||
assert isinstance(res, pd.DataFrame), "Metrics are not dataframes" | ||
assert res.shape[0] == iterations, "Metrics dataframe contains more or less rows with random states " \ | ||
"(#rows should be equal #of experiments" | ||
assert info_col not in model.data, "Info_col is take part in experiment, it should be deleted in preprocess" | ||
assert isinstance(datas_dict, dict), "Result is not dict" | ||
assert len(datas_dict) == iterations, "# of dataframes is not equal # of iterations" | ||
assert all(data.columns) == all(datas_dict[0].drop(columns=['group']).columns), "Columns in the result are not " \ | ||
"the same as columns in initial " \ | ||
"data " | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
from lightautoml.addons.hypex.ABTesting.ab_tester import ABTest | ||
from lightautoml.addons.hypex.utils.tutorial_data_creation import create_test_data | ||
|
||
|
||
# def test_split_ab(): | ||
# data = create_test_data() | ||
# half_data = int(data.shape[0] / 2) | ||
# data['group'] = ['test'] * half_data + ['control'] * half_data | ||
# | ||
# group_field = 'group' | ||
# | ||
# model = ABTest() | ||
# splitted_data = model.split_ab(data, group_field) | ||
# | ||
# assert isinstance(splitted_data, dict), "result of split_ab is not dict" | ||
# assert len(splitted_data) == 2, "split_ab contains not of 2 values" | ||
# assert list(splitted_data.keys()) == ['test', 'control'], "changed keys in result of split_ab" | ||
# | ||
# | ||
# def test_calc_difference(): | ||
# data = create_test_data() | ||
# half_data = int(data.shape[0] / 2) | ||
# data['group'] = ['test'] * half_data + ['control'] * half_data | ||
# | ||
# group_field = 'group' | ||
# target_field = 'post_spends' | ||
# | ||
# model = ABTest() | ||
# splitted_data = model.split_ab(data, group_field) | ||
# differences = model.calc_difference(splitted_data, target_field) | ||
# | ||
# assert isinstance(differences, dict), "result of calc_difference is not dict" | ||
|
||
|
||
def test_calc_p_value(): | ||
data = create_test_data() | ||
half_data = int(data.shape[0] / 2) | ||
data['group'] = ['test'] * half_data + ['control'] * half_data | ||
|
||
group_field = 'group' | ||
target_field = 'post_spends' | ||
|
||
model = ABTest() | ||
splitted_data = model.split_ab(data, group_field) | ||
pvalues = model.calc_p_value(splitted_data, target_field) | ||
|
||
assert isinstance(pvalues, dict), "result of calc_p_value is not dict" | ||
|
||
|
||
def test_execute(): | ||
data = create_test_data() | ||
half_data = int(data.shape[0] / 2) | ||
data['group'] = ['test'] * half_data + ['control'] * half_data | ||
|
||
target_field = 'post_spends' | ||
target_field_before = 'pre_spends' | ||
group_field = 'group' | ||
|
||
model = ABTest() | ||
result = model.execute( | ||
data=data, | ||
target_field=target_field, | ||
target_field_before=target_field_before, | ||
group_field=group_field | ||
) | ||
|
||
assert isinstance(result, dict), "result of func execution is not dict" | ||
assert len(result) == 3, "result of execution is changed, len of dict was 3" | ||
assert list(result.keys()) == ['size', 'difference', 'p_value'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters