Skip to content

Commit

Permalink
small edits
Browse files Browse the repository at this point in the history
  • Loading branch information
Paul Yang authored and Paul Yang committed Oct 16, 2024
1 parent be8a3c4 commit 210910e
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions examples/torch-training/TorchBasicExample.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,11 @@ def return_status(self):
# Define a cluster type - here we launch an on-demand AWS cluster with 1 NVIDIA A10G GPU.
# You can use any cloud you want, or existing compute
cluster = rh.ondemand_cluster(
name="a10g-cluster", instance_type="A10G:1", provider="aws"
name="a10g-cluster",
instance_type="A10G:1",
provider="aws",
autopstop_mins=120
# name="a10g-cluster", instance_type="T4:1", provider="gcp", autopstop_mins=120 # If we wanted to use GCP, for example
).up_if_not()

# Next, we define the environment for our module. This includes the required dependencies that need
Expand All @@ -244,9 +248,6 @@ def return_status(self):
name="torch_model"
) # Instantiating it based on the remote RH module, and naming it "torch_model".

# Though we could just as easily run identical code on local
# model = SimpleTrainer() # If instantiating a local example

# We set some settings for the model training
batch_size = 64
epochs = 5
Expand Down Expand Up @@ -286,3 +287,6 @@ def return_status(self):
example_data, example_target = local_dataset[0][0].unsqueeze(0), local_dataset[0][1]
prediction = model.predict(example_data)
print(f"Predicted: {prediction}, Actual: {example_target}")

# Down the cluster when done. If you have saved the cluster to Runhouse Den, you can also reuse the cluster by name for another task with `cluster = rh.cluster('a10g-cluster')`
cluster.teardown()

0 comments on commit 210910e

Please sign in to comment.