From 35827e5bcd979e4e26b49e079938f0cc471d4d3a Mon Sep 17 00:00:00 2001 From: gkumbhat Date: Wed, 2 Aug 2023 20:48:59 -0500 Subject: [PATCH 01/12] :sparkles: Add fine-tuning support in text generation local module Signed-off-by: gkumbhat Signed-off-by: gkumbhat --- .../text_generation/text_generation_local.py | 253 +++++++++++++++++- .../test_text_generation_local.py | 89 +++++- 2 files changed, 329 insertions(+), 13 deletions(-) diff --git a/caikit_nlp/modules/text_generation/text_generation_local.py b/caikit_nlp/modules/text_generation/text_generation_local.py index ba35f8dd..7cd33813 100644 --- a/caikit_nlp/modules/text_generation/text_generation_local.py +++ b/caikit_nlp/modules/text_generation/text_generation_local.py @@ -16,12 +16,15 @@ # Standard import gc import os +from typing import Optional # Third Party -from transformers import AutoConfig +from torch.utils.data import IterableDataset +from transformers import AutoConfig, AutoTokenizer import torch # First Party +from caikit.core.data_model import DataStream from caikit.core.modules import ModuleBase, ModuleConfig, ModuleSaver, module from caikit.core.toolkit import error_handler from caikit.interfaces.nlp.data_model import GeneratedTextResult @@ -29,11 +32,14 @@ import alog # Local +from ...data_model import GenerationTrainRecord from ...resources.pretrained_model import ( HFAutoCausalLM, HFAutoSeq2SeqLM, PretrainedModelBase, ) +from ...toolkit.data_stream_wrapper import SimpleIterableStreamWrapper +from ...toolkit.data_type_utils import get_torch_dtype log = alog.use_channel("TXT_GEN") error = error_handler.get(log) @@ -49,25 +55,32 @@ class TextGeneration(ModuleBase): """Module to provide text generation capabilities""" + RANDOM_SEED = 73 supported_resources = [HFAutoCausalLM, HFAutoSeq2SeqLM] def __init__( self, - base_model_name: str, - base_model: PretrainedModelBase = None, - eos_token: str = None, + model_name: str, + model: PretrainedModelBase = None, + bos_token: Optional[str] = None, + sep_token: Optional[str] = None, + eos_token: Optional[str] = None, + pad_token: Optional[str] = None, ): super().__init__() error.type_check("", str, allow_none=True, eos_token=eos_token) - self.base_model = base_model - self.base_model_name = base_model_name + self.model = model + self.model_name = model_name + self._bos_token = bos_token + self._sep_token = sep_token self._eos_token = eos_token + self._pad_token = pad_token # pylint: disable=duplicate-code def __del__(self): - del self.base_model + del self.model gc.collect() try: torch.cuda.empty_cache() @@ -114,6 +127,193 @@ def bootstrap(cls, base_model_path: str): eos_token=eos_token, ) + + @classmethod + def train( + cls, + base_model: str, # TODO: Union[str, PretrainedModelBase] + train_stream: DataStream[GenerationTrainRecord], + torch_dtype: str = None, # TODO: Optional[Union[torch.dtype, str]] + max_source_length: int = 256, + max_target_length: int = 128, + batch_size: int = 8, + num_epochs: int = 5, + accumulate_steps: int = 32, + random_seed: int = RANDOM_SEED, + lr: float = 2e-5, + # Directory where model predictions and checkpoints will be written + checkpoint_dir: str = "/tmp", + **training_arguments, + ): + """ + Fine-tune a CausalLM or Seq2seq text generation model. + + Args: + base_model: Union[str, caikit_nlp.resources.pretrained_model.base.PretrainedModelBase] + Base resource model used for underlying generation. + train_stream: DataStream[GenerationTrainRecord] or DataStream[ClassificationTrainRecord] + Data to be used for fine-tuning the generation model. + torch_dtype: str + TODO: Optional[Union[torch.dtype, str]] + Data type to use for training/inference of the underlying text generation model. + If no value is provided, we pull from torch_dtype in config. If an in memory + resource is provided which does not match the specified data type, the model + underpinning the resource will be converted in place to the correct torch dtype. + max_source_length: int + Max length of input sequences being considered. Default: 256. + max_target_length: int + Max length of target sequences being predicted. Default: 128. + batch_size: int + Batch sized to be used for training / evaluation data. Default: 8. + num_epochs: int + Number of epochs to tune the model. Default: 20. + accumulate_steps: int + Number of steps to use for gradient accumulation. Default: 1. + lr: float + Learning rate to be used while tuning model. Default: 2e-5. + checkpoint_dir: str + Directory where model predictions and checkpoints will be written + **training_arguments: + Arguments supported by HF Training Arguments. + TrainingArguments: + https://huggingface.co/docs/transformers/v4.30.0/en/main_classes/trainer#transformers.TrainingArguments + Seq2SeqTrainingArguments: + https://huggingface.co/docs/transformers/v4.30.0/en/main_classes/trainer#transformers.Seq2SeqTrainingArguments + Returns: + FineTuning + Instance of this class with fine-tuned models. + """ + + torch_dtype = get_torch_dtype(torch_dtype) + + ## NOTE: Below code has been used in couple of places at this point, like in + # text_generation module. In future, we would want to consolidate this into + # a base class or a toolkit function + # pylint: disable=duplicate-code + resource_type = None + + ## Load base model + if isinstance(base_model, str): + model_config = AutoConfig.from_pretrained(base_model) + + for resource in cls.supported_resources: + if model_config.model_type in resource.SUPPORTED_MODEL_TYPES: + resource_type = resource + break + + if not resource_type: + error( + "", + "{} model type is not supported currently!".format( + model_config.model_type + ), + ) + log.debug("Bootstrapping base resource [%s]", base_model) + base_model = resource_type.bootstrap(base_model, torch_dtype=torch_dtype) + + else: + # base_model is actually a resource object + resource_type = type(base_model) + + error.type_check("", PretrainedModelBase, base_model=base_model) + ## Generate data loader from stream + training_dataset: IterableDataset = cls._preprocess_function( + base_model=base_model, + train_stream=train_stream, + tokenizer=base_model.tokenizer, + max_source_length=max_source_length, + max_target_length=max_target_length, + shuffle=True, + ) + + ### Dtype based processing + # NOTE: Following is not exhaustive list of all parameters + # for all dtypes + if torch_dtype == torch.float16: + dtype_based_params = { + "fp16": True, + } + elif torch_dtype == torch.bfloat16: + dtype_based_params = { + "bf16": True, + } + else: + # default to float32 + dtype_based_params = {} + + ## TODO: Add automatic sharding selection based on number of parameters + # in base model + ## TODO: Fetch trainer from resource + + # TODO: Make this whole thing configurable by end-users, + # by optionally accepting `training_args` + # as argument to this train function. + # TODO: Remove all the default used below and make them all configurable + + training_args = { + "output_dir": checkpoint_dir, + "per_device_train_batch_size": batch_size, + "per_device_eval_batch_size": batch_size, + "num_train_epochs": num_epochs, + "seed": random_seed, + # NOTE: We have disabled evaluation for now + "do_eval": False, + # "evaluation_strategy ": "epoch", + "learning_rate": lr, + "weight_decay": 0.01, + "save_total_limit": 3, + "push_to_hub": False, + "no_cuda": False, # Default + "remove_unused_columns": False, + "dataloader_pin_memory": False, + "gradient_accumulation_steps": accumulate_steps, + "eval_accumulation_steps": accumulate_steps, + # eval_steps=1, + # load_best_model_at_end + **training_arguments, + **dtype_based_params, + } + + trainer = base_model.get_trainer( + train_dataset=training_dataset, **training_args + ) + + if num_epochs < 1: + log.warning( + "", + f"Number of epochs configured is {num_epochs} which is less than minimum 1. \ + No training will be performed", + ) + + return cls( + model_name=base_model._model_name, + model=base_model, + ) + + # Start training via Trainer.train function + trainer.train() + + # save the model temporarily and reload it + # this is done, since otherwise the model might be distributed in different + # devices, in which case its better to use trainer's `prediction_step` + # functions, but then, they don't always give API similar to `generate` + # and thus cause incompatibilities in `run` function + trainer.save_model(checkpoint_dir) + + model = resource_type.bootstrap( + checkpoint_dir, checkpoint_dir, torch_dtype=torch_dtype + ) + + return cls( + model_name=base_model._model_name, + model=model, + bos_token=model.tokenizer.bos_token or None, + sep_token=model.tokenizer.sep_token or None, + eos_token=model.tokenizer.eos_token or None, + pad_token=model.tokenizer.pad_token or None, + ) + + @classmethod def load(cls, model_path: str) -> "TextGeneration": """Function to load text-generation model @@ -153,9 +353,9 @@ def save(self, model_path): "eos_token": self._eos_token, } ) - if self.base_model: + if self.model: # This will save both tokenizer and base model - self.base_model.save( + self.model.save( model_path, tokenizer_dirname=artifacts_dir, base_model_dirname=artifacts_dir, @@ -213,8 +413,8 @@ def run( Generated text result produced by the model. """ - inputs = self.base_model.tokenizer(text, return_tensors="pt") - generate_ids = self.base_model.model.generate( + inputs = self.model.tokenizer(text, return_tensors="pt") + generate_ids = self.model.model.generate( input_ids=inputs["input_ids"], num_beams=num_beams, max_new_tokens=max_new_tokens, @@ -227,7 +427,7 @@ def run( ) token_count = generate_ids.size(1) - 1 preds = [ - self.base_model.tokenizer.decode( + self.model.tokenizer.decode( g, skip_special_tokens=True, clean_up_tokenization_spaces=True ) for g in generate_ids @@ -244,3 +444,32 @@ def run( finish_reason=finish_reason, producer_id=self.PRODUCER_ID, ) + + + ################################## Private Functions ########################################### + + @staticmethod + def _preprocess_function( + base_model: PretrainedModelBase, + train_stream: DataStream[GenerationTrainRecord], + tokenizer: AutoTokenizer, + max_source_length: int, + max_target_length: int, + shuffle: bool, + ): + """Pre-process each example to get it prepared for training.""" + + # TODO: We are using a default verbalizer which is strictly tied to + # source training record currently. We need to figure out a better + # way to make verbalizer optional for build_task_tokenize_function + ( + tokenize_function, + requires_unwrapping, + ) = base_model.build_task_tokenize_function( + tokenizer, max_source_length, max_target_length, verbalizer="{{input}}" + ) + mapped_stream = train_stream.map(tokenize_function) + if requires_unwrapping: + mapped_stream = mapped_stream.flatten() + + return SimpleIterableStreamWrapper(mapped_stream, shuffle=shuffle) diff --git a/tests/modules/text_generation/test_text_generation_local.py b/tests/modules/text_generation/test_text_generation_local.py index b5505475..5ba514bf 100644 --- a/tests/modules/text_generation/test_text_generation_local.py +++ b/tests/modules/text_generation/test_text_generation_local.py @@ -6,10 +6,21 @@ # First Party from caikit.interfaces.nlp.data_model import GeneratedTextResult +import caikit + +# Third Party +import torch # Local +from caikit_nlp.data_model import GenerationTrainRecord from caikit_nlp.modules.text_generation import TextGeneration -from tests.fixtures import CAUSAL_LM_MODEL, SEQ2SEQ_LM_MODEL +from caikit_nlp.resources.pretrained_model import HFAutoCausalLM, HFAutoSeq2SeqLM +from tests.fixtures import ( + CAUSAL_LM_MODEL, + SEQ2SEQ_LM_MODEL, + disable_wip, + set_cpu_device, +) ### Stub Modules @@ -55,3 +66,79 @@ def test_save_model_can_run(): sample_text = "Hello stub" generated_text = new_model.run(sample_text) assert isinstance(generated_text, GeneratedTextResult) + +############################## Training ################################ + +def test_train_model_seq2seq(disable_wip, set_cpu_device): + """Ensure that we can finetune a seq2seq model on some toy data for 1+ + steps & run inference.""" + train_kwargs = { + "base_model": HFAutoSeq2SeqLM.bootstrap( + model_name=SEQ2SEQ_LM_MODEL, tokenizer_name=SEQ2SEQ_LM_MODEL + ), + "num_epochs": 1, + "train_stream": caikit.core.data_model.DataStream.from_iterable( + [ + GenerationTrainRecord( + input="@foo what a cute dog!", output="no complaint" + ), + GenerationTrainRecord( + input="@bar this is the worst idea ever.", output="complaint" + ), + ] + ), + "torch_dtype": torch.float32, + } + model = TextGeneration.train(**train_kwargs) + assert isinstance(model.model, HFAutoSeq2SeqLM) + # Ensure that we can get something out of it + pred = model.run("@bar what a cute cat!") + assert isinstance(pred, GeneratedTextResult) + + +def test_train_model_causallm(disable_wip, set_cpu_device): + """Ensure that we can finetune a causal-lm model on some toy data for 1+ + steps & run inference.""" + train_kwargs = { + "base_model": HFAutoCausalLM.bootstrap( + model_name=CAUSAL_LM_MODEL, tokenizer_name=CAUSAL_LM_MODEL + ), + "num_epochs": 1, + "train_stream": caikit.core.data_model.DataStream.from_iterable( + [ + GenerationTrainRecord( + input="@foo what a cute dog!", output="no complaint" + ), + ] + ), + "torch_dtype": torch.float32, + } + model = TextGeneration.train(**train_kwargs) + assert isinstance(model.model, HFAutoCausalLM) + + # Ensure that we can get something out of it + pred = model.run("@bar what a cute cat!") + assert isinstance(pred, GeneratedTextResult) + + +############################## Error Cases ################################ + + +def test_zero_epoch_case(disable_wip): + """Test to ensure 0 epoch training request doesn't explode""" + train_kwargs = { + "base_model": HFAutoSeq2SeqLM.bootstrap( + model_name=SEQ2SEQ_LM_MODEL, tokenizer_name=SEQ2SEQ_LM_MODEL + ), + "num_epochs": 0, + "train_stream": caikit.core.data_model.DataStream.from_iterable( + [ + GenerationTrainRecord( + input="@foo what a cute dog!", output="no complaint" + ), + ] + ), + "torch_dtype": torch.float32, + } + model = TextGeneration.train(**train_kwargs) + assert isinstance(model.model, HFAutoSeq2SeqLM) From 2ea49ec0d75dc82213585f8661091072f4a081ef Mon Sep 17 00:00:00 2001 From: gkumbhat Date: Wed, 2 Aug 2023 20:51:44 -0500 Subject: [PATCH 02/12] :white_check_mark: Add test for save and load of fine-tuned model Signed-off-by: gkumbhat Signed-off-by: gkumbhat --- .../test_text_generation_local.py | 27 ++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/tests/modules/text_generation/test_text_generation_local.py b/tests/modules/text_generation/test_text_generation_local.py index 5ba514bf..68f4926b 100644 --- a/tests/modules/text_generation/test_text_generation_local.py +++ b/tests/modules/text_generation/test_text_generation_local.py @@ -91,11 +91,36 @@ def test_train_model_seq2seq(disable_wip, set_cpu_device): } model = TextGeneration.train(**train_kwargs) assert isinstance(model.model, HFAutoSeq2SeqLM) + with tempfile.TemporaryDirectory() as model_dir: + model.save(model_dir) + new_model = TextGeneration.load(model_dir) + sample_text = "Hello stub" + generated_text = new_model.run(sample_text) + assert isinstance(generated_text, GeneratedTextResult) + + +def test_train_model_save_and_load(disable_wip, set_cpu_device): + """Ensure that we are able to save and load a finetuned model and execute inference on it""" + train_kwargs = { + "base_model": HFAutoSeq2SeqLM.bootstrap( + model_name=SEQ2SEQ_LM_MODEL, tokenizer_name=SEQ2SEQ_LM_MODEL + ), + "num_epochs": 1, + "train_stream": caikit.core.data_model.DataStream.from_iterable( + [ + GenerationTrainRecord( + input="@foo what a cute dog!", output="no complaint" + ) + ] + ), + "torch_dtype": torch.float32, + } + model = TextGeneration.train(**train_kwargs) + assert isinstance(model.model, HFAutoSeq2SeqLM) # Ensure that we can get something out of it pred = model.run("@bar what a cute cat!") assert isinstance(pred, GeneratedTextResult) - def test_train_model_causallm(disable_wip, set_cpu_device): """Ensure that we can finetune a causal-lm model on some toy data for 1+ steps & run inference.""" From ec3c04bc45f6f57749e5f3b98bf14b1a3b82620d Mon Sep 17 00:00:00 2001 From: gkumbhat Date: Wed, 2 Aug 2023 20:54:52 -0500 Subject: [PATCH 03/12] :art: Fix linting and fine-tuning test statements Signed-off-by: gkumbhat Signed-off-by: gkumbhat --- .../text_generation/text_generation_local.py | 9 ++---- .../test_text_generation_local.py | 28 +++++++++++-------- 2 files changed, 19 insertions(+), 18 deletions(-) diff --git a/caikit_nlp/modules/text_generation/text_generation_local.py b/caikit_nlp/modules/text_generation/text_generation_local.py index 7cd33813..5d35cdac 100644 --- a/caikit_nlp/modules/text_generation/text_generation_local.py +++ b/caikit_nlp/modules/text_generation/text_generation_local.py @@ -14,9 +14,9 @@ # Standard +from typing import Optional import gc import os -from typing import Optional # Third Party from torch.utils.data import IterableDataset @@ -127,7 +127,6 @@ def bootstrap(cls, base_model_path: str): eos_token=eos_token, ) - @classmethod def train( cls, @@ -313,7 +312,6 @@ def train( pad_token=model.tokenizer.pad_token or None, ) - @classmethod def load(cls, model_path: str) -> "TextGeneration": """Function to load text-generation model @@ -370,7 +368,7 @@ def run( num_beams=1, max_new_tokens=20, min_new_tokens=0, - **kwargs + **kwargs, ) -> "GeneratedTextResult": """Run inference against the model running in TGIS. @@ -445,8 +443,7 @@ def run( producer_id=self.PRODUCER_ID, ) - - ################################## Private Functions ########################################### + ################################## Private Functions ###################################### @staticmethod def _preprocess_function( diff --git a/tests/modules/text_generation/test_text_generation_local.py b/tests/modules/text_generation/test_text_generation_local.py index 68f4926b..55b9b7e3 100644 --- a/tests/modules/text_generation/test_text_generation_local.py +++ b/tests/modules/text_generation/test_text_generation_local.py @@ -4,13 +4,13 @@ import os import tempfile +# Third Party +import torch + # First Party from caikit.interfaces.nlp.data_model import GeneratedTextResult import caikit -# Third Party -import torch - # Local from caikit_nlp.data_model import GenerationTrainRecord from caikit_nlp.modules.text_generation import TextGeneration @@ -67,8 +67,10 @@ def test_save_model_can_run(): generated_text = new_model.run(sample_text) assert isinstance(generated_text, GeneratedTextResult) + ############################## Training ################################ + def test_train_model_seq2seq(disable_wip, set_cpu_device): """Ensure that we can finetune a seq2seq model on some toy data for 1+ steps & run inference.""" @@ -91,12 +93,10 @@ def test_train_model_seq2seq(disable_wip, set_cpu_device): } model = TextGeneration.train(**train_kwargs) assert isinstance(model.model, HFAutoSeq2SeqLM) - with tempfile.TemporaryDirectory() as model_dir: - model.save(model_dir) - new_model = TextGeneration.load(model_dir) - sample_text = "Hello stub" - generated_text = new_model.run(sample_text) - assert isinstance(generated_text, GeneratedTextResult) + + # Ensure that we can get something out of it + pred = model.run("@bar what a cute cat!") + assert isinstance(pred, GeneratedTextResult) def test_train_model_save_and_load(disable_wip, set_cpu_device): @@ -117,9 +117,13 @@ def test_train_model_save_and_load(disable_wip, set_cpu_device): } model = TextGeneration.train(**train_kwargs) assert isinstance(model.model, HFAutoSeq2SeqLM) - # Ensure that we can get something out of it - pred = model.run("@bar what a cute cat!") - assert isinstance(pred, GeneratedTextResult) + with tempfile.TemporaryDirectory() as model_dir: + model.save(model_dir) + new_model = TextGeneration.load(model_dir) + sample_text = "Hello stub" + generated_text = new_model.run(sample_text) + assert isinstance(generated_text, GeneratedTextResult) + def test_train_model_causallm(disable_wip, set_cpu_device): """Ensure that we can finetune a causal-lm model on some toy data for 1+ From 3d2ee1609c6fbf732bf0739d0ca37441a61f4f23 Mon Sep 17 00:00:00 2001 From: gkumbhat Date: Wed, 2 Aug 2023 20:57:55 -0500 Subject: [PATCH 04/12] :fire: Remove separate fine-tuning module Signed-off-by: gkumbhat Signed-off-by: gkumbhat --- .../modules/text_generation/__init__.py | 1 - .../modules/text_generation/fine_tuning.py | 342 ------------------ .../text_generation/test_fine_tuning.py | 94 ----- 3 files changed, 437 deletions(-) delete mode 100644 caikit_nlp/modules/text_generation/fine_tuning.py delete mode 100644 tests/modules/text_generation/test_fine_tuning.py diff --git a/caikit_nlp/modules/text_generation/__init__.py b/caikit_nlp/modules/text_generation/__init__.py index 078178f7..8c696d81 100644 --- a/caikit_nlp/modules/text_generation/__init__.py +++ b/caikit_nlp/modules/text_generation/__init__.py @@ -13,7 +13,6 @@ # limitations under the License. # Local -from .fine_tuning import FineTuning from .peft_prompt_tuning import PeftPromptTuning from .peft_tgis_remote import PeftPromptTuningTGIS from .text_generation_local import TextGeneration diff --git a/caikit_nlp/modules/text_generation/fine_tuning.py b/caikit_nlp/modules/text_generation/fine_tuning.py deleted file mode 100644 index f3933516..00000000 --- a/caikit_nlp/modules/text_generation/fine_tuning.py +++ /dev/null @@ -1,342 +0,0 @@ -# Copyright The Caikit Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Standard -from typing import Optional - -# Third Party -from torch.utils.data import IterableDataset -from transformers import AutoConfig, AutoTokenizer -import torch - -# First Party -from caikit.core.data_model import DataStream -from caikit.core.modules import ModuleBase, module -from caikit.core.toolkit import error_handler, wip_decorator -from caikit.interfaces.nlp.data_model import GeneratedTextResult -from caikit.interfaces.nlp.tasks import TextGenerationTask -import alog - -# Local -from ...data_model import GenerationTrainRecord -from ...resources.pretrained_model import ( - HFAutoCausalLM, - HFAutoSeq2SeqLM, - PretrainedModelBase, -) -from ...toolkit.data_stream_wrapper import SimpleIterableStreamWrapper -from ...toolkit.data_type_utils import get_torch_dtype - -log = alog.use_channel("FIN_TUN_GEN") -error = error_handler.get(log) - - -# pylint: disable=too-many-lines,too-many-instance-attributes -@module( - id="28a81449-32ce-4be3-b688-545bde68f738", - name="Text Generation", - version="0.1.0", - task=TextGenerationTask, -) -@wip_decorator.work_in_progress( - category=wip_decorator.WipCategory.WIP, action=wip_decorator.Action.ERROR -) -class FineTuning(ModuleBase): - """Module to provide fine-tuning support for text generation task""" - - RANDOM_SEED = 73 - supported_resources = [HFAutoCausalLM, HFAutoSeq2SeqLM] - - def __init__( - self, - tokenizer, - model, - bos_token: Optional[str] = None, - sep_token: Optional[str] = None, - eos_token: Optional[str] = None, - pad_token: Optional[str] = None, - ): - super().__init__() - - self.tokenizer = tokenizer - self.model = model - self._bos_token = bos_token - self._sep_token = sep_token - self._eos_token = eos_token - self._pad_token = pad_token - - @classmethod - def train( - cls, - base_model: str, # TODO: Union[str, PretrainedModelBase] - train_stream: DataStream[GenerationTrainRecord], - torch_dtype: str = None, # TODO: Optional[Union[torch.dtype, str]] - max_source_length: int = 256, - max_target_length: int = 128, - batch_size: int = 8, - num_epochs: int = 5, - accumulate_steps: int = 32, - random_seed: int = RANDOM_SEED, - lr: float = 2e-5, - # Directory where model predictions and checkpoints will be written - checkpoint_dir: str = "/tmp", - **training_arguments, - ): - """ - Fine-tune a CausalLM or Seq2seq text generation model. - - Args: - base_model: Union[str, caikit_nlp.resources.pretrained_model.base.PretrainedModelBase] - Base resource model used for underlying generation. - train_stream: DataStream[GenerationTrainRecord] or DataStream[ClassificationTrainRecord] - Data to be used for fine-tuning the generation model. - torch_dtype: str - TODO: Optional[Union[torch.dtype, str]] - Data type to use for training/inference of the underlying text generation model. - If no value is provided, we pull from torch_dtype in config. If an in memory - resource is provided which does not match the specified data type, the model - underpinning the resource will be converted in place to the correct torch dtype. - max_source_length: int - Max length of input sequences being considered. Default: 256. - max_target_length: int - Max length of target sequences being predicted. Default: 128. - batch_size: int - Batch sized to be used for training / evaluation data. Default: 8. - num_epochs: int - Number of epochs to tune the model. Default: 20. - accumulate_steps: int - Number of steps to use for gradient accumulation. Default: 1. - lr: float - Learning rate to be used while tuning model. Default: 2e-5. - checkpoint_dir: str - Directory where model predictions and checkpoints will be written - **training_arguments: - Arguments supported by HF Training Arguments. - TrainingArguments: - https://huggingface.co/docs/transformers/v4.30.0/en/main_classes/trainer#transformers.TrainingArguments - Seq2SeqTrainingArguments: - https://huggingface.co/docs/transformers/v4.30.0/en/main_classes/trainer#transformers.Seq2SeqTrainingArguments - Returns: - FineTuning - Instance of this class with fine-tuned models. - """ - - torch_dtype = get_torch_dtype(torch_dtype) - - ## NOTE: Below code has been used in couple of places at this point, like in - # text_generation module. In future, we would want to consolidate this into - # a base class or a toolkit function - # pylint: disable=duplicate-code - resource_type = None - - ## Load base model - if isinstance(base_model, str): - model_config = AutoConfig.from_pretrained(base_model) - - for resource in cls.supported_resources: - if model_config.model_type in resource.SUPPORTED_MODEL_TYPES: - resource_type = resource - break - - if not resource_type: - error( - "", - "{} model type is not supported currently!".format( - model_config.model_type - ), - ) - log.debug("Bootstrapping base resource [%s]", base_model) - base_model = resource_type.bootstrap(base_model, torch_dtype=torch_dtype) - - else: - # base_model is actually a resource object - resource_type = type(base_model) - - error.type_check("", PretrainedModelBase, base_model=base_model) - ## Generate data loader from stream - training_dataset: IterableDataset = cls._preprocess_function( - base_model=base_model, - train_stream=train_stream, - tokenizer=base_model.tokenizer, - max_source_length=max_source_length, - max_target_length=max_target_length, - shuffle=True, - ) - - ### Dtype based processing - # NOTE: Following is not exhaustive list of all parameters - # for all dtypes - if torch_dtype == torch.float16: - dtype_based_params = { - "fp16": True, - } - elif torch_dtype == torch.bfloat16: - dtype_based_params = { - "bf16": True, - } - else: - # default to float32 - dtype_based_params = {} - - ## TODO: Add automatic sharding selection based on number of parameters - # in base model - ## TODO: Fetch trainer from resource - - # TODO: Make this whole thing configurable by end-users, - # by optionally accepting `training_args` - # as argument to this train function. - # TODO: Remove all the default used below and make them all configurable - - training_args = { - "output_dir": checkpoint_dir, - "per_device_train_batch_size": batch_size, - "per_device_eval_batch_size": batch_size, - "num_train_epochs": num_epochs, - "seed": random_seed, - # NOTE: We have disabled evaluation for now - "do_eval": False, - # "evaluation_strategy ": "epoch", - "learning_rate": lr, - "weight_decay": 0.01, - "save_total_limit": 3, - "push_to_hub": False, - "no_cuda": False, # Default - "remove_unused_columns": False, - "dataloader_pin_memory": False, - "gradient_accumulation_steps": accumulate_steps, - "eval_accumulation_steps": accumulate_steps, - # eval_steps=1, - # load_best_model_at_end - **training_arguments, - **dtype_based_params, - } - - trainer = base_model.get_trainer( - train_dataset=training_dataset, **training_args - ) - - if num_epochs < 1: - log.warning( - "", - f"Number of epochs configured is {num_epochs} which is less than minimum 1. \ - No training will be performed", - ) - - return cls( - tokenizer=base_model.tokenizer, - model=trainer, - ) - - # Start training via Trainer.train function - trainer.train() - - # save the model temporarily and reload it - # this is done, since otherwise the model might be distributed in different - # devices, in which case its better to use trainer's `prediction_step` - # functions, but then, they don't always give API similar to `generate` - # and thus cause incompatibilities in `run` function - trainer.save_model(checkpoint_dir) - - model = resource_type.bootstrap( - checkpoint_dir, checkpoint_dir, torch_dtype=torch_dtype - ) - - return cls( - tokenizer=model.tokenizer, - model=model, - bos_token=model.tokenizer.bos_token or None, - sep_token=model.tokenizer.sep_token or None, - eos_token=model.tokenizer.eos_token or None, - pad_token=model.tokenizer.pad_token or None, - ) - - # pylint: disable=unused-argument - def run( - self, text, preserve_input_text=False, max_new_tokens=20, min_new_tokens=0 - ) -> "GeneratedTextResult": - """Run inference against the model running in TGIS. - - Args: - text: str - Source string to be encoded for generation. - preserve_input_text: bool - Whether or not the source string should be contained in the generated output, - e.g., as a prefix. - max_new_tokens: int - The maximum numbers of tokens to generate. - Default: 128 - min_new_tokens: int - The minimum numbers of tokens to generate. - Default: 0 - means no minimum - Returns: - GeneratedTextResult - Generated text result - """ - - inputs = self.model.tokenizer(text, return_tensors="pt") - generate_ids = self.model.model.generate( - input_ids=inputs["input_ids"], - max_new_tokens=max_new_tokens, - min_new_tokens=min_new_tokens, - use_cache=True, - ) - - token_count = generate_ids.size(1) - 1 - preds = [ - self.model.tokenizer.decode( - g, skip_special_tokens=True, clean_up_tokenization_spaces=True - ) - for g in generate_ids - ] - if generate_ids[0][-1].item() == self._eos_token: - finish_reason = "EOS_TOKEN" - elif generate_ids.size(1) - 1 == max_new_tokens: - finish_reason = "MAX_TOKENS" - else: - finish_reason = "OTHER" - - return GeneratedTextResult( - generated_tokens=token_count, - generated_text=preds[0], - finish_reason=finish_reason, - producer_id=self.PRODUCER_ID, - ) - - ################################## Private Functions ########################################### - - @staticmethod - def _preprocess_function( - base_model: PretrainedModelBase, - train_stream: DataStream[GenerationTrainRecord], - tokenizer: AutoTokenizer, - max_source_length: int, - max_target_length: int, - shuffle: bool, - ): - """Pre-process each example to get it prepared for training.""" - - # TODO: We are using a default verbalizer which is strictly tied to - # source training record currently. We need to figure out a better - # way to make verbalizer optional for build_task_tokenize_function - ( - tokenize_function, - requires_unwrapping, - ) = base_model.build_task_tokenize_function( - tokenizer, max_source_length, max_target_length, verbalizer="{{input}}" - ) - mapped_stream = train_stream.map(tokenize_function) - if requires_unwrapping: - mapped_stream = mapped_stream.flatten() - - return SimpleIterableStreamWrapper(mapped_stream, shuffle=shuffle) diff --git a/tests/modules/text_generation/test_fine_tuning.py b/tests/modules/text_generation/test_fine_tuning.py deleted file mode 100644 index a17f5ffa..00000000 --- a/tests/modules/text_generation/test_fine_tuning.py +++ /dev/null @@ -1,94 +0,0 @@ -# Third Party -from transformers import Trainer -import pytest -import torch - -# First Party -from caikit.interfaces.nlp.data_model import GeneratedTextResult -import caikit - -# Local -from caikit_nlp.data_model import GenerationTrainRecord -from caikit_nlp.modules.text_generation import FineTuning -from caikit_nlp.resources.pretrained_model import HFAutoCausalLM, HFAutoSeq2SeqLM -from tests.fixtures import ( - CAUSAL_LM_MODEL, - SEQ2SEQ_LM_MODEL, - disable_wip, - set_cpu_device, -) - - -def test_train_model_seq2seq(disable_wip, set_cpu_device): - """Ensure that we can finetune a seq2seq model on some toy data for 1+ - steps & run inference.""" - train_kwargs = { - "base_model": HFAutoSeq2SeqLM.bootstrap( - model_name=SEQ2SEQ_LM_MODEL, tokenizer_name=SEQ2SEQ_LM_MODEL - ), - "num_epochs": 1, - "train_stream": caikit.core.data_model.DataStream.from_iterable( - [ - GenerationTrainRecord( - input="@foo what a cute dog!", output="no complaint" - ), - GenerationTrainRecord( - input="@bar this is the worst idea ever.", output="complaint" - ), - ] - ), - "torch_dtype": torch.float32, - } - model = FineTuning.train(**train_kwargs) - assert isinstance(model.model, HFAutoSeq2SeqLM) - # Ensure that we can get something out of it - pred = model.run("@bar what a cute cat!") - assert isinstance(pred, GeneratedTextResult) - - -def test_train_model_causallm(disable_wip, set_cpu_device): - """Ensure that we can finetune a causal-lm model on some toy data for 1+ - steps & run inference.""" - train_kwargs = { - "base_model": HFAutoCausalLM.bootstrap( - model_name=CAUSAL_LM_MODEL, tokenizer_name=CAUSAL_LM_MODEL - ), - "num_epochs": 1, - "train_stream": caikit.core.data_model.DataStream.from_iterable( - [ - GenerationTrainRecord( - input="@foo what a cute dog!", output="no complaint" - ), - ] - ), - "torch_dtype": torch.float32, - } - model = FineTuning.train(**train_kwargs) - assert isinstance(model.model, HFAutoCausalLM) - - # Ensure that we can get something out of it - pred = model.run("@bar what a cute cat!") - assert isinstance(pred, GeneratedTextResult) - - -############################## Error Cases ################################ - - -def test_zero_epoch_case(disable_wip): - """Test to ensure 0 epoch training request doesn't explode""" - train_kwargs = { - "base_model": HFAutoSeq2SeqLM.bootstrap( - model_name=SEQ2SEQ_LM_MODEL, tokenizer_name=SEQ2SEQ_LM_MODEL - ), - "num_epochs": 0, - "train_stream": caikit.core.data_model.DataStream.from_iterable( - [ - GenerationTrainRecord( - input="@foo what a cute dog!", output="no complaint" - ), - ] - ), - "torch_dtype": torch.float32, - } - model = FineTuning.train(**train_kwargs) - assert isinstance(model.model, Trainer) From aeab02476de7b5b43f7a269fed33c5360b36f864 Mon Sep 17 00:00:00 2001 From: gkumbhat Date: Thu, 3 Aug 2023 17:46:18 -0500 Subject: [PATCH 05/12] :recycle: Update text generation tgis and its tests to work with remote TGIS concept Signed-off-by: gkumbhat Signed-off-by: gkumbhat --- .../text_generation/text_generation_tgis.py | 124 ++++++++---------- .../test_text_generation_tgis.py | 4 +- 2 files changed, 55 insertions(+), 73 deletions(-) diff --git a/caikit_nlp/modules/text_generation/text_generation_tgis.py b/caikit_nlp/modules/text_generation/text_generation_tgis.py index 74f64c20..ea1a567d 100644 --- a/caikit_nlp/modules/text_generation/text_generation_tgis.py +++ b/caikit_nlp/modules/text_generation/text_generation_tgis.py @@ -14,8 +14,7 @@ # Standard -from typing import Iterable, Optional -import os +from typing import Iterable, Optional, Union # First Party from caikit.core.module_backends import BackendBase, backend_types @@ -52,8 +51,8 @@ class TextGenerationTGIS(ModuleBase): def __init__( self, - base_model_name: str, - base_model: Optional[PretrainedModelBase] = None, + model_name: str, + model: Optional[PretrainedModelBase] = None, bos_token: Optional[str] = None, sep_token: Optional[str] = None, eos_token: Optional[str] = None, @@ -66,8 +65,8 @@ def __init__( error.type_check("", str, allow_none=True, sep_token=sep_token) error.type_check("", str, allow_none=True, eos_token=eos_token) error.type_check("", str, allow_none=True, pad_token=pad_token) - self.base_model = base_model - self.base_model_name = base_model_name + self.model = model + self.model_name = model_name # Set _model_loaded as False by default. This will only get set to True if # we enable the tgis_backend and we are able to fetch the client successfully. @@ -77,49 +76,36 @@ def __init__( # for example, bootstrapping a model to caikit format and saving. self._client = None if tgis_backend: - self._client = tgis_backend.get_client(base_model_name) + self._client = tgis_backend.get_client(model_name) # mark that the model is loaded so that we can unload it later self._model_loaded = True + self.tgis_backend = tgis_backend self._bos_token = bos_token self._sep_token = sep_token self._eos_token = eos_token self._pad_token = pad_token self.tgis_generation_client = TGISGenerationClient( - self.base_model_name, self._eos_token, self._client, self.PRODUCER_ID + self.model_name, self._eos_token, self._client, self.PRODUCER_ID ) def __del__(self): # nothing to unload if we didn't finish loading - if self._model_loaded and self.load_backend: - self.load_backend.unload_model(self._model_path) + if self._model_loaded and self.tgis_backend: + self.tgis_backend.unload_model(self.model_name) @classmethod - def bootstrap(cls, base_model_path: str, load_backend: BackendBase = None): - """Function to bootstrap a pre-trained transformers model and - get a caikit text-generation 'model'. + def bootstrap(cls, model_path: str, load_backend: Union[BackendBase, None] = None): - Args: - base_model_path: str - Path to transformers model - NOTE: Model path needs to contain tokenizer as well - load_backend: BackendBase - Backend object to be used to run inference with. - NOTE: this is required for inferencing. It is - made optional to support the model conversion use-case - Returns: - caikit_nlp.blocks.text_generation.TextGeneration - Object of TextGeneration class (model) - """ - text_generation_inst = TextGeneration.bootstrap(base_model_path) - bos_token = text_generation_inst.base_model._tokenizer.bos_token - sep_token = text_generation_inst.base_model._tokenizer.sep_token - eos_token = text_generation_inst.base_model._tokenizer.eos_token or None - pad_token = text_generation_inst.base_model._tokenizer.pad_token + text_generation_inst = TextGeneration.bootstrap(model_path) + bos_token = text_generation_inst.model._tokenizer.bos_token + sep_token = text_generation_inst.model._tokenizer.sep_token + eos_token = text_generation_inst.model._tokenizer.eos_token or None + pad_token = text_generation_inst.model._tokenizer.pad_token return cls( - text_generation_inst.base_model_name, - text_generation_inst.base_model, + text_generation_inst.model_name, + text_generation_inst.model, bos_token=bos_token, sep_token=sep_token, eos_token=eos_token, @@ -127,38 +113,12 @@ def bootstrap(cls, base_model_path: str, load_backend: BackendBase = None): tgis_backend=load_backend, ) - def save(self, model_path): - """Save caikit model - - Args: - model_path: str - Folder to save text-generation caikit model - """ - saver = ModuleSaver(self, model_path=model_path) - with saver: - saver.update_config( - { - "base_model_name": self.base_model_name, - "bos_token": self._bos_token, - "sep_token": self._sep_token, - "eos_token": self._eos_token, - "pad_token": self._pad_token, - } - ) - if self.base_model: - artifacts_dir = "artifacts" - log.debug("Saving model artifacts to %s", artifacts_dir) - saver.update_config({"artifact_path": artifacts_dir}) - # This will save both tokenizer and base model - self.base_model.save( - model_path, - tokenizer_dirname=artifacts_dir, - base_model_dirname=artifacts_dir, - ) - @classmethod def load(cls, model_path: str, load_backend: BackendBase) -> "TextGeneration": - """Function to load text-generation model + """Function to load text-generation model. Note, this only loads + "remote" style model, i.e the cakit-model that doesn't + necessarily required to have actual artifacts in it + and thus only saves them in "remote" format. Args: model_path: str @@ -172,17 +132,12 @@ def load(cls, model_path: str, load_backend: BackendBase) -> "TextGeneration": error.type_check("", TGISBackend, load_backend=load_backend) config = ModuleConfig.load(model_path) - artifacts_path = config.artifact_path - if artifacts_path: - base_model_name = os.path.join(model_path, artifacts_path) - error.dir_check("", base_model_name) - log.debug("Loading with on-disk artifacts: %s", base_model_name) - else: - base_model_name = config.base_model_name - error.type_check("", str, base_model_name=base_model_name) - log.debug("Loading with model name: %s", base_model_name) + + model_name = config.model_name + error.type_check("", str, model_name=model_name) + log.debug("Loading with model name: %s", model_name) return cls( - base_model_name, + model_name, bos_token=config.bos_token, sep_token=config.sep_token, eos_token=config.eos_token, @@ -190,6 +145,31 @@ def load(cls, model_path: str, load_backend: BackendBase) -> "TextGeneration": tgis_backend=load_backend, ) + def save(self, model_path: str): + """Export the config for this model. + This saves the model in "remote" style + and does not store the actual model artifacts + along with the caikit-model. + + model_path: str + Path to which we should write our model. + """ + # pylint: disable=duplicate-code + saver = ModuleSaver( + self, + model_path=model_path, + ) + with saver: + saver.update_config( + { + "model_name": self.model_name, + "bos_token": self._bos_token, + "sep_token": self._sep_token, + "eos_token": self._eos_token, + "pad_token": self._pad_token, + } + ) + @TextGenerationTask.taskmethod() def run( self, text, preserve_input_text=False, max_new_tokens=20, min_new_tokens=0 diff --git a/tests/modules/text_generation/test_text_generation_tgis.py b/tests/modules/text_generation/test_text_generation_tgis.py index efd1cb8f..dc8137af 100644 --- a/tests/modules/text_generation/test_text_generation_tgis.py +++ b/tests/modules/text_generation/test_text_generation_tgis.py @@ -59,7 +59,9 @@ def test_run_multi_response_errors(): def test_bootstrap_and_save_model(): """Check if we can bootstrap and save the model successfully""" - model = TextGenerationTGIS.bootstrap(SEQ2SEQ_LM_MODEL) + model = TextGenerationTGIS.bootstrap( + SEQ2SEQ_LM_MODEL, load_backend=StubTGISBackend() + ) with tempfile.TemporaryDirectory() as model_dir: model.save(model_dir) From d2c0b9943765d9d3333c3707e97f94c5fb2fb5b3 Mon Sep 17 00:00:00 2001 From: gkumbhat Date: Thu, 3 Aug 2023 18:03:27 -0500 Subject: [PATCH 06/12] :recycle: Refactor load and save functions to provide portability Signed-off-by: gkumbhat Signed-off-by: gkumbhat --- .../text_generation/text_generation_tgis.py | 14 +++++-- .../test_text_generation_tgis.py | 38 ++++++++++++++++++- 2 files changed, 47 insertions(+), 5 deletions(-) diff --git a/caikit_nlp/modules/text_generation/text_generation_tgis.py b/caikit_nlp/modules/text_generation/text_generation_tgis.py index ea1a567d..7cd87d26 100644 --- a/caikit_nlp/modules/text_generation/text_generation_tgis.py +++ b/caikit_nlp/modules/text_generation/text_generation_tgis.py @@ -15,6 +15,7 @@ # Standard from typing import Iterable, Optional, Union +import os # First Party from caikit.core.module_backends import BackendBase, backend_types @@ -132,10 +133,15 @@ def load(cls, model_path: str, load_backend: BackendBase) -> "TextGeneration": error.type_check("", TGISBackend, load_backend=load_backend) config = ModuleConfig.load(model_path) - - model_name = config.model_name - error.type_check("", str, model_name=model_name) - log.debug("Loading with model name: %s", model_name) + artifacts_path = config.artifact_path + if artifacts_path: + model_name = os.path.join(model_path, artifacts_path) + error.dir_check("", model_name) + log.debug("Loading with on-disk artifacts: %s", model_name) + else: + model_name = config.model_name + error.type_check("", str, model_name=model_name) + log.debug("Loading with model name: %s", model_name) return cls( model_name, bos_token=config.bos_token, diff --git a/tests/modules/text_generation/test_text_generation_tgis.py b/tests/modules/text_generation/test_text_generation_tgis.py index dc8137af..75360eb8 100644 --- a/tests/modules/text_generation/test_text_generation_tgis.py +++ b/tests/modules/text_generation/test_text_generation_tgis.py @@ -7,9 +7,16 @@ # Third Party import pytest +import torch + +# First Party +from caikit.interfaces.nlp.data_model import GeneratedTextResult +import caikit # Local -from caikit_nlp.modules.text_generation import TextGenerationTGIS +from caikit_nlp.data_model.generation import GenerationTrainRecord +from caikit_nlp.modules.text_generation import TextGeneration, TextGenerationTGIS +from caikit_nlp.resources.pretrained_model.hf_auto_seq2seq_lm import HFAutoSeq2SeqLM from tests.fixtures import ( CAUSAL_LM_MODEL, SEQ2SEQ_LM_MODEL, @@ -81,6 +88,35 @@ def test_save_model_can_run(): StubTGISClient.validate_unary_generate_response(result) +def test_local_train_load_tgis(): + """Check if the model trained in local module is able to + be loaded in TGIS module / backend + """ + train_kwargs = { + "base_model": HFAutoSeq2SeqLM.bootstrap( + model_name=SEQ2SEQ_LM_MODEL, tokenizer_name=SEQ2SEQ_LM_MODEL + ), + "num_epochs": 1, + "train_stream": caikit.core.data_model.DataStream.from_iterable( + [ + GenerationTrainRecord( + input="@foo what a cute dog!", output="no complaint" + ) + ] + ), + "torch_dtype": torch.float32, + } + model = TextGeneration.train(**train_kwargs) + with tempfile.TemporaryDirectory() as model_dir: + model.save(model_dir) + new_model = TextGenerationTGIS.load( + model_dir, load_backend=StubTGISBackend(mock_remote=True) + ) + sample_text = "Hello stub" + generated_text = new_model.run(sample_text) + assert isinstance(generated_text, GeneratedTextResult) + + def test_remote_tgis_only_model(): """Make sure that a model can be created and used that will only work with a remote TGIS connection (i.e. it has no artifacts) From 805d13955344fcd00de305897ba6107ce08cc788 Mon Sep 17 00:00:00 2001 From: gkumbhat Date: Fri, 4 Aug 2023 15:16:55 -0500 Subject: [PATCH 07/12] :sparkles: Fix fine-tuning script to make it optionally work with remote TGIS for evaluation Signed-off-by: gkumbhat Signed-off-by: gkumbhat --- .../text_generation/text_generation_local.py | 20 +++++++-- examples/run_fine_tuning.py | 44 ++++++++++++++++--- examples/text-generation-launcher | 7 +-- examples/utils.py | 12 +++-- 4 files changed, 66 insertions(+), 17 deletions(-) diff --git a/caikit_nlp/modules/text_generation/text_generation_local.py b/caikit_nlp/modules/text_generation/text_generation_local.py index 5d35cdac..c27fb2c8 100644 --- a/caikit_nlp/modules/text_generation/text_generation_local.py +++ b/caikit_nlp/modules/text_generation/text_generation_local.py @@ -39,7 +39,7 @@ PretrainedModelBase, ) from ...toolkit.data_stream_wrapper import SimpleIterableStreamWrapper -from ...toolkit.data_type_utils import get_torch_dtype +from ...toolkit.data_type_utils import get_torch_dtype, str_to_torch_dtype log = alog.use_channel("TXT_GEN") error = error_handler.get(log) @@ -179,7 +179,7 @@ def train( Seq2SeqTrainingArguments: https://huggingface.co/docs/transformers/v4.30.0/en/main_classes/trainer#transformers.Seq2SeqTrainingArguments Returns: - FineTuning + TextGeneration Instance of this class with fine-tuned models. """ @@ -313,24 +313,36 @@ def train( ) @classmethod - def load(cls, model_path: str) -> "TextGeneration": + def load( + cls, + model_path: str, + torch_dtype: str = None, + ) -> "TextGeneration": """Function to load text-generation model Args: model_path: str Path to the model to be loaded. + torch_dtype: str + Torch data type to be used when loading the model. Returns: TextGeneration Instance of this class built from the on disk model. """ config = ModuleConfig.load(model_path) + + if torch_dtype is not None: + torch_dtype = str_to_torch_dtype(torch_dtype) + else: + torch_dtype = str_to_torch_dtype(config.trained_torch_dtype) + base_model_path = config.get("artifact_path") error.type_check("", str, base_model_path=base_model_path) base_model_path = os.path.join(model_path, base_model_path) error.dir_check("", base_model_path) - return cls.bootstrap(base_model_path) + return cls.bootstrap(base_model_path, torch_dtype) def save(self, model_path): """Save caikit model diff --git a/examples/run_fine_tuning.py b/examples/run_fine_tuning.py index 4768859e..65733497 100644 --- a/examples/run_fine_tuning.py +++ b/examples/run_fine_tuning.py @@ -20,6 +20,7 @@ SUPPORTED_METRICS, DatasetInfo, configure_random_seed_and_logging, + load_model, print_colored, ) @@ -30,7 +31,7 @@ # Local from caikit_nlp.data_model import GenerationTrainRecord, TuningConfig -from caikit_nlp.modules.text_generation import FineTuning +from caikit_nlp.modules.text_generation import TextGeneration from caikit_nlp.resources.pretrained_model import ( HFAutoCausalLM, HFAutoSeq2SeqLM, @@ -170,6 +171,11 @@ def register_common_arguments(subparser: argparse.ArgumentParser) -> None: nargs="*", default=["accuracy"], ) + subparser.add_argument( + "--tgis", + help="Run inference using TGIS. NOTE: This involves saving and reloading model in TGIS container", + action="store_true", + ) def validate_common_args(args: argparse.Namespace): @@ -247,7 +253,12 @@ def get_model_preds_and_references(model, validation_stream): for datum in tqdm(validation_stream): # Local .run() currently prepends the input text to the generated string; # Ensure that we're just splitting the first predicted token & beyond. - raw_model_text = model.run(datum.input).text + # TEMP HACK to avoid too big input to TGIS problem: + try: + raw_model_text = model.run(datum.input).generated_text + except Exception as ex: + print(ex) + continue parse_pred_text = raw_model_text.split(datum.input)[-1].strip() model_preds.append(parse_pred_text) targets.append(datum.output) @@ -303,20 +314,18 @@ def export_model_preds(preds_file, predictions, validation_stream): # Then actually train the model & save it print_colored("[Starting the training...]") - model = FineTuning.train( + model = TextGeneration.train( base_model, train_stream, max_source_length=args.max_source_length, max_target_length=args.max_target_length, lr=args.learning_rate, - torch_dtype="float16", + torch_dtype=args.torch_dtype, batch_size=args.batch_size, accumulate_steps=args.accumulate_steps, num_epochs=args.num_epochs, ) - # model.save(args.output_dir, save_base_model=not args.prompt_only) - print_colored("[Training Complete]") # Prediction @@ -325,13 +334,34 @@ def export_model_preds(preds_file, predictions, validation_stream): print("Generated text: ", prediction_results) + if args.tgis: + + # Saving model + model.save(args.output_dir) + + # Load model in TGIS + # HACK: export args.output_dir as MODEL_NAME for TGIS + # container to pick up automatically + os.environ["MODEL_DIR"] = os.path.dirname(args.output_dir) + os.environ[ + "MODEL_NAME" + ] = f"/models/{os.path.basename(args.output_dir)}/artifacts" + + loaded_model = load_model(is_distributed=True, model_path=args.output_dir) + + else: + # Use trained model directly + loaded_model = model + ## Evaluation print_colored("[Starting Evaluation]") validation_stream = dataset_info.dataset_loader()[1] print_colored("Getting model predictions...") - predictions, references = get_model_preds_and_references(model, validation_stream) + predictions, references = get_model_preds_and_references( + loaded_model, validation_stream + ) export_model_preds(args.preds_file, predictions, validation_stream) diff --git a/examples/text-generation-launcher b/examples/text-generation-launcher index 4204f93c..6273dcb2 100755 --- a/examples/text-generation-launcher +++ b/examples/text-generation-launcher @@ -1,16 +1,17 @@ #!/usr/bin/env bash -# This script is primarily meant for illustrative purposes; if we don't have +# This script is primarily meant for illustrative purposes; if we don't have # the text-generation-launcher command locally available, but we do have a Docker # container, we add this script onto our path so when the TGIS backend in caikit # tries to start the server, it runs this script instead. # -# NOTE: +# NOTE: # - Model ID, directories, etc are hardcoded to our example, params from the backend, # e.g., shard configuration, are ignored. # # - We need to export port 3000 (for probes in core distributed), and we forward 8033->50055 # so that our gRPC server is exposed on the expected port for local TGIS. TGIS_MODEL="${MODEL_NAME:-bigscience/bloom-560m}" +MODEL_DIR="${MODEL_DIR:-models}" echo "Running TGIS with model: $TGIS_MODEL" docker run --rm \ @@ -21,7 +22,7 @@ docker run --rm \ -p 8087:8087 \ -p 50055:8033 \ -p 3000:3000 \ - -v $(pwd)/models:/models \ + -v $(pwd)/${MODEL_DIR}:/models \ -v $(pwd)/../runtime_config.yaml:/conf/runtime_config.yaml \ -v $(pwd)/transformers_cache:/shared_model_storage/transformers_cache \ -v $(pwd)/prompt_prefixes:/prompt_prefixes \ diff --git a/examples/utils.py b/examples/utils.py index 80d18cf2..f3c610be 100644 --- a/examples/utils.py +++ b/examples/utils.py @@ -39,6 +39,8 @@ "formatter": "pretty", } +log = alog.use_channel("EXMPL_UTILS") + def configure_random_seed_and_logging(): """Ensure that random experiments will be deterministic & set up default ALOG configuration.""" @@ -86,7 +88,7 @@ def get_distributed_model(model_path): "initializers": { "default": { "config": { - "backend_priority": {[{"type": TGISBackend.backend_type}]} + "backend_priority": [{"type": TGISBackend.backend_type}] } } } @@ -99,13 +101,17 @@ def get_distributed_model(model_path): # NOTE: bloom-560m is the default here because that's the default model used in our # text generation server hack script. model_name_override = os.getenv("MODEL_NAME", "bloom-560m") - loaded_base_model = dist_model.base_model_name + if hasattr(dist_model, "base_model_name"): + loaded_base_model = dist_model.base_model_name + else: + loaded_base_model = dist_model.model_name if not model_name_override.endswith(loaded_base_model): - raise ValueError( + log.error( "TGIS using model name: {} conflicts with base model name: {}; set env var MODEL_NAME to the correct base model!".format( model_name_override, loaded_base_model ) ) + return dist_model From eb5c7fdd2b29dfc1b63e06c0dbf51064face9840 Mon Sep 17 00:00:00 2001 From: gkumbhat Date: Fri, 4 Aug 2023 15:57:22 -0500 Subject: [PATCH 08/12] :bug::sparkles: Fix dtype issue and add truncate input tokens param Signed-off-by: gkumbhat Signed-off-by: gkumbhat --- .../text_generation/peft_tgis_remote.py | 38 ++++++++++++--- .../text_generation/text_generation_local.py | 12 +++-- .../text_generation/text_generation_tgis.py | 37 +++++++++++++-- caikit_nlp/toolkit/tgis_utils.py | 46 +++++++++++++++---- 4 files changed, 111 insertions(+), 22 deletions(-) diff --git a/caikit_nlp/modules/text_generation/peft_tgis_remote.py b/caikit_nlp/modules/text_generation/peft_tgis_remote.py index af9626e8..65112bea 100644 --- a/caikit_nlp/modules/text_generation/peft_tgis_remote.py +++ b/caikit_nlp/modules/text_generation/peft_tgis_remote.py @@ -160,7 +160,12 @@ def save(self, model_path: str): @TextGenerationTask.taskmethod() def run( - self, text, preserve_input_text=False, max_new_tokens=20, min_new_tokens=0 + self, + text, + preserve_input_text=False, + max_new_tokens=20, + min_new_tokens=0, + truncate_input_tokens=0, ) -> GeneratedTextResult: """Run inference against the model running in TGIS. Currently we leverage greedy decoding and apply the same verbalizer used for training the local model prior to sending the @@ -178,7 +183,11 @@ def run( min_new_tokens: int The minimum numbers of tokens to generate. Default: 0 - means no minimum - + truncate_input_tokens: int + Truncate inputs to provided number of tokens. This can be + use to avoid failing due to input being longer than + configured limits. + Default: 0 - means don't truncate, thus throw error. Returns: GeneratedTextResult Generated text result produced by TGIS. @@ -190,12 +199,21 @@ def run( ) verbalized_text = render_verbalizer(self.verbalizer, {"input": text}) return self.tgis_generation_client.unary_generate( - verbalized_text, preserve_input_text, max_new_tokens, min_new_tokens + verbalized_text, + preserve_input_text, + max_new_tokens, + min_new_tokens, + truncate_input_tokens, ) @TextGenerationTask.taskmethod(output_streaming=True) def run_stream_out( - self, text: str, preserve_input_text=False, max_new_tokens=20, min_new_tokens=0 + self, + text: str, + preserve_input_text=False, + max_new_tokens=20, + min_new_tokens=0, + truncate_input_tokens=0, ) -> Iterable[GeneratedTextStreamResult]: """Run output stream inferencing against the model running in TGIS @@ -211,7 +229,11 @@ def run_stream_out( min_new_tokens: int The minimum numbers of tokens to generate. Default: 0 - means no minimum - + truncate_input_tokens: int + Truncate inputs to provided number of tokens. This can be + use to avoid failing due to input being longer than + configured limits. + Default: 0 - means don't truncate, thus throw error. Returns: Iterable[GeneratedTextStreamResult] """ @@ -223,5 +245,9 @@ def run_stream_out( ) verbalized_text = render_verbalizer(self.verbalizer, {"input": text}) return self.tgis_generation_client.stream_generate( - verbalized_text, preserve_input_text, max_new_tokens, min_new_tokens + verbalized_text, + preserve_input_text, + max_new_tokens, + min_new_tokens, + truncate_input_tokens, ) diff --git a/caikit_nlp/modules/text_generation/text_generation_local.py b/caikit_nlp/modules/text_generation/text_generation_local.py index c27fb2c8..7ab0d89f 100644 --- a/caikit_nlp/modules/text_generation/text_generation_local.py +++ b/caikit_nlp/modules/text_generation/text_generation_local.py @@ -88,7 +88,7 @@ def __del__(self): pass @classmethod - def bootstrap(cls, base_model_path: str): + def bootstrap(cls, base_model_path: str, torch_dtype: str = "float32"): """Function to bootstrap a pre-trained transformers model and get a caikit text-generation 'model'. @@ -96,6 +96,9 @@ def bootstrap(cls, base_model_path: str): base_model_path: str Path to transformers model NOTE: Model path needs to contain tokenizer as well + torch_dtype: str + Torch data type to be used when loading the model. + Default: float32 Returns: caikit_nlp.blocks.text_generation.TextGeneration Object of TextGeneration class (model) @@ -118,7 +121,9 @@ def bootstrap(cls, base_model_path: str): ) log.debug("Bootstrapping base resource [%s]", base_model_path) base_model = resource_type.bootstrap( - base_model_path, tokenizer_name=base_model_path + base_model_path, + tokenizer_name=base_model_path, + torch_dtype=torch_dtype, ) eos_token = base_model._tokenizer.eos_token or None return cls( @@ -334,7 +339,7 @@ def load( if torch_dtype is not None: torch_dtype = str_to_torch_dtype(torch_dtype) - else: + elif config.trained_torch_dtype: torch_dtype = str_to_torch_dtype(config.trained_torch_dtype) base_model_path = config.get("artifact_path") @@ -361,6 +366,7 @@ def save(self, model_path): { "artifact_path": artifacts_dir, "eos_token": self._eos_token, + "torch_dtype": str(self.model._torch_dtype), } ) if self.model: diff --git a/caikit_nlp/modules/text_generation/text_generation_tgis.py b/caikit_nlp/modules/text_generation/text_generation_tgis.py index 7cd87d26..fcc06aa9 100644 --- a/caikit_nlp/modules/text_generation/text_generation_tgis.py +++ b/caikit_nlp/modules/text_generation/text_generation_tgis.py @@ -178,7 +178,12 @@ def save(self, model_path: str): @TextGenerationTask.taskmethod() def run( - self, text, preserve_input_text=False, max_new_tokens=20, min_new_tokens=0 + self, + text, + preserve_input_text=False, + max_new_tokens=20, + min_new_tokens=0, + truncate_input_tokens=0, ) -> GeneratedTextResult: """Run inference against the model running in TGIS. @@ -194,18 +199,32 @@ def run( min_new_tokens: int The minimum numbers of tokens to generate. Default: 0 - means no minimum + truncate_input_tokens: int + Truncate inputs to provided number of tokens. This can be + use to avoid failing due to input being longer than + configured limits. + Default: 0 - means don't truncate, thus throw error. Returns: GeneratedTextResult Generated text result produced by TGIS. """ if self._model_loaded: return self.tgis_generation_client.unary_generate( - text, preserve_input_text, max_new_tokens, min_new_tokens + text, + preserve_input_text, + max_new_tokens, + min_new_tokens, + truncate_input_tokens, ) @TextGenerationTask.taskmethod(output_streaming=True) def run_stream_out( - self, text: str, preserve_input_text=False, max_new_tokens=20, min_new_tokens=0 + self, + text: str, + preserve_input_text=False, + max_new_tokens=20, + min_new_tokens=0, + truncate_input_tokens=0, ) -> Iterable[GeneratedTextStreamResult]: """Run output stream inferencing for text generation module. @@ -219,11 +238,19 @@ def run_stream_out( Maximum tokens for the model to generate min_new_tokens: int Minimum tokens for the model to generate - + truncate_input_tokens: int + Truncate inputs to provided number of tokens. This can be + use to avoid failing due to input being longer than + configured limits. + Default: 0 - means don't truncate, thus throw error. Returns: Iterable[GeneratedTextStreamResult] """ if self._model_loaded: return self.tgis_generation_client.stream_generate( - text, preserve_input_text, max_new_tokens, min_new_tokens + text, + preserve_input_text, + max_new_tokens, + min_new_tokens, + truncate_input_tokens, ) diff --git a/caikit_nlp/toolkit/tgis_utils.py b/caikit_nlp/toolkit/tgis_utils.py index be32657d..a1d44d4c 100644 --- a/caikit_nlp/toolkit/tgis_utils.py +++ b/caikit_nlp/toolkit/tgis_utils.py @@ -31,19 +31,27 @@ error = error_handler.get(log) -def get_params(preserve_input_text, eos_token, max_new_tokens, min_new_tokens): +def get_params( + preserve_input_text, + eos_token, + max_new_tokens, + min_new_tokens, + truncate_input_tokens, +): """Get generation parameters Args: - preserve_input_text: str + preserve_input_text: str Whether or not the source string should be contained in the generated output, e.g., as a prefix. - eos_token: str + eos_token: str A special token representing the end of a sentence. - max_new_tokens: int + max_new_tokens: int The maximum numbers of tokens to generate. - min_new_tokens: int + min_new_tokens: int The minimum numbers of tokens to generate. + truncate_input_tokens: int + Truncate inputs to provided number of tokens. """ res_options = generation_pb2.ResponseOptions( input_text=preserve_input_text, @@ -60,6 +68,7 @@ def get_params(preserve_input_text, eos_token, max_new_tokens, min_new_tokens): params = generation_pb2.Parameters( response=res_options, stopping=stopping, + truncate_input_tokens=truncate_input_tokens, ) return params @@ -77,7 +86,12 @@ def __init__( self.prefix_id = prefix_id def unary_generate( - self, text, preserve_input_text, max_new_tokens, min_new_tokens + self, + text, + preserve_input_text, + max_new_tokens, + min_new_tokens, + truncate_input_tokens, ) -> GeneratedTextResult: """Generate unary output from model in TGIS @@ -93,7 +107,11 @@ def unary_generate( min_new_tokens: int The minimum numbers of tokens to generate. Default: 0 - means no minimum - + truncate_input_tokens: int + Truncate inputs to provided number of tokens. This can be + use to avoid failing due to input being longer than + configured limits. + 0 - means don't truncate, thus throw error. Returns: GeneratedTextResult Generated text result produced by TGIS. @@ -114,6 +132,7 @@ def unary_generate( eos_token=self.eos_token, max_new_tokens=max_new_tokens, min_new_tokens=min_new_tokens, + truncate_input_tokens=truncate_input_tokens, ) gen_reqs = [generation_pb2.GenerationRequest(text=text)] @@ -150,7 +169,12 @@ def unary_generate( ) def stream_generate( - self, text, preserve_input_text, max_new_tokens, min_new_tokens + self, + text, + preserve_input_text, + max_new_tokens, + min_new_tokens, + truncate_input_tokens, ) -> Iterable[GeneratedTextStreamResult]: """Generate stream output from model in TGIS @@ -164,6 +188,11 @@ def stream_generate( Maximum tokens for the model to generate min_new_tokens: int Minimum tokens for the model to generate + truncate_input_tokens: int + Truncate inputs to provided number of tokens. This can be + use to avoid failing due to input being longer than + configured limits. + 0 - means don't truncate, thus throw error. Returns: Iterable[GeneratedTextStreamResult] @@ -183,6 +212,7 @@ def stream_generate( eos_token=self.eos_token, max_new_tokens=max_new_tokens, min_new_tokens=min_new_tokens, + truncate_input_tokens=truncate_input_tokens, ) gen_req = generation_pb2.GenerationRequest(text=text) From a747ce7023e2381e68926b263d44f8c4e1f443de Mon Sep 17 00:00:00 2001 From: Gaurav Kumbhat Date: Mon, 7 Aug 2023 16:49:12 -0500 Subject: [PATCH 09/12] Apply suggestions from code review Co-authored-by: Alex Brooks Signed-off-by: Gaurav Kumbhat Signed-off-by: gkumbhat --- examples/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/utils.py b/examples/utils.py index f3c610be..cfb944fd 100644 --- a/examples/utils.py +++ b/examples/utils.py @@ -100,7 +100,7 @@ def get_distributed_model(model_path): # make sure that its suffix (base model name) aligns with what we have in our config. # NOTE: bloom-560m is the default here because that's the default model used in our # text generation server hack script. - model_name_override = os.getenv("MODEL_NAME", "bloom-560m") + model_name_override = os.getenv("MODEL_NAME", "bigscience/bloom-560m") if hasattr(dist_model, "base_model_name"): loaded_base_model = dist_model.base_model_name else: From 9a430d3ee8b101d7865b515ff4e37f8a288a6b3c Mon Sep 17 00:00:00 2001 From: gkumbhat Date: Mon, 7 Aug 2023 17:33:24 -0500 Subject: [PATCH 10/12] :bug: Remove temp hack for handling truncation in TGIS long Signed-off-by: gkumbhat Signed-off-by: gkumbhat --- .../text_generation/text_generation_local.py | 36 ++++++++++++++----- .../text_generation/text_generation_tgis.py | 25 ++++++++++--- examples/run_fine_tuning.py | 18 +++++----- 3 files changed, 57 insertions(+), 22 deletions(-) diff --git a/caikit_nlp/modules/text_generation/text_generation_local.py b/caikit_nlp/modules/text_generation/text_generation_local.py index 7ab0d89f..21661dbd 100644 --- a/caikit_nlp/modules/text_generation/text_generation_local.py +++ b/caikit_nlp/modules/text_generation/text_generation_local.py @@ -379,13 +379,14 @@ def save(self, model_path): def run( self, - text, - repetition_penalty=2.5, - length_penalty=1.0, - early_stopping=True, - num_beams=1, - max_new_tokens=20, - min_new_tokens=0, + text: str, + repetition_penalty: float = 2.5, + length_penalty: float = 1.0, + early_stopping: bool = True, + num_beams: int = 1, + max_new_tokens: int = 20, + min_new_tokens: int = 0, + truncate_input_tokens: int = 0, **kwargs, ) -> "GeneratedTextResult": """Run inference against the model running in TGIS. @@ -421,6 +422,11 @@ def run( min_new_tokens: int The minimum numbers of tokens to generate. Default: 0 - means no minimum + truncate_input_tokens: int + Truncate inputs to provided number of tokens. This can be + use to avoid failing due to input being longer than + configured limits. + Default: 0 - means don't truncate, thus throw error. kwargs: Any other parameters to pass to generate as specified in GenerationConfig. https://huggingface.co/docs/transformers/v4.30.0/en/main_classes/text_generation#transformers.GenerationConfig @@ -429,7 +435,21 @@ def run( Generated text result produced by the model. """ - inputs = self.model.tokenizer(text, return_tensors="pt") + # NOTE: below is to match TGIS API, where 0 identifies as no truncation + if truncate_input_tokens: + # NOTE: below will make model throw error in case inputs are longer + # than allowed length + truncation = False + + else: + truncation = True + + inputs = self.model.tokenizer( + text, + truncation=truncation, + max_length=truncate_input_tokens, + return_tensors="pt", + ) generate_ids = self.model.model.generate( input_ids=inputs["input_ids"], num_beams=num_beams, diff --git a/caikit_nlp/modules/text_generation/text_generation_tgis.py b/caikit_nlp/modules/text_generation/text_generation_tgis.py index fcc06aa9..96a3d4b0 100644 --- a/caikit_nlp/modules/text_generation/text_generation_tgis.py +++ b/caikit_nlp/modules/text_generation/text_generation_tgis.py @@ -97,6 +97,21 @@ def __del__(self): @classmethod def bootstrap(cls, model_path: str, load_backend: Union[BackendBase, None] = None): + """Function to bootstrap a pre-trained transformers model and + get a caikit text-generation 'model'. + + Args: + base_model_path: str + Path to transformers model + NOTE: Model path needs to contain tokenizer as well + load_backend: BackendBase + Backend object to be used to run inference with. + NOTE: this is required for inferencing. It is + made optional to support the model conversion use-case + Returns: + caikit_nlp.blocks.text_generation.TextGeneration + Object of TextGeneration class (model) + """ text_generation_inst = TextGeneration.bootstrap(model_path) bos_token = text_generation_inst.model._tokenizer.bos_token @@ -179,11 +194,11 @@ def save(self, model_path: str): @TextGenerationTask.taskmethod() def run( self, - text, - preserve_input_text=False, - max_new_tokens=20, - min_new_tokens=0, - truncate_input_tokens=0, + text: str, + preserve_input_text: bool = False, + max_new_tokens: int = 20, + min_new_tokens: int = 0, + truncate_input_tokens: int = 0, ) -> GeneratedTextResult: """Run inference against the model running in TGIS. diff --git a/examples/run_fine_tuning.py b/examples/run_fine_tuning.py index 65733497..71e7b2d5 100644 --- a/examples/run_fine_tuning.py +++ b/examples/run_fine_tuning.py @@ -232,7 +232,7 @@ def show_experiment_configuration(args, dataset_info, model_type) -> None: print_colored("\n".join([print_str for print_str in print_strs if print_str])) -def get_model_preds_and_references(model, validation_stream): +def get_model_preds_and_references(model, validation_stream, truncate_input_tokens): """Given a model & a validation stream, run the model against every example in the validation stream and compare the outputs to the target/output sequence. @@ -242,7 +242,9 @@ def get_model_preds_and_references(model, validation_stream): validation_stream: DataStream[GenerationTrainRecord] Validation stream with labeled targets that we want to compare to our model's predictions. - + truncate_input_tokens: int + maximum number of tokens to be accepted by the model and rest will be + truncated. Returns: Tuple(List) Tuple of 2 lists; the model predictions and the expected output sequences. @@ -253,12 +255,9 @@ def get_model_preds_and_references(model, validation_stream): for datum in tqdm(validation_stream): # Local .run() currently prepends the input text to the generated string; # Ensure that we're just splitting the first predicted token & beyond. - # TEMP HACK to avoid too big input to TGIS problem: - try: - raw_model_text = model.run(datum.input).generated_text - except Exception as ex: - print(ex) - continue + raw_model_text = model.run( + datum.input, truncate_input_tokens=truncate_input_tokens + ).generated_text parse_pred_text = raw_model_text.split(datum.input)[-1].strip() model_preds.append(parse_pred_text) targets.append(datum.output) @@ -359,8 +358,9 @@ def export_model_preds(preds_file, predictions, validation_stream): validation_stream = dataset_info.dataset_loader()[1] print_colored("Getting model predictions...") + truncate_input_tokens = args.args.max_source_length + args.max_target_length predictions, references = get_model_preds_and_references( - loaded_model, validation_stream + loaded_model, validation_stream, truncate_input_tokens ) export_model_preds(args.preds_file, predictions, validation_stream) From 4fed554722f131524f0660e30dd41ebadf01fbee Mon Sep 17 00:00:00 2001 From: gkumbhat Date: Mon, 7 Aug 2023 17:48:45 -0500 Subject: [PATCH 11/12] :bug: Fix token truncation input condition Signed-off-by: gkumbhat Signed-off-by: gkumbhat --- caikit_nlp/modules/text_generation/text_generation_local.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/caikit_nlp/modules/text_generation/text_generation_local.py b/caikit_nlp/modules/text_generation/text_generation_local.py index 21661dbd..0d2223b9 100644 --- a/caikit_nlp/modules/text_generation/text_generation_local.py +++ b/caikit_nlp/modules/text_generation/text_generation_local.py @@ -436,7 +436,7 @@ def run( """ # NOTE: below is to match TGIS API, where 0 identifies as no truncation - if truncate_input_tokens: + if truncate_input_tokens == 0: # NOTE: below will make model throw error in case inputs are longer # than allowed length truncation = False From 6907c9d8f71a831c539d5409e0615db4a9523b32 Mon Sep 17 00:00:00 2001 From: gkumbhat Date: Mon, 7 Aug 2023 18:20:02 -0500 Subject: [PATCH 12/12] :art: Move artifact path to use os path join Signed-off-by: gkumbhat Signed-off-by: gkumbhat --- examples/run_fine_tuning.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/run_fine_tuning.py b/examples/run_fine_tuning.py index 71e7b2d5..88146114 100644 --- a/examples/run_fine_tuning.py +++ b/examples/run_fine_tuning.py @@ -342,9 +342,9 @@ def export_model_preds(preds_file, predictions, validation_stream): # HACK: export args.output_dir as MODEL_NAME for TGIS # container to pick up automatically os.environ["MODEL_DIR"] = os.path.dirname(args.output_dir) - os.environ[ - "MODEL_NAME" - ] = f"/models/{os.path.basename(args.output_dir)}/artifacts" + os.environ["MODEL_NAME"] = os.path.join( + "models", os.path.basename(args.output_dir), "artifacts" + ) loaded_model = load_model(is_distributed=True, model_path=args.output_dir)