From 0696f463718af9da07ab0806face6189987e7b29 Mon Sep 17 00:00:00 2001 From: DeepMind Team Date: Thu, 5 Dec 2024 05:59:29 -0800 Subject: [PATCH] Add GPU option to cifar10_tensorflow example. PiperOrigin-RevId: 703085899 Change-Id: Icaf5c0e524a41c461744d5824ca388c75658cdd6 GitOrigin-RevId: 06593deec2145f64ddd1323c44c8f70fedcc2ef5 --- examples/cifar10_tensorflow/launcher.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/examples/cifar10_tensorflow/launcher.py b/examples/cifar10_tensorflow/launcher.py index f0a7922..f6e7249 100644 --- a/examples/cifar10_tensorflow/launcher.py +++ b/examples/cifar10_tensorflow/launcher.py @@ -30,6 +30,7 @@ FLAGS = flags.FLAGS flags.DEFINE_string('tensorboard', None, 'Tensorboard instance.') +flags.DEFINE_integer('gpus_per_node', 2, 'Number of GPUs per node.') def main(_): @@ -76,7 +77,10 @@ def main(_): experiment.add( xm.Job( executable=executable, - executor=xm_local.Vertex(tensorboard=tensorboard_capability), + executor=xm_local.Vertex( + tensorboard=tensorboard_capability, + requirements=xm.JobRequirements(t4=FLAGS.gpus_per_node), + ), args=hyperparameters, ) )