Skip to content

Commit

Permalink
making train steps depend on arg
Browse files Browse the repository at this point in the history
  • Loading branch information
kingb12 committed Sep 9, 2023
1 parent a2ba682 commit 1e4f6c4
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 3 deletions.
2 changes: 1 addition & 1 deletion k8s/peft/starcoder_custom_train_example.yml
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ spec:
python -c "from huggingface_hub.hf_api import HfFolder; HfFolder.save_token('${HF_API_TOKEN}')" &&
echo "job ready to start" &&
export TOKENIZERS_PARALLELISM=true &&
conda run --no-capture-output -p ./venv python src/hf_libraries_demo/experiments/peft/custom_training_loop_example.py --batch_size 2 &&
conda run --no-capture-output -p ./venv python src/hf_libraries_demo/experiments/peft/custom_training_loop_example.py --batch_size 4 &&
echo "job complete!"
# some arguments needed by kubernetes, plus some useful defaults
volumeMounts:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,8 @@ def train(self):
parser.add_argument('--num_workers', type=int, default=4, help='Number of workers for data loading (default: 8)')
parser.add_argument('--pin_memory', action='store_true', default=True,
help='Use pinned (page-locked) memory. If not set, defaults to True.')
parser.add_argument('--max_train_steps', type=int, default=32,
help='number of training steps to take')
args = parser.parse_args()

# Load the and process dataset. Added more training data points to get a more complete test.
Expand Down Expand Up @@ -249,7 +251,7 @@ def train(self):

# ============== Start of code changes: implementing our own training loop ========================================

num_training_steps: int = 32
num_training_steps: int = args.max_train_steps

# Setup optimizer and scheduler
optimizer = get_optimizer(model, weight_decay=0.05, optimizer_class=AdamW)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
parser.add_argument('--num_workers', type=int, default=4, help='Number of workers for data loading (default: 8)')
parser.add_argument('--pin_memory', action='store_true', default=True,
help='Use pinned (page-locked) memory. If not set, defaults to True.')
parser.add_argument('--max_train_steps', type=int, default=32,
help='number of training steps to take')
args = parser.parse_args()

# Load the and process dataset. Added more training data points to get a more complete test.
Expand Down Expand Up @@ -99,7 +101,7 @@
evaluation_strategy="steps",
save_strategy="steps",
load_best_model_at_end=True,
max_steps=32,
max_steps=args.max_train_steps,
eval_steps=16,
save_steps=16,
logging_steps=1,
Expand Down

0 comments on commit 1e4f6c4

Please sign in to comment.