-
Notifications
You must be signed in to change notification settings - Fork 24
/
mnist_utils.py
30 lines (26 loc) · 939 Bytes
/
mnist_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
# Databricks notebook source
import matplotlib.pyplot as plt
import numpy as np
# functions to show an image
def imshow(img):
img = img / 2 + 0.5 # unnormalize
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.show()
# COMMAND ----------
def log_mnist_model(model, run_id=None):
#create a conda env file which has requirement frameworks
from mlflow.utils.environment import _mlflow_conda_env
model_env = _mlflow_conda_env(
additional_pip_deps=[
"cloudpickle=={}".format(cloudpickle.__version__),
"torch=={}".format(torch.__version__),
"torchvision=={}".format(torchvision.__version__),
"pillow=={}".format("6.0.0")
]
)
if run_id:
with mlflow.start_run(run_id = run_id):
mlflow.pytorch.log_model(model, MODEL_SAVE_PATH, conda_env=model_env)
else:
mlflow.pytorch.log_model(model, MODEL_SAVE_PATH, conda_env=model_env)