diff --git a/README.md b/README.md index fafb677..635c061 100644 --- a/README.md +++ b/README.md @@ -8,6 +8,15 @@ It is 3D global output from the mid-top CAM model, on the original model grid. However, the demo data here is one very small part of the CAM output due to storage limit of Github. NN trained on this Demodata will not work. +# Installing + +Clone this repo and enter it.\ +Then run: +``` +pip install . +``` +to install the neccessary dependencies.\ +It is recommended this is done from inside a virtual environment. # data loader load 3D CAM data and reshaping them to the NN input. diff --git a/Model.py b/newCAM_emulation/Model.py similarity index 100% rename from Model.py rename to newCAM_emulation/Model.py diff --git a/NN_pred.py b/newCAM_emulation/NN_pred.py similarity index 100% rename from NN_pred.py rename to newCAM_emulation/NN_pred.py diff --git a/newCAM_emulation/__init__.py b/newCAM_emulation/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/loaddata.py b/newCAM_emulation/loaddata.py similarity index 100% rename from loaddata.py rename to newCAM_emulation/loaddata.py diff --git a/train.py b/newCAM_emulation/train.py similarity index 96% rename from train.py rename to newCAM_emulation/train.py index a8c1314..7379560 100644 --- a/train.py +++ b/newCAM_emulation/train.py @@ -40,8 +40,8 @@ def early_stop(self, validation_loss): ## load mean and std for normalization -fm = np.load('Demodata/mean_demo_sub.npz') -fs = np.load('Demodata/std_demo_sub.npz') +fm = np.load('../Demodata/mean_demo_sub.npz') +fs = np.load('../Demodata/std_demo_sub.npz') Um = fm['U'] Vm = fm['V'] @@ -92,7 +92,7 @@ def early_stop(self, validation_loss): if (iter > 1): model.load_state_dict(torch.load('conv_torch.pth')) print ('data loader iteration',iter) - filename = './Demodata/newCAM_demo_sub_' + str(iter).zfill(1) + '.nc' + filename = '../Demodata/newCAM_demo_sub_' + str(iter).zfill(1) + '.nc' print('working on: ', filename) F = nc.Dataset(filename) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..572ca9a --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,61 @@ +[build-system] +requires = ["setuptools >= 61"] +build-backend = "setuptools.build_meta" + +[project] +name = "newCAM_emulation" +version = "0.0.0" +description = "PyTorch Net to emulate the gravity wave drag in CAM" +authors = [ + { name="Qiang Sun", email="qiangsun@uchicago.edu" }, +] +readme = "README.md" +license = {file = "LICENSE"} +requires-python = ">=3.9" +classifiers = [ + "License :: OSI Approved :: MIT License", + "Development Status :: 3 - Alpha", + "Natural Language :: English", + "Programming Language :: Python :: 3", + 'Programming Language :: Python :: 3.9', + 'Programming Language :: Python :: 3.10', + 'Programming Language :: Python :: 3.11', + 'Programming Language :: Python :: 3.12', + 'Typing :: Typed', + "Operating System :: OS Independent", +] + +dependencies = [ + "numpy>=1.20.0", + "torch", + "torchvision", + "scipy", + "matplotlib", + "xarray", + "netcdf4", +] + +[project.optional-dependencies] +lint = [ + "black>=24.1.0", + "pylint", + # "mypy>=1.0.0", + # "pytest>=7.2.0", + # "pytest-mock", + "pydocstyle", +] + +[project.urls] +"Homepage" = "https://github.com/DataWaveProject/newCAM_emulation" +"Bug Tracker" = "https://github.com/DataWaveProject/newCAM_emulation/issues" + +[tool.setuptools] +# By default, include-package-data is true in pyproject.toml, so you do +# NOT have to specify this line. +include-package-data = true + +[tool.setuptools.packages.find] +where = ["."] # list of folders that contain the packages (["."] by default) +include = ["newCAM_emulation*"] # package names should match these glob patterns (["*"] by default) +exclude = ["Demodata/*"] # exclude packages matching these glob patterns (empty by default) +namespaces = false # to disable scanning PEP 420 namespaces (true by default)