diff --git a/python/resdata/summary/rd_sum.py b/python/resdata/summary/rd_sum.py index 709b67f2b..b644a0bc6 100644 --- a/python/resdata/summary/rd_sum.py +++ b/python/resdata/summary/rd_sum.py @@ -13,7 +13,7 @@ import ctypes import pandas as pd import re -from typing import Sequence, List, Tuple, Optional +from typing import Sequence, List, Tuple, Optional, Union # Observe that there is some convention conflict with the C code # regarding order of arguments: The C code generally takes the time @@ -477,7 +477,9 @@ def get_values(self, key, report_only=False): else: raise KeyError("Summary object does not have key:%s" % key) - def _make_time_vector(self, time_index): + def _make_time_vector( + self, time_index: Sequence[Union[CTime, datetime.datetime, int, datetime.date]] + ) -> TimeVector: time_points = TimeVector() for t in time_index: time_points.append(t) @@ -558,7 +560,11 @@ def report_dates(self): dates.append(self.get_report_time(report)) return dates - def pandas_frame(self, time_index=None, column_keys=None): + def pandas_frame( + self, + time_index: Optional[Sequence[datetime.datetime]] = None, + column_keys: Optional[Sequence[str]] = None, + ) -> pd.DataFrame: """Will create a pandas frame with summary data. By default you will get all time points in the summary case, but by @@ -615,8 +621,7 @@ def pandas_frame(self, time_index=None, column_keys=None): data.ctypes.data_as(ctypes.POINTER(ctypes.c_double)), ) - frame = pd.DataFrame(index=list(time_index), columns=list(keywords), data=data) - return frame + return pd.DataFrame(index=list(time_index), columns=list(keywords), data=data) @staticmethod def _compile_headers_list( diff --git a/python/tests/rd_tests/test_sum.py b/python/tests/rd_tests/test_sum.py index 4fb455d5d..79c6a94d1 100644 --- a/python/tests/rd_tests/test_sum.py +++ b/python/tests/rd_tests/test_sum.py @@ -1,4 +1,5 @@ import csv +import pytest import datetime import os import os.path @@ -27,6 +28,7 @@ def assert_frame_equal(a, b): from resdata.summary import Summary, SummaryKeyWordVector, SummaryVarType from resdata.util.test import TestAreaContext from resdata.util.test.mock import createSummary +from resdata.util.util import CTime, TimeVector from tests import ResdataTest @@ -597,30 +599,6 @@ def test_wells_and_groups(self): self.assertEqual(case.wells(), []) self.assertEqual(case.groups(), []) - def test_pandas(self): - case = create_case() - dates = ( - [datetime.datetime(2000, 1, 1)] - + case.dates - + [datetime.datetime(2020, 1, 1)] - ) - frame = case.pandas_frame(column_keys=["FOPT", "FOPR"], time_index=dates) - - fopr = frame["FOPR"] - fopt = frame["FOPT"] - - self.assertEqual(fopr[0], 0) - self.assertEqual(fopr[-1], 0) - - self.assertEqual(fopt[0], 0) - self.assertEqual(fopt[0], case.first_value("FOPT")) - self.assertEqual(fopt[-1], case.last_value("FOPT")) - - frame = case.pandas_frame() - rows, columns = frame.shape - self.assertEqual(len(case.keys()), columns) - self.assertEqual(len(case), rows) - def test_csv_load(self): case = create_case2() frame = case.pandas_frame() @@ -692,7 +670,6 @@ def test_resample_extrapolate(self): """ Test resampling of summary with extrapolate option of lower and upper boundaries enabled """ - from resdata.util.util import CTime, TimeVector time_points = TimeVector() @@ -769,6 +746,37 @@ def test_pandas2_compatibility_dataframe_index(self): +def create_time_vector(lst): + vec = TimeVector() + for l in lst: + vec.append(l) + return vec + + +@pytest.mark.parametrize("time_index_type", [list, create_time_vector, tuple]) +def test_pandas(time_index_type): + case = create_case() + dates = time_index_type( + [datetime.datetime(2000, 1, 1)] + case.dates + [datetime.datetime(2020, 1, 1)] + ) + frame = case.pandas_frame(column_keys=["FOPT", "FOPR"], time_index=dates) + + fopr = frame["FOPR"] + fopt = frame["FOPT"] + + assert fopr[0] == 0 + assert fopr[-1] == 0 + + assert fopt[0] == 0 + assert fopt[0] == case.first_value("FOPT") + assert fopt[-1] == case.last_value("FOPT") + + frame = case.pandas_frame() + rows, columns = frame.shape + assert len(case.keys()) == columns + assert len(case) == rows + + def test_t_step(): sum = createSummary( "CASE",