Skip to content

Commit

Permalink
Update pre-commit hooks to use latest versions of ruff and mypy (#236)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
andersy005 and pre-commit-ci[bot] authored Sep 11, 2024
1 parent bbe3499 commit aded14b
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 38 deletions.
5 changes: 2 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ repos:
- id: check-yaml

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: "v0.5.5"
rev: "v0.6.4"
hooks:
- id: ruff
args: ["--fix"]
Expand All @@ -26,13 +26,12 @@ repos:
- id: prettier

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.11.0
rev: v1.11.2
hooks:
- id: mypy
additional_dependencies: [
# Type stubs
types-setuptools,
types-pkg_resources,
# Dependencies that are typed
numpy,
xarray,
Expand Down
71 changes: 36 additions & 35 deletions doc/demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
"outputs": [],
"source": [
"import xarray as xr\n",
"\n",
"import xbatcher"
]
},
Expand All @@ -46,12 +47,12 @@
"metadata": {},
"outputs": [],
"source": [
"store = \"az://carbonplan-share/example_cmip6_data.zarr\"\n",
"store = 'az://carbonplan-share/example_cmip6_data.zarr'\n",
"ds = xr.open_dataset(\n",
" store,\n",
" engine=\"zarr\",\n",
" engine='zarr',\n",
" chunks={},\n",
" backend_kwargs={\"storage_options\": {\"account_name\": \"carbonplan\"}},\n",
" backend_kwargs={'storage_options': {'account_name': 'carbonplan'}},\n",
")\n",
"\n",
"# the attributes contain a lot of useful information, but clutter the print out when we inspect the outputs\n",
Expand Down Expand Up @@ -98,10 +99,10 @@
"\n",
"bgen = xbatcher.BatchGenerator(\n",
" ds=ds,\n",
" input_dims={\"time\": n_timepoint_in_each_sample},\n",
" input_dims={'time': n_timepoint_in_each_sample},\n",
")\n",
"\n",
"print(f\"{len(bgen)} batches\")"
"print(f'{len(bgen)} batches')"
]
},
{
Expand Down Expand Up @@ -133,7 +134,7 @@
"outputs": [],
"source": [
"expected_n_batch = len(ds.time) / n_timepoint_in_each_sample\n",
"print(f\"Expecting {expected_n_batch} batches, getting {len(bgen)} batches\")"
"print(f'Expecting {expected_n_batch} batches, getting {len(bgen)} batches')"
]
},
{
Expand All @@ -153,7 +154,7 @@
"source": [
"expected_batch_size = len(ds.lat) * len(ds.lon)\n",
"print(\n",
" f\"Expecting {expected_batch_size} samples per batch, getting {len(batch.sample)} samples per batch\"\n",
" f'Expecting {expected_batch_size} samples per batch, getting {len(batch.sample)} samples per batch'\n",
")"
]
},
Expand All @@ -179,12 +180,12 @@
"\n",
"bgen = xbatcher.BatchGenerator(\n",
" ds=ds,\n",
" input_dims={\"time\": n_timepoint_in_each_sample},\n",
" batch_dims={\"time\": n_timepoint_in_each_batch},\n",
" input_dims={'time': n_timepoint_in_each_sample},\n",
" batch_dims={'time': n_timepoint_in_each_batch},\n",
" concat_input_dims=True,\n",
")\n",
"\n",
"print(f\"{len(bgen)} batches\")"
"print(f'{len(bgen)} batches')"
]
},
{
Expand Down Expand Up @@ -217,11 +218,11 @@
"source": [
"n_timepoint_in_batch = 31\n",
"\n",
"bgen = xbatcher.BatchGenerator(ds=ds, input_dims={\"time\": n_timepoint_in_batch})\n",
"bgen = xbatcher.BatchGenerator(ds=ds, input_dims={'time': n_timepoint_in_batch})\n",
"\n",
"for batch in bgen:\n",
" print(f\"last time point in ds is {ds.time[-1].values}\")\n",
" print(f\"last time point in batch is {batch.time[-1].values}\")\n",
" print(f'last time point in ds is {ds.time[-1].values}')\n",
" print(f'last time point in batch is {batch.time[-1].values}')\n",
"batch"
]
},
Expand Down Expand Up @@ -249,15 +250,15 @@
"\n",
"bgen = xbatcher.BatchGenerator(\n",
" ds=ds,\n",
" input_dims={\"time\": n_timepoint_in_each_sample},\n",
" batch_dims={\"time\": n_timepoint_in_each_batch},\n",
" input_dims={'time': n_timepoint_in_each_sample},\n",
" batch_dims={'time': n_timepoint_in_each_batch},\n",
" concat_input_dims=True,\n",
" input_overlap={\"time\": input_overlap},\n",
" input_overlap={'time': input_overlap},\n",
")\n",
"\n",
"batch = bgen[0]\n",
"\n",
"print(f\"{len(bgen)} batches\")\n",
"print(f'{len(bgen)} batches')\n",
"batch"
]
},
Expand All @@ -283,10 +284,10 @@
"display(pixel)\n",
"\n",
"print(\n",
" f\"sample 1 goes from {pixel.isel(input_batch=0).time[0].values} to {pixel.isel(input_batch=0).time[-1].values}\"\n",
" f'sample 1 goes from {pixel.isel(input_batch=0).time[0].values} to {pixel.isel(input_batch=0).time[-1].values}'\n",
")\n",
"print(\n",
" f\"sample 2 goes from {pixel.isel(input_batch=1).time[0].values} to {pixel.isel(input_batch=1).time[-1].values}\"\n",
" f'sample 2 goes from {pixel.isel(input_batch=1).time[0].values} to {pixel.isel(input_batch=1).time[-1].values}'\n",
")"
]
},
Expand All @@ -310,28 +311,28 @@
"outputs": [],
"source": [
"bgen = xbatcher.BatchGenerator(\n",
" ds=ds[[\"tasmax\"]].isel(lat=slice(0, 18), lon=slice(0, 18), time=slice(0, 30)),\n",
" input_dims={\"lat\": 9, \"lon\": 9, \"time\": 10},\n",
" batch_dims={\"lat\": 18, \"lon\": 18, \"time\": 15},\n",
" ds=ds[['tasmax']].isel(lat=slice(0, 18), lon=slice(0, 18), time=slice(0, 30)),\n",
" input_dims={'lat': 9, 'lon': 9, 'time': 10},\n",
" batch_dims={'lat': 18, 'lon': 18, 'time': 15},\n",
" concat_input_dims=True,\n",
" input_overlap={\"lat\": 8, \"lon\": 8, \"time\": 9},\n",
" input_overlap={'lat': 8, 'lon': 8, 'time': 9},\n",
")\n",
"\n",
"for i, batch in enumerate(bgen):\n",
" print(f\"batch {i}\")\n",
" print(f'batch {i}')\n",
" # make sure the ordering of dimension is consistent\n",
" batch = batch.transpose(\"input_batch\", \"lat_input\", \"lon_input\", \"time_input\")\n",
" batch = batch.transpose('input_batch', 'lat_input', 'lon_input', 'time_input')\n",
"\n",
" # only use the first 9 time points as features, since the last time point is the label to be predicted\n",
" features = batch.tasmax.isel(time_input=slice(0, 9))\n",
" # select the center pixel at the last time point to be the label to be predicted\n",
" # the actual lat/lon/time for each of the sample can be accessed in labels.coords\n",
" labels = batch.tasmax.isel(lat_input=5, lon_input=5, time_input=9)\n",
"\n",
" print(\"feature shape\", features.shape)\n",
" print(\"label shape\", labels.shape)\n",
" print(\"shape of lat of each sample\", labels.coords[\"lat\"].shape)\n",
" print(\"\")"
" print('feature shape', features.shape)\n",
" print('label shape', labels.shape)\n",
" print('shape of lat of each sample', labels.coords['lat'].shape)\n",
" print('')"
]
},
{
Expand All @@ -350,21 +351,21 @@
"outputs": [],
"source": [
"for i, batch in enumerate(bgen):\n",
" print(f\"batch {i}\")\n",
" print(f'batch {i}')\n",
" # make sure the ordering of dimension is consistent\n",
" batch = batch.transpose(\"input_batch\", \"lat_input\", \"lon_input\", \"time_input\")\n",
" batch = batch.transpose('input_batch', 'lat_input', 'lon_input', 'time_input')\n",
"\n",
" # only use the first 9 time points as features, since the last time point is the label to be predicted\n",
" features = batch.tasmax.isel(time_input=slice(0, 9))\n",
" features = features.stack(features=[\"lat_input\", \"lon_input\", \"time_input\"])\n",
" features = features.stack(features=['lat_input', 'lon_input', 'time_input'])\n",
"\n",
" # select the center pixel at the last time point to be the label to be predicted\n",
" # the actual lat/lon/time for each of the sample can be accessed in labels.coords\n",
" labels = batch.tasmax.isel(lat_input=5, lon_input=5, time_input=9)\n",
"\n",
" print(\"feature shape\", features.shape)\n",
" print(\"label shape\", labels.shape)\n",
" print(\"shape of lat of each sample\", labels.coords[\"lat\"].shape, \"\\n\")"
" print('feature shape', features.shape)\n",
" print('label shape', labels.shape)\n",
" print('shape of lat of each sample', labels.coords['lat'].shape, '\\n')"
]
},
{
Expand Down

0 comments on commit aded14b

Please sign in to comment.