Skip to content

Commit

Permalink
Add bucket code
Browse files Browse the repository at this point in the history
  • Loading branch information
ghiggi committed Aug 25, 2023
1 parent 6511254 commit 2180439
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 37 deletions.
66 changes: 31 additions & 35 deletions gpm_api/bucket/analysis.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,38 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Fri Aug 25 11:18:54 2023
@author: ghiggi
"""

import polars as pl
import numpy as np
import pandas as pd
import numpy as np
import pandas as pd
import polars as pl


def get_n_decimals(number):
number_str = str(number)
decimal_index = number_str.find('.')
decimal_index = number_str.find(".")

if decimal_index == -1:
return 0 # No decimal point found

# Count the number of characters after the decimal point
n_decimals = len(number_str) - decimal_index - 1
return n_decimals


def get_lat_bins(bin_spacing):
n_decimals = get_n_decimals(bin_spacing)
lat_buckets = np.arange(-90.0, 90.0+1e-6, bin_spacing).round(n_decimals)
lat_buckets = np.arange(-90.0, 90.0 + 1e-6, bin_spacing).round(n_decimals)
if lat_buckets[-1] != 90.0:
lat_buckets = np.append(lat_buckets, np.array([90.0]))
return lat_buckets


def get_lon_bins(bin_spacing):
n_decimals = get_n_decimals(bin_spacing)
lon_buckets = np.arange(-180.0, 180.0+1e-6, bin_spacing).round(n_decimals)
lon_buckets = np.arange(-180.0, 180.0 + 1e-6, bin_spacing).round(n_decimals)
if lon_buckets[-1] != 180.0:
lon_buckets = np.append(lon_buckets, np.array([180.0]))
return lon_buckets
Expand All @@ -42,26 +41,26 @@ def get_lon_bins(bin_spacing):
def get_lon_labels(bin_spacing):
n_decimals = get_n_decimals(bin_spacing)
lon_buckets = get_lon_bins(bin_spacing)
lon_labels = lon_buckets[0:-1] + bin_spacing/2
lon_labels = lon_labels.round(n_decimals+1)
lon_labels = lon_buckets[0:-1] + bin_spacing / 2
lon_labels = lon_labels.round(n_decimals + 1)
return lon_labels


def get_lat_labels(bin_spacing):
n_decimals = get_n_decimals(bin_spacing)
lat_buckets = get_lat_bins(bin_spacing)
lat_labels = lat_buckets[0:-1] + bin_spacing/2
lat_labels = lat_labels.round(n_decimals+1)
lat_labels = lat_buckets[0:-1] + bin_spacing / 2
lat_labels = lat_labels.round(n_decimals + 1)
return lat_labels


def get_cut_lat_breaks_labels(bin_spacing):
lat_labels = get_lat_labels(bin_spacing)
lat_buckets = get_lat_bins(bin_spacing)
# Define cut labels
# Define cut labels
cut_lat_labels = lat_labels.astype(str).tolist()
cut_lat_labels = ["outside_left"] + cut_lat_labels + ["outside_right"]
# Deal with left inclusion
# Deal with left inclusion
cut_lat_breaks = lat_buckets
cut_lat_breaks[0] = cut_lat_breaks[0] - 1e-6
return cut_lat_breaks, cut_lat_labels
Expand All @@ -72,56 +71,53 @@ def get_cut_lon_breaks_labels(bin_spacing):
lon_buckets = get_lon_bins(bin_spacing)
cut_lon_labels = lon_labels.astype(str).tolist()
cut_lon_labels = ["outside_left"] + cut_lon_labels + ["outside_right"]
# - Deal with left inclusion
# - Deal with left inclusion
cut_lon_breaks = lon_buckets
cut_lon_breaks[0] = cut_lon_breaks[0] - 1e-6
return cut_lon_breaks, cut_lon_labels


def pl_add_geographic_bins(df,
xbin_column,
ybin_column,
bin_spacing,
x_column="lon",
y_column="lat"):

def pl_add_geographic_bins(
df, xbin_column, ybin_column, bin_spacing, x_column="lon", y_column="lat"
):

cut_lon_breaks, cut_lon_labels = get_cut_lon_breaks_labels(bin_spacing)
cut_lat_breaks, cut_lat_labels = get_cut_lat_breaks_labels(bin_spacing)
cut_lat_breaks, cut_lat_labels = get_cut_lat_breaks_labels(bin_spacing)
df = df.with_columns(
pl.col(x_column).cut(cut_lon_breaks, labels=cut_lon_labels).alias(xbin_column),
pl.col(y_column).cut(cut_lat_breaks, labels=cut_lat_labels).alias(ybin_column),
)
# df.filter(pl.col(xbin_column) == "outside_left")
# df.filter(pl.col(xbin_column) == "outside_right")
return df
return df


def pl_df_to_xarray(df, xbin_column, ybin_column, bin_spacing):
df_stats_pd = df.to_pandas()
df_stats_pd.dtypes

df_stats_pd[xbin_column] = df_stats_pd[xbin_column].astype(float)
df_stats_pd[ybin_column] = df_stats_pd[ybin_column].astype(float)
df_stats_pd = df_stats_pd.set_index([xbin_column, ybin_column])

## Left join to the spatial bin template dataframe
# Create the MultiIndex from the combination of x and y bins
lon_labels = get_lon_labels(bin_spacing)
lat_labels = get_lat_labels(bin_spacing)
multi_index = pd.MultiIndex.from_product(
[lon_labels, lat_labels], names=[xbin_column, ybin_column]
)

# Create an empty DataFrame with the MultiIndex
empty_df = pd.DataFrame(index=multi_index)

empty_df
# Create final dataframe

# Create final dataframe
df_stats_pd = empty_df.join(df_stats_pd, how="left")

# Reshape to xarray
# Reshape to xarray
ds = df_stats_pd.to_xarray()
ds = ds.rename({xbin_column: "longitude", ybin_column: "latitude"})
return ds

return ds
2 changes: 0 additions & 2 deletions gpm_api/bucket/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,5 +445,3 @@ def merge_granule_buckets(
# **writer_kwargs,
# )
print("Parquet Dataset writing has completed")


0 comments on commit 2180439

Please sign in to comment.