Skip to content

Commit

Permalink
updated workflow and addressed comments
Browse files Browse the repository at this point in the history
  • Loading branch information
mitanshudodia committed Jan 2, 2025
1 parent 0aaef53 commit 99278e7
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 10 deletions.
28 changes: 21 additions & 7 deletions workflows/train_mnist_wf/README.md
Original file line number Diff line number Diff line change
@@ -1,14 +1,28 @@
## Change the following variable value before deploying the workflow on your workspace
## There are few pre-requisite we have fulfill before deploying the workflow

- Default value of `workspace_fqn` in main workflow function.
- Value of env variables `TFY_API_KEY` `TFY_HOST` in `task_config`.
- you can use virtual accout token as `TFY_API_KEY`, click [here](https://docs.truefoundry.com/docs/generating-truefoundry-api-keys#virtual-accounts) to learn about how to create virtual account.
- `host` value in `deploy.py`
### Creating the ml repo and giving the workspace the access to that ml repo
- First create a ml repo where you want to log the models. To learn about how to create a ml repo, click [here](https://docs.truefoundry.com/docs/creating-a-ml-repo#/).
- Give ml repo access to the workspace where you will be deploying your workflow and the model. To know about how to give access click [here](https://docs.truefoundry.com/docs/key-concepts#/grant-access-of-ml-repo-to-workspace)

### Setting the value of default variables

- Set the value of env variables `TFY_API_KEY` `TFY_HOST` in `task_config` in `train-deploy-workflow.py` file.
- you can use virtual account token as `TFY_API_KEY`, click [here](https://docs.truefoundry.com/docs/generating-truefoundry-api-keys#virtual-accounts) to learn about how to create a virtual account.
- `host` value in `Port` field in `deploy.py` file

## Deploying the workflow

You can deploy the workflow using the following command, make sure your truefoudry cli version is more thatn `4.0.0`.

```bash
tfy deploy workflow --name <wf-name> --file <file-name> --workspace-fqn <workspace-fqn>
```
tfy deploy workflow --name <wf-name> --file train-deploy-workflow.py --workspace-fqn <workspace-fqn>
```
**Make sure you have workflow helm chart installed in the workspace in which you are deploying workflow**

## Executing the workflow
The workflow takes following arguments as input while executing the workflow.
`ml_repo`: The name of the ml repo where you want to deploy the model. The workspace should have access to this ml repo.
`workspace_fqn`: Workspace fqn where you want to deploy the model.
`epochs`: An array of integer which define the number of epoch you want to train the model for, each epoch will run with corresponding learning rate which you will give in `learning_rate` argument. The lenght of `epochs` and `learning_rate` shoud be same.
`learning_rate`: An array of float where each number is the learning rate you want your model to train with, corresponding to the epochs defined at same postion.
`accuracy_threshold`: The threshold value, so the workflow will deploy the model if its validation accuracy is greater than this threshold accuracy.
9 changes: 6 additions & 3 deletions workflows/train_mnist_wf/train-deploy-workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
resources=Resources(cpu_request=1.2, cpu_limit=1.2, memory_limit=3000, memory_request=3000, ephemeral_storage_limit=2000, ephemeral_storage_request=2000),
service_account="default",
env={
"TF_CPP_MIN_LOG_LEVEL": "3", # suppress tensorflow warnings
"FLYTE_SDK_LOGGING_LEVEL": "40",
"TFY_API_KEY": "<your-api-key>",
"TFY_HOST": "<tfy-host-value>",
}
Expand Down Expand Up @@ -56,6 +58,7 @@ def train_model(epochs: int, learning_rate: float, data: Dict[str, np.array], ml
model.compile(optimizer=optimizer, loss="sparse_categorical_crossentropy", metrics=["accuracy"])

epochs = epochs
print(f"Started training the model for {epochs} epochs")
history = model.fit(x_train, y_train, epochs=epochs, validation_data=(x_test, y_test))

# Evaluate the model
Expand Down Expand Up @@ -122,11 +125,11 @@ def model_not_found(threshold: float) -> str:


@workflow
def model_training_workflow(ml_repo: str, epochs: List[int] = [2, 3, 5], learning_rate: List[float] = [0.1, 0.001, 0.001], accuracy_threshold: float = 0.15, workspace_fqn: str = "<your-ws-fqn>") -> Union[str, None]:
def model_training_workflow(ml_repo: str, workspace_fqn: str, epochs: List[int] = [2, 3, 5], learning_rate: List[float] = [0.1, 0.001, 0.001], accuracy_threshold: float = 0.15) -> Union[str, None]:
data = fetch_data()
train_model_function = partial(train_model, data=data, ml_repo=ml_repo)
fqns = map_task(train_model_function, concurrency=2)(epochs=epochs, learning_rate=learning_rate)
best_fqn, is_best_model_found = get_run_fqn_of_best_model(fqns=fqns, threshold=accuracy_threshold)
message = conditional("Deploy best model").if_(is_best_model_found == True).then(deploy_model(run_fqn=best_fqn, workspace_fqn=workspace_fqn)).else_().then(model_not_found(threshold=accuracy_threshold))
model_version_fqn, does_model_pass_threshold_accuracy = get_run_fqn_of_best_model(fqns=fqns, threshold=accuracy_threshold)
message = conditional("Deploy model").if_(does_model_pass_threshold_accuracy == True).then(deploy_model(run_fqn=model_version_fqn, workspace_fqn=workspace_fqn)).else_().then(model_not_found(threshold=accuracy_threshold))

return message

0 comments on commit 99278e7

Please sign in to comment.