diff --git a/README.md b/README.md index 7137b48..67f2bf5 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,12 @@ -# Database +# Example +Two examples with sample data are wrapped up including + - [train a LSTM network to learn SMAP soil moisture](example/train-lstm.py) + - [estimate uncertainty of a LSTM network ](example/train-lstm-mca.py) + +A demo for temporal test is [here](example/demo-temporal-test.ipynb) + + +# Database description ## Database Structure ``` ├── CONUS @@ -14,7 +22,7 @@ │   │   ├── [Constant-Variable-Name].csv │   │   └── ... │   └── crd.csv -├── CONUSv4f1wSite +├── CONUSv4f1 │   └── ... ├── Statistics │   ├── [Variable-Name]_stat.csv @@ -23,14 +31,14 @@ │   └── ... ├── Subset │   ├── CONUS.csv -│   └── CONUSv4f1wSite.csv +│   └── CONUSv4f1.csv └── Variable ├── varConstLst.csv └── varLst.csv ``` -### 1. Dataset folders (*CONUS* , *CONUSv4f1wSite*) +### 1. Dataset folders (*CONUS* , *CONUSv4f1*) Data folder contains all data including both training and testing, time-dependent variables and constant variables. -In example data structure, there are two dataset folders - *CONUS* and *CONUSv4f1wSite*. Those data are saved in: +In example data structure, there are two dataset folders - *CONUS* and *CONUSv4f1*. Those data are saved in: - **year/[Variable-Name].csv**: @@ -81,20 +89,3 @@ If the index is -1 means all grid, from example CONUS dataset. Stored csv files contains a list of variables. Used as input to training code. Time-dependent variables and constant variables should be stored seperately. For example: - varLst.csv -> a list of time-dependent variables used as training predictors. - varLst.csv -> a list of constant variables used as training predictors. - -## Code to load dataset -initilize a dataset object of CONUS dataset, from 2015 to 2016, SMAP as target and variables inside variable list files as predictor. -``` python -rootDB = [path to database] -dataset = classDB.Dataset( - rootDB=rootDB, subsetName='CONUS', - yrLst=np.arange(2015,2017), - var=('varLst', 'varConstLst'), targetName='SMAP_AM') -``` -Read data and load predictor and target to x and y. -``` python -dataset.readInput(loadNorm=True) -dataset.readTarget(loadNorm=True) -x = dataset.normInput -y = dataset.normTarget -``` \ No newline at end of file diff --git a/example/.ipynb_checkpoints/demo-temporal-test-checkpoint.ipynb b/example/.ipynb_checkpoints/demo-temporal-test-checkpoint.ipynb new file mode 100644 index 0000000..4149fb1 --- /dev/null +++ b/example/.ipynb_checkpoints/demo-temporal-test-checkpoint.ipynb @@ -0,0 +1,940 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# demo\n", + "This is a demo for model temporal test and plot the result map and time series. Before this we trained a model using [train-lstm.py](train-lstm.py). By default the model will be saved in [here](example/output/CONUSv4f1/)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "1. Load packages and target SMAP observation" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "loading package hydroDL\n", + "/home/kxf227/work/GitHUB/pyRnnSMAP/example/data/Subset/CONUSv4f1.csv\n", + "read /home/kxf227/work/GitHUB/pyRnnSMAP/example/data/CONUSv4f1/2016/SMAP_AM.csv 0.047271728515625\n" + ] + } + ], + "source": [ + "import os\n", + "from hydroDL.data import dbCsv\n", + "from hydroDL.post import plot, stat\n", + "from hydroDL import master\n", + "\n", + "cDir = os.getcwd()\n", + "rootDB = os.path.join(cDir, 'data')\n", + "tRange = [20160401, 20170401]\n", + "df = dbCsv.DataframeCsv(\n", + " rootDB=rootDB, subset='CONUSv4f1', tRange=tRange)\n", + "yt = df.getData(varT='SMAP_AM', doNorm=False, rmNan=False)\n", + "yt = yt.squeeze()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "2. Test the model in another year" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "/home/kxf227/work/GitHUB/pyRnnSMAP/example/data/Subset/CONUSv4f1.csv\n", + "read /home/kxf227/work/GitHUB/pyRnnSMAP/example/data/CONUSv4f1/2016/APCP_FORA.csv 0.0474090576171875\n", + "read /home/kxf227/work/GitHUB/pyRnnSMAP/example/data/CONUSv4f1/2016/DLWRF_FORA.csv 0.05213618278503418\n", + "read /home/kxf227/work/GitHUB/pyRnnSMAP/example/data/CONUSv4f1/2016/DSWRF_FORA.csv 0.051689863204956055\n", + "read /home/kxf227/work/GitHUB/pyRnnSMAP/example/data/CONUSv4f1/2016/TMP_2_FORA.csv 0.05163002014160156\n", + "read /home/kxf227/work/GitHUB/pyRnnSMAP/example/data/CONUSv4f1/2016/SPFH_2_FORA.csv 0.05494546890258789\n", + "read /home/kxf227/work/GitHUB/pyRnnSMAP/example/data/CONUSv4f1/2016/VGRD_10_FORA.csv 0.05271744728088379\n", + "read /home/kxf227/work/GitHUB/pyRnnSMAP/example/data/CONUSv4f1/2016/UGRD_10_FORA.csv 0.053144216537475586\n" + ] + } + ], + "source": [ + "out = os.path.join(cDir, 'output', 'CONUSv4f1')\n", + "yp = master.test(\n", + " out, tRange=tRange, subset='CONUSv4f1')\n", + "yp = yp.squeeze()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "3. Calculate statistic metrices and plot the result. An interactive map will be generated, where users can click on map to show time series of observation and model predictions. " + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "application/javascript": [ + "/* Put everything inside the global mpl namespace */\n", + "window.mpl = {};\n", + "\n", + "\n", + "mpl.get_websocket_type = function() {\n", + " if (typeof(WebSocket) !== 'undefined') {\n", + " return WebSocket;\n", + " } else if (typeof(MozWebSocket) !== 'undefined') {\n", + " return MozWebSocket;\n", + " } else {\n", + " alert('Your browser does not have WebSocket support.' +\n", + " 'Please try Chrome, Safari or Firefox ≥ 6. ' +\n", + " 'Firefox 4 and 5 are also supported but you ' +\n", + " 'have to enable WebSockets in about:config.');\n", + " };\n", + "}\n", + "\n", + "mpl.figure = function(figure_id, websocket, ondownload, parent_element) {\n", + " this.id = figure_id;\n", + "\n", + " this.ws = websocket;\n", + "\n", + " this.supports_binary = (this.ws.binaryType != undefined);\n", + "\n", + " if (!this.supports_binary) {\n", + " var warnings = document.getElementById(\"mpl-warnings\");\n", + " if (warnings) {\n", + " warnings.style.display = 'block';\n", + " warnings.textContent = (\n", + " \"This browser does not support binary websocket messages. \" +\n", + " \"Performance may be slow.\");\n", + " }\n", + " }\n", + "\n", + " this.imageObj = new Image();\n", + "\n", + " this.context = undefined;\n", + " this.message = undefined;\n", + " this.canvas = undefined;\n", + " this.rubberband_canvas = undefined;\n", + " this.rubberband_context = undefined;\n", + " this.format_dropdown = undefined;\n", + "\n", + " this.image_mode = 'full';\n", + "\n", + " this.root = $('
');\n", + " this._root_extra_style(this.root)\n", + " this.root.attr('style', 'display: inline-block');\n", + "\n", + " $(parent_element).append(this.root);\n", + "\n", + " this._init_header(this);\n", + " this._init_canvas(this);\n", + " this._init_toolbar(this);\n", + "\n", + " var fig = this;\n", + "\n", + " this.waiting = false;\n", + "\n", + " this.ws.onopen = function () {\n", + " fig.send_message(\"supports_binary\", {value: fig.supports_binary});\n", + " fig.send_message(\"send_image_mode\", {});\n", + " if (mpl.ratio != 1) {\n", + " fig.send_message(\"set_dpi_ratio\", {'dpi_ratio': mpl.ratio});\n", + " }\n", + " fig.send_message(\"refresh\", {});\n", + " }\n", + "\n", + " this.imageObj.onload = function() {\n", + " if (fig.image_mode == 'full') {\n", + " // Full images could contain transparency (where diff images\n", + " // almost always do), so we need to clear the canvas so that\n", + " // there is no ghosting.\n", + " fig.context.clearRect(0, 0, fig.canvas.width, fig.canvas.height);\n", + " }\n", + " fig.context.drawImage(fig.imageObj, 0, 0);\n", + " };\n", + "\n", + " this.imageObj.onunload = function() {\n", + " fig.ws.close();\n", + " }\n", + "\n", + " this.ws.onmessage = this._make_on_message_function(this);\n", + "\n", + " this.ondownload = ondownload;\n", + "}\n", + "\n", + "mpl.figure.prototype._init_header = function() {\n", + " var titlebar = $(\n", + " '
');\n", + " var titletext = $(\n", + " '
');\n", + " titlebar.append(titletext)\n", + " this.root.append(titlebar);\n", + " this.header = titletext[0];\n", + "}\n", + "\n", + "\n", + "\n", + "mpl.figure.prototype._canvas_extra_style = function(canvas_div) {\n", + "\n", + "}\n", + "\n", + "\n", + "mpl.figure.prototype._root_extra_style = function(canvas_div) {\n", + "\n", + "}\n", + "\n", + "mpl.figure.prototype._init_canvas = function() {\n", + " var fig = this;\n", + "\n", + " var canvas_div = $('
');\n", + "\n", + " canvas_div.attr('style', 'position: relative; clear: both; outline: 0');\n", + "\n", + " function canvas_keyboard_event(event) {\n", + " return fig.key_event(event, event['data']);\n", + " }\n", + "\n", + " canvas_div.keydown('key_press', canvas_keyboard_event);\n", + " canvas_div.keyup('key_release', canvas_keyboard_event);\n", + " this.canvas_div = canvas_div\n", + " this._canvas_extra_style(canvas_div)\n", + " this.root.append(canvas_div);\n", + "\n", + " var canvas = $('');\n", + " canvas.addClass('mpl-canvas');\n", + " canvas.attr('style', \"left: 0; top: 0; z-index: 0; outline: 0\")\n", + "\n", + " this.canvas = canvas[0];\n", + " this.context = canvas[0].getContext(\"2d\");\n", + "\n", + " var backingStore = this.context.backingStorePixelRatio ||\n", + "\tthis.context.webkitBackingStorePixelRatio ||\n", + "\tthis.context.mozBackingStorePixelRatio ||\n", + "\tthis.context.msBackingStorePixelRatio ||\n", + "\tthis.context.oBackingStorePixelRatio ||\n", + "\tthis.context.backingStorePixelRatio || 1;\n", + "\n", + " mpl.ratio = (window.devicePixelRatio || 1) / backingStore;\n", + "\n", + " var rubberband = $('');\n", + " rubberband.attr('style', \"position: absolute; left: 0; top: 0; z-index: 1;\")\n", + "\n", + " var pass_mouse_events = true;\n", + "\n", + " canvas_div.resizable({\n", + " start: function(event, ui) {\n", + " pass_mouse_events = false;\n", + " },\n", + " resize: function(event, ui) {\n", + " fig.request_resize(ui.size.width, ui.size.height);\n", + " },\n", + " stop: function(event, ui) {\n", + " pass_mouse_events = true;\n", + " fig.request_resize(ui.size.width, ui.size.height);\n", + " },\n", + " });\n", + "\n", + " function mouse_event_fn(event) {\n", + " if (pass_mouse_events)\n", + " return fig.mouse_event(event, event['data']);\n", + " }\n", + "\n", + " rubberband.mousedown('button_press', mouse_event_fn);\n", + " rubberband.mouseup('button_release', mouse_event_fn);\n", + " // Throttle sequential mouse events to 1 every 20ms.\n", + " rubberband.mousemove('motion_notify', mouse_event_fn);\n", + "\n", + " rubberband.mouseenter('figure_enter', mouse_event_fn);\n", + " rubberband.mouseleave('figure_leave', mouse_event_fn);\n", + "\n", + " canvas_div.on(\"wheel\", function (event) {\n", + " event = event.originalEvent;\n", + " event['data'] = 'scroll'\n", + " if (event.deltaY < 0) {\n", + " event.step = 1;\n", + " } else {\n", + " event.step = -1;\n", + " }\n", + " mouse_event_fn(event);\n", + " });\n", + "\n", + " canvas_div.append(canvas);\n", + " canvas_div.append(rubberband);\n", + "\n", + " this.rubberband = rubberband;\n", + " this.rubberband_canvas = rubberband[0];\n", + " this.rubberband_context = rubberband[0].getContext(\"2d\");\n", + " this.rubberband_context.strokeStyle = \"#000000\";\n", + "\n", + " this._resize_canvas = function(width, height) {\n", + " // Keep the size of the canvas, canvas container, and rubber band\n", + " // canvas in synch.\n", + " canvas_div.css('width', width)\n", + " canvas_div.css('height', height)\n", + "\n", + " canvas.attr('width', width * mpl.ratio);\n", + " canvas.attr('height', height * mpl.ratio);\n", + " canvas.attr('style', 'width: ' + width + 'px; height: ' + height + 'px;');\n", + "\n", + " rubberband.attr('width', width);\n", + " rubberband.attr('height', height);\n", + " }\n", + "\n", + " // Set the figure to an initial 600x600px, this will subsequently be updated\n", + " // upon first draw.\n", + " this._resize_canvas(600, 600);\n", + "\n", + " // Disable right mouse context menu.\n", + " $(this.rubberband_canvas).bind(\"contextmenu\",function(e){\n", + " return false;\n", + " });\n", + "\n", + " function set_focus () {\n", + " canvas.focus();\n", + " canvas_div.focus();\n", + " }\n", + "\n", + " window.setTimeout(set_focus, 100);\n", + "}\n", + "\n", + "mpl.figure.prototype._init_toolbar = function() {\n", + " var fig = this;\n", + "\n", + " var nav_element = $('
')\n", + " nav_element.attr('style', 'width: 100%');\n", + " this.root.append(nav_element);\n", + "\n", + " // Define a callback function for later on.\n", + " function toolbar_event(event) {\n", + " return fig.toolbar_button_onclick(event['data']);\n", + " }\n", + " function toolbar_mouse_event(event) {\n", + " return fig.toolbar_button_onmouseover(event['data']);\n", + " }\n", + "\n", + " for(var toolbar_ind in mpl.toolbar_items) {\n", + " var name = mpl.toolbar_items[toolbar_ind][0];\n", + " var tooltip = mpl.toolbar_items[toolbar_ind][1];\n", + " var image = mpl.toolbar_items[toolbar_ind][2];\n", + " var method_name = mpl.toolbar_items[toolbar_ind][3];\n", + "\n", + " if (!name) {\n", + " // put a spacer in here.\n", + " continue;\n", + " }\n", + " var button = $('');\n", + " button.click(method_name, toolbar_event);\n", + " button.mouseover(tooltip, toolbar_mouse_event);\n", + " nav_element.append(button);\n", + " }\n", + "\n", + " // Add the status bar.\n", + " var status_bar = $('');\n", + " nav_element.append(status_bar);\n", + " this.message = status_bar[0];\n", + "\n", + " // Add the close button to the window.\n", + " var buttongrp = $('
');\n", + " var button = $('');\n", + " button.click(function (evt) { fig.handle_close(fig, {}); } );\n", + " button.mouseover('Stop Interaction', toolbar_mouse_event);\n", + " buttongrp.append(button);\n", + " var titlebar = this.root.find($('.ui-dialog-titlebar'));\n", + " titlebar.prepend(buttongrp);\n", + "}\n", + "\n", + "mpl.figure.prototype._root_extra_style = function(el){\n", + " var fig = this\n", + " el.on(\"remove\", function(){\n", + "\tfig.close_ws(fig, {});\n", + " });\n", + "}\n", + "\n", + "mpl.figure.prototype._canvas_extra_style = function(el){\n", + " // this is important to make the div 'focusable\n", + " el.attr('tabindex', 0)\n", + " // reach out to IPython and tell the keyboard manager to turn it's self\n", + " // off when our div gets focus\n", + "\n", + " // location in version 3\n", + " if (IPython.notebook.keyboard_manager) {\n", + " IPython.notebook.keyboard_manager.register_events(el);\n", + " }\n", + " else {\n", + " // location in version 2\n", + " IPython.keyboard_manager.register_events(el);\n", + " }\n", + "\n", + "}\n", + "\n", + "mpl.figure.prototype._key_event_extra = function(event, name) {\n", + " var manager = IPython.notebook.keyboard_manager;\n", + " if (!manager)\n", + " manager = IPython.keyboard_manager;\n", + "\n", + " // Check for shift+enter\n", + " if (event.shiftKey && event.which == 13) {\n", + " this.canvas_div.blur();\n", + " event.shiftKey = false;\n", + " // Send a \"J\" for go to next cell\n", + " event.which = 74;\n", + " event.keyCode = 74;\n", + " manager.command_mode();\n", + " manager.handle_keydown(event);\n", + " }\n", + "}\n", + "\n", + "mpl.figure.prototype.handle_save = function(fig, msg) {\n", + " fig.ondownload(fig, null);\n", + "}\n", + "\n", + "\n", + "mpl.find_output_cell = function(html_output) {\n", + " // Return the cell and output element which can be found *uniquely* in the notebook.\n", + " // Note - this is a bit hacky, but it is done because the \"notebook_saving.Notebook\"\n", + " // IPython event is triggered only after the cells have been serialised, which for\n", + " // our purposes (turning an active figure into a static one), is too late.\n", + " var cells = IPython.notebook.get_cells();\n", + " var ncells = cells.length;\n", + " for (var i=0; i= 3 moved mimebundle to data attribute of output\n", + " data = data.data;\n", + " }\n", + " if (data['text/html'] == html_output) {\n", + " return [cell, data, j];\n", + " }\n", + " }\n", + " }\n", + " }\n", + "}\n", + "\n", + "// Register the function which deals with the matplotlib target/channel.\n", + "// The kernel may be null if the page has been refreshed.\n", + "if (IPython.notebook.kernel != null) {\n", + " IPython.notebook.kernel.comm_manager.register_target('matplotlib', mpl.mpl_figure_comm);\n", + "}\n" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# calculate stat\n", + "statErr = stat.statError(yp, yt)\n", + "dataGrid = [statErr['RMSE'], statErr['Corr']]\n", + "dataTs = [yp, yt]\n", + "t = df.getT()\n", + "crd = df.getGeo()\n", + "mapNameLst = ['RMSE', 'Correlation']\n", + "tsNameLst = ['LSTM', 'SMAP']\n", + "colorMap = None\n", + "colorTs = None\n", + "# plot map and time series\n", + "%matplotlib notebook\n", + "plot.plotTsMap(\n", + " dataGrid,\n", + " dataTs,\n", + " crd,\n", + " t,\n", + " colorMap=colorMap,\n", + " mapNameLst=mapNameLst,\n", + " tsNameLst=tsNameLst,\n", + " figsize=[8,4])" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.8" + }, + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "state": {}, + "version_major": 2, + "version_minor": 0 + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/example/train-lstm-mca.py b/example/train-lstm-mca.py index cf51b70..d9185a0 100644 --- a/example/train-lstm-mca.py +++ b/example/train-lstm-mca.py @@ -9,7 +9,7 @@ master.default.optDataCsv, path=os.path.join(cDir, 'data'), subset='CONUSv4f1', - tRange=[20150401, 20160331], + tRange=[20150401, 20160401], ) optModel = master.default.optLstm optLoss = master.updateOpt( @@ -23,4 +23,4 @@ # test pred = master.test( - out, tRange=[20160401, 20170331], subset='CONUSv4f1', epoch=500) + out, tRange=[20160401, 20170401], subset='CONUSv4f1', epoch=500) diff --git a/hydroDL/post/plot.py b/hydroDL/post/plot.py index 6a66a44..d6af3b4 100644 --- a/hydroDL/post/plot.py +++ b/hydroDL/post/plot.py @@ -190,7 +190,8 @@ def plotTsMap(dataGrid, *, colorMap=None, mapNameLst=None, - tsNameLst=None): + tsNameLst=None, + figsize=[12, 6]): if type(dataGrid) is np.ndarray: dataGrid = [dataGrid] if type(dataTs) is np.ndarray: @@ -198,7 +199,7 @@ def plotTsMap(dataGrid, nMap = len(dataGrid) nTs = len(dataTs) - fig = plt.figure(figsize=[12, 6]) + fig = plt.figure(figsize=figsize) gs = gridspec.GridSpec(3, nMap) for k in range(nMap):