Skip to content

Commit

Permalink
fix notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
juannat7 committed Aug 27, 2024
1 parent 62d08ab commit adeda6a
Showing 1 changed file with 63 additions and 17 deletions.
80 changes: 63 additions & 17 deletions notebooks/02a_s2s_modeling.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"id": "6dbd3180-4d25-46b7-90fc-f8db79661301",
"metadata": {},
"outputs": [],
Expand All @@ -29,10 +29,19 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 2,
"id": "d8dc2f0a-5b6d-4382-87a8-fe2aecd1b786",
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/burg/home/jn2808/.conda/envs/bench/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"source": [
"import torch\n",
"from torch.utils.data import DataLoader\n",
Expand Down Expand Up @@ -72,10 +81,19 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 3,
"id": "562a1d04-95f7-4544-ac2a-d4f1bd092941",
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/burg/home/jn2808/.conda/envs/bench/lib/python3.9/site-packages/gribapi/__init__.py:23: UserWarning: ecCodes 2.31.0 or higher is recommended. You are running version 2.30.0\n",
" warnings.warn(\n"
]
}
],
"source": [
"# Specify train/val years + test benchmark\n",
"train_years = np.arange(2016, 2022)\n",
Expand Down Expand Up @@ -117,7 +135,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 4,
"id": "9b6144b9-ee33-449b-82ed-371ba1bc61a1",
"metadata": {},
"outputs": [],
Expand All @@ -133,7 +151,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 5,
"id": "5977a244-5138-418c-b4be-886fadfb6230",
"metadata": {},
"outputs": [],
Expand All @@ -147,18 +165,27 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 6,
"id": "1bc1bfee-e65a-4f20-b4ac-28067d3fdf92",
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"train/val x: torch.Size([4, 60, 121, 240])\n",
"train/val y: torch.Size([4, 1, 60, 121, 240])\n"
]
}
],
"source": [
"print(f'train/val x: {train_x.shape}') # Each tensor has the shape of (batch_size, params, lat, lon)\n",
"print(f'train/val y: {train_y.shape}') # Each tensor has the shape of (batch_size, step_size, params, lat, lon)"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 7,
"id": "1c4657f3-4bb1-4a13-8595-5f8a1d8b5054",
"metadata": {},
"outputs": [],
Expand Down Expand Up @@ -188,22 +215,33 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 8,
"id": "b1e23cbe-cf32-407a-95a0-3536298f6391",
"metadata": {},
"outputs": [],
"source": [
"# Specify model specifications\n",
"\n",
"model = cnn.UNet(input_size=train_x.shape[1], output_size=train_x.shape[1], dropout=True, dropout_rate=0.1)\n"
"model = cnn.UNet(input_size=train_x.shape[1], output_size=train_x.shape[1])\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 9,
"id": "036344fb-badd-49ef-9edc-22444c43a1e1",
"metadata": {},
"outputs": [],
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([4, 60, 121, 240])"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Run the model to get output\n",
"preds = model(train_x)\n",
Expand All @@ -221,7 +259,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 10,
"id": "9f5fac83-d23e-47c2-9919-e7804640b3f8",
"metadata": {},
"outputs": [],
Expand All @@ -232,10 +270,18 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 11,
"id": "99308b7a-424b-4c9a-a76e-a002c68f6ed8",
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1.0959651470184326\n"
]
}
],
"source": [
"# Compute error\n",
"preds = model(train_x)\n",
Expand Down

0 comments on commit adeda6a

Please sign in to comment.