diff --git a/scripts/inference/hf_generate.py b/scripts/inference/hf_generate.py index 96592ca477..45ddc6b63e 100644 --- a/scripts/inference/hf_generate.py +++ b/scripts/inference/hf_generate.py @@ -217,6 +217,7 @@ def main(args: Namespace) -> None: if device is not None: print(f'Placing model on {device=}...') model.to(device) + model.to(model_dtype) except Exception as e: raise RuntimeError( 'Unable to load HF model. ' +