Skip to content

Commit

Permalink
Fix and test scalar list and tabulate for specific runs
Browse files Browse the repository at this point in the history
  • Loading branch information
glatterf42 committed Aug 2, 2024
1 parent 12b2421 commit c306efc
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 22 deletions.
4 changes: 3 additions & 1 deletion ixmp4/core/optimization/scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,4 +137,6 @@ def list(self, name: str | None = None) -> Iterable[Scalar]:
]

def tabulate(self, name: str | None = None) -> pd.DataFrame:
return self.backend.optimization.scalars.tabulate(name=name)
return self.backend.optimization.scalars.tabulate(
run_id=self._run.id, name=name
)
42 changes: 31 additions & 11 deletions tests/core/test_scalar.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pandas as pd
import pytest

from ixmp4 import Scalar
from ixmp4.core import Platform, Scalar

from ..utils import all_platforms, assert_unordered_equality

Expand Down Expand Up @@ -35,7 +35,7 @@ def df_from_list(scalars: list[Scalar]):
@all_platforms
class TestCoreScalar:
def test_create_scalar(self, test_mp, request):
test_mp = request.getfixturevalue(test_mp)
test_mp: Platform = request.getfixturevalue(test_mp) # type: ignore
run = test_mp.runs.create("Model", "Scenario")
unit = test_mp.units.create("Test Unit")
scalar_1 = run.optimization.scalars.create(
Expand Down Expand Up @@ -63,7 +63,7 @@ def test_create_scalar(self, test_mp, request):
assert scalar_3.unit.name == ""

def test_get_scalar(self, test_mp, request):
test_mp = request.getfixturevalue(test_mp)
test_mp: Platform = request.getfixturevalue(test_mp) # type: ignore
run = test_mp.runs.create("Model", "Scenario")
unit = test_mp.units.create("Test Unit")
scalar = run.optimization.scalars.create("Scalar", value=10, unit=unit.name)
Expand All @@ -77,7 +77,7 @@ def test_get_scalar(self, test_mp, request):
_ = run.optimization.scalars.get("Foo")

def test_update_scalar(self, test_mp, request):
test_mp = request.getfixturevalue(test_mp)
test_mp: Platform = request.getfixturevalue(test_mp) # type: ignore
run = test_mp.runs.create("Model", "Scenario")
unit = test_mp.units.create("Test Unit")
unit2 = test_mp.units.create("Test Unit 2")
Expand All @@ -100,10 +100,8 @@ def test_update_scalar(self, test_mp, request):
assert scalar.unit.id == result.unit.id == 1

def test_list_scalars(self, test_mp, request):
test_mp = request.getfixturevalue(test_mp)
test_mp: Platform = request.getfixturevalue(test_mp) # type: ignore
run = test_mp.runs.create("Model", "Scenario")
# Per default, list() lists only `default` version runs:
run.set_as_default()
unit = test_mp.units.create("Test Unit")
scalar_1 = run.optimization.scalars.create(
"Scalar 1", value=1, unit="Test Unit"
Expand All @@ -120,11 +118,21 @@ def test_list_scalars(self, test_mp, request):
]
assert not (set(expected_id) ^ set(list_id))

# Test that only Scalars belonging to this Run are listed
run_2 = test_mp.runs.create("Model", "Scenario")
scalar_3 = run_2.optimization.scalars.create(
"Scalar 1", value=1, unit="Test Unit"
)
scalar_4 = run_2.optimization.scalars.create(
"Scalar 2", value=2, unit=unit.name
)
expected_ids = [scalar_3.id, scalar_4.id]
list_ids = [scalar.id for scalar in run_2.optimization.scalars.list()]
assert not (set(expected_ids) ^ set(list_ids))

def test_tabulate_scalars(self, test_mp, request):
test_mp = request.getfixturevalue(test_mp)
test_mp: Platform = request.getfixturevalue(test_mp) # type: ignore
run = test_mp.runs.create("Model", "Scenario")
# Per default, tabulate() lists only `default` version runs:
run.set_as_default()
unit = test_mp.units.create("Test Unit")
scalar_1 = run.optimization.scalars.create("Scalar 1", value=1, unit=unit.name)
scalar_2 = run.optimization.scalars.create("Scalar 2", value=2, unit=unit.name)
Expand All @@ -136,8 +144,20 @@ def test_tabulate_scalars(self, test_mp, request):
result = run.optimization.scalars.tabulate(name="Scalar 2")
assert_unordered_equality(expected, result, check_dtype=False)

# Test that only Scalars belonging to this Run are tabulated
run_2 = test_mp.runs.create("Model", "Scenario")
scalar_3 = run_2.optimization.scalars.create(
"Scalar 1", value=1, unit=unit.name
)
scalar_4 = run_2.optimization.scalars.create(
"Scalar 2", value=2, unit=unit.name
)
expected = df_from_list(scalars=[scalar_3, scalar_4])
result = run_2.optimization.scalars.tabulate()
assert_unordered_equality(expected, result, check_dtype=False)

def test_scalar_docs(self, test_mp, request):
test_mp = request.getfixturevalue(test_mp)
test_mp: Platform = request.getfixturevalue(test_mp) # type: ignore
run = test_mp.runs.create("Model", "Scenario")
unit = test_mp.units.create("Test Unit")
scalar = run.optimization.scalars.create("Scalar 1", value=4, unit=unit.name)
Expand Down
41 changes: 31 additions & 10 deletions tests/data/test_optimization_scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import pandas.testing as pdt
import pytest

from ixmp4 import Scalar
from ixmp4.core import Platform, Scalar

from ..utils import all_platforms

Expand Down Expand Up @@ -36,7 +36,7 @@ def df_from_list(scalars: list):
@all_platforms
class TestDataOptimizationScalar:
def test_create_scalar(self, test_mp, request):
test_mp = request.getfixturevalue(test_mp)
test_mp: Platform = request.getfixturevalue(test_mp) # type: ignore
run = test_mp.backend.runs.create("Model", "Scenario")
unit = test_mp.backend.units.create("Unit")
unit2 = test_mp.backend.units.create("Unit 2")
Expand All @@ -54,7 +54,7 @@ def test_create_scalar(self, test_mp, request):
)

