From 3f9695a2f908fc071f6f36d75d6c1007504fa127 Mon Sep 17 00:00:00 2001 From: Louis Tiao Date: Wed, 11 Dec 2024 23:22:54 -0800 Subject: [PATCH] Simplified and optimized logic for calculating per-metric subsampling rate for MapData (#3106) Summary: Pull Request resolved: https://github.com/facebook/Ax/pull/3106 This refines the logic for calculating per-metric subsampling rates in `MapData.subsample` and incorporates a (probably premature) performance optimization, achieved by utilizing binary search on a sorted list instead of linear search. Reviewed By: Balandat Differential Revision: D66366076 --- ax/core/map_data.py | 69 ++++++++++++++++++++++++++++++++++----------- 1 file changed, 52 insertions(+), 17 deletions(-) diff --git a/ax/core/map_data.py b/ax/core/map_data.py index 35f777d38b1..06da05e4a71 100644 --- a/ax/core/map_data.py +++ b/ax/core/map_data.py @@ -7,12 +7,14 @@ from __future__ import annotations +from bisect import bisect_right from collections.abc import Iterable, Sequence from copy import deepcopy from logging import Logger from typing import Any, Generic, TypeVar import numpy as np +import numpy.typing as npt import pandas as pd from ax.core.data import Data from ax.core.types import TMapTrialEvaluation @@ -412,6 +414,48 @@ def subsample( ) +def _ceil_divide( + a: int | np.int_ | npt.NDArray[np.int_], b: int | np.int_ | npt.NDArray[np.int_] +) -> np.int_ | npt.NDArray[np.int_]: + return -np.floor_divide(-a, b) + + +def _subsample_rate( + map_df: pd.DataFrame, + keep_every: int | None = None, + limit_rows_per_group: int | None = None, + limit_rows_per_metric: int | None = None, +) -> int: + if keep_every is not None: + return keep_every + + grouped_map_df = map_df.groupby(MapData.DEDUPLICATE_BY_COLUMNS) + group_sizes = grouped_map_df.size() + max_rows = group_sizes.max() + + if limit_rows_per_group is not None: + return _ceil_divide(max_rows, limit_rows_per_group).item() + + if limit_rows_per_metric is not None: + # search for the `keep_every` such that when you apply it to each group, + # the total number of rows is smaller than `limit_rows_per_metric`. + ks = np.arange(max_rows, 0, -1) + # total sizes in ascending order + total_sizes = np.sum( + _ceil_divide(group_sizes.values, ks[..., np.newaxis]), axis=1 + ) + # binary search + i = bisect_right(total_sizes, limit_rows_per_metric) + # if no such `k` is found, then `derived_keep_every` stays as 1. + if i > 0: + return ks[i - 1].item() + + raise ValueError( + "at least one of `keep_every`, `limit_rows_per_group`, " + "or `limit_rows_per_metric` must be specified." + ) + + def _subsample_one_metric( map_df: pd.DataFrame, map_key: str | None = None, @@ -421,30 +465,21 @@ def _subsample_one_metric( include_first_last: bool = True, ) -> pd.DataFrame: """Helper function to subsample a dataframe that holds a single metric.""" - derived_keep_every = 1 - if keep_every is not None: - derived_keep_every = keep_every - elif limit_rows_per_group is not None: - max_rows = map_df.groupby(MapData.DEDUPLICATE_BY_COLUMNS).size().max() - derived_keep_every = np.ceil(max_rows / limit_rows_per_group) - elif limit_rows_per_metric is not None: - group_sizes = map_df.groupby(MapData.DEDUPLICATE_BY_COLUMNS).size().to_numpy() - # search for the `keep_every` such that when you apply it to each group, - # the total number of rows is smaller than `limit_rows_per_metric`. - for k in range(1, group_sizes.max() + 1): - if (np.ceil(group_sizes / k)).sum() <= limit_rows_per_metric: - derived_keep_every = k - break - # if no such `k` is found, then `derived_keep_every` stays as 1. + + grouped_map_df = map_df.groupby(MapData.DEDUPLICATE_BY_COLUMNS) + + derived_keep_every = _subsample_rate( + map_df, keep_every, limit_rows_per_group, limit_rows_per_metric + ) if derived_keep_every <= 1: filtered_map_df = map_df else: filtered_dfs = [] - for _, df_g in map_df.groupby(MapData.DEDUPLICATE_BY_COLUMNS): + for _, df_g in grouped_map_df: df_g = df_g.sort_values(map_key) if include_first_last: - rows_per_group = int(np.ceil(len(df_g) / derived_keep_every)) + rows_per_group = _ceil_divide(len(df_g), derived_keep_every) linspace_idcs = np.linspace(0, len(df_g) - 1, rows_per_group) idcs = np.round(linspace_idcs).astype(int) filtered_df = df_g.iloc[idcs]