def test_get_scalar(self, test_mp, request):
test_mp = request.getfixturevalue(test_mp)
test_mp: Platform = request.getfixturevalue(test_mp) # type: ignore
run = test_mp.backend.runs.create("Model", "Scenario")
unit = test_mp.backend.units.create("Unit")
scalar = test_mp.backend.optimization.scalars.create(
Expand All @@ -68,7 +68,7 @@ def test_get_scalar(self, test_mp, request):
_ = test_mp.backend.optimization.scalars.get(run_id=run.id, name="Scalar 2")

def test_update_scalar(self, test_mp, request):
test_mp = request.getfixturevalue(test_mp)
test_mp: Platform = request.getfixturevalue(test_mp) # type: ignore
run = test_mp.backend.runs.create("Model", "Scenario")
unit = test_mp.backend.units.create("Unit")
unit2 = test_mp.backend.units.create("Unit 2")
Expand All @@ -87,10 +87,8 @@ def test_update_scalar(self, test_mp, request):
assert ret.value == 20

def test_list_scalars(self, test_mp, request):
test_mp = request.getfixturevalue(test_mp)
test_mp: Platform = request.getfixturevalue(test_mp) # type: ignore
run = test_mp.backend.runs.create("Model", "Scenario")
# Per default, list() lists scalars for `default` version runs:
test_mp.backend.runs.set_as_default_version(run.id)
unit = test_mp.backend.units.create("Unit")
unit2 = test_mp.backend.units.create("Unit 2")
scalar_1 = test_mp.backend.optimization.scalars.create(
Expand All @@ -102,11 +100,21 @@ def test_list_scalars(self, test_mp, request):
assert [scalar_1] == test_mp.backend.optimization.scalars.list(name="Scalar")
assert [scalar_1, scalar_2] == test_mp.backend.optimization.scalars.list()

# Test listing of scalars of particular run only
run_2 = test_mp.backend.runs.create("Model", "Scenario")
scalar_3 = test_mp.backend.optimization.scalars.create(
run_id=run_2.id, name="Scalar", value=1, unit_name=unit.name
)
scalar_4 = test_mp.backend.optimization.scalars.create(
run_id=run_2.id, name="Scalar 2", value=2, unit_name=unit2.name
)
assert [scalar_3, scalar_4] == test_mp.backend.optimization.scalars.list(
run_id=run_2.id
)

def test_tabulate_scalars(self, test_mp, request):
test_mp = request.getfixturevalue(test_mp)
test_mp: Platform = request.getfixturevalue(test_mp) # type: ignore
run = test_mp.backend.runs.create("Model", "Scenario")
# Per default, tabulate() lists scalars for `default` version runs:
test_mp.backend.runs.set_as_default_version(run.id)
unit = test_mp.backend.units.create("Unit")
unit2 = test_mp.backend.units.create("Unit 2")
scalar_1 = test_mp.backend.optimization.scalars.create(
Expand All @@ -124,3 +132,16 @@ def test_tabulate_scalars(self, test_mp, request):
pdt.assert_frame_equal(
expected, test_mp.backend.optimization.scalars.tabulate(name="Scalar")
)

# Test tabulation of scalars of particular run only
run_2 = test_mp.backend.runs.create("Model", "Scenario")
scalar_3 = test_mp.backend.optimization.scalars.create(
run_id=run_2.id, name="Scalar", value=1, unit_name=unit.name
)
scalar_4 = test_mp.backend.optimization.scalars.create(
run_id=run_2.id, name="Scalar 2", value=2, unit_name=unit2.name
)
expected = df_from_list(scalars=[scalar_3, scalar_4])
pdt.assert_frame_equal(
expected, test_mp.backend.optimization.scalars.tabulate(run_id=run_2.id)
)

0 comments on commit c306efc

Please sign in to comment.