diff --git a/CHANGELOG.md b/CHANGELOG.md index f360bd9b..5b7867f7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,7 +5,14 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). -## [0.65.0] - 2025-01-31 +## [0.70.0] - 2024-02-15 + +### Added + +- Added support for using LLMs with `LanguageModelingModel` (with or without quantization) +- Added support for adapter tuning LLMs + +## [0.65.0] - 2024-01-31 - Lots of QOL improvements - Added support for evaluating retrieval models with `pytrec_eval` diff --git a/README.md b/README.md index 867f8fb5..90184be9 100755 --- a/README.md +++ b/README.md @@ -9,17 +9,17 @@ This library is based on the [Transformers](https://github.com/huggingface/trans **Supported Tasks:** +- Information Retrieval (Dense Retrieval) +- (Large) Language Models (Training, Fine-tuning, and Generation) +- Encoder Model Training and Fine-tuning - Sequence Classification - Token Classification (NER) - Question Answering -- Language Model Fine-Tuning -- Language Model Training - Language Generation - T5 Model - Seq2Seq Tasks - Multi-Modal Classification - Conversational AI. -- Text Representation Generation. # Table of contents diff --git a/examples/llms/download_squad.ipynb b/examples/llms/download_squad.ipynb new file mode 100644 index 00000000..ed9d3077 --- /dev/null +++ b/examples/llms/download_squad.ipynb @@ -0,0 +1,375 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import datasets" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "bfc6efd35877482fab97a58798a178be", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Downloading readme: 0%| | 0.00/7.83k [00:00\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
text
0Question: To whom did the Virgin Mary allegedl...
1Question: What is in front of the Notre Dame M...
2Question: The Basilica of the Sacred heart at ...
3Question: What is the Grotto at Notre Dame? An...
4Question: What sits on top of the Main Buildin...
......
87594Question: In what US state did Kathmandu first...
87595Question: What was Yangon previously known as?...
87596Question: With what Belorussian city does Kath...
87597Question: In what year did Kathmandu create it...
87598Question: What is KMC an initialism of? Answer...
\n", + "

87599 rows × 1 columns

\n", + "" + ], + "text/plain": [ + " text\n", + "0 Question: To whom did the Virgin Mary allegedl...\n", + "1 Question: What is in front of the Notre Dame M...\n", + "2 Question: The Basilica of the Sacred heart at ...\n", + "3 Question: What is the Grotto at Notre Dame? An...\n", + "4 Question: What sits on top of the Main Buildin...\n", + "... ...\n", + "87594 Question: In what US state did Kathmandu first...\n", + "87595 Question: What was Yangon previously known as?...\n", + "87596 Question: With what Belorussian city does Kath...\n", + "87597 Question: In what year did Kathmandu create it...\n", + "87598 Question: What is KMC an initialism of? Answer...\n", + "\n", + "[87599 rows x 1 columns]" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "train_df = pd.DataFrame(train_texts, columns=[\"text\"])\n", + "eval_df = pd.DataFrame({\"text\": eval_texts, \"answer\": eval_answers})\n", + "\n", + "train_df" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [], + "source": [ + "train_df.to_json(\"../data/squad-train.jsonl\", orient=\"records\", lines=True)\n", + "eval_df.to_json(\"../data/squad-eval.jsonl\", orient=\"records\", lines=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [], + "source": [ + "train_df[\"rag_query\"] = train_data[\"question\"]\n", + "eval_df[\"rag_query\"] = eval_data[\"question\"]" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [], + "source": [ + "train_df.to_json(\"../data/squad-train.jsonl\", orient=\"records\", lines=True)\n", + "eval_df.to_json(\"../data/squad-eval.jsonl\", orient=\"records\", lines=True)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "st", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/examples/llms/train.py b/examples/llms/train.py new file mode 100644 index 00000000..c8eb2bf2 --- /dev/null +++ b/examples/llms/train.py @@ -0,0 +1,98 @@ +import logging +import pandas as pd + +from simpletransformers.language_modeling import LanguageModelingModel, LanguageModelingArgs, GenerationArgs +from simpletransformers.retrieval import ( + RetrievalModel, + RetrievalArgs, +) + + +logging.basicConfig(level=logging.INFO) +transformers_logger = logging.getLogger("transformers") +transformers_logger.setLevel(logging.WARNING) +trainsformers_modules_logger = logging.getLogger("transformers_modules") +trainsformers_modules_logger.setLevel(logging.ERROR) + + +rag = True + +if rag: + from rag_setup import model as rag_model + + +train_file = "../data/squad-train.jsonl" + + +model_args = LanguageModelingArgs() +model_args.reprocess_input_data = True +model_args.overwrite_output_dir = True +model_args.num_train_epochs = 1 +model_args.save_eval_checkpoints = False +model_args.save_model_every_epoch = False +model_args.train_batch_size = 2 +model_args.eval_batch_size = 4 +model_args.gradient_accumulation_steps = 1 +model_args.manual_seed = 4 +model_args.fp16 = True +model_args.dataset_type = "simple" +model_args.logging_steps = 100 +model_args.evaluate_during_training = False +model_args.mlm = False +model_args.use_multiprocessing = False +model_args.use_hf_datasets = True +model_args.peft = True +model_args.qlora = False +model_args.nf4 = True +model_args.loftq_bits = 4 +model_args.lora_config = {"r": 8} +model_args.data_format = "jsonl" +model_args.trust_remote_code = True +model_args.save_steps = 1000 +model_args.optimizer = "Adam8bit" +model_args.chunk_text = False +model_args.max_seq_length = 500 + +if not rag: + model_args.wandb_project = "llama-adapter-tuning-squad" + model_args.wandb_kwargs = {"name": "squad-llama-2-7b-vanilla"} + model_args.output_dir = "squad-llama-2-7b" + +if rag: + model_args.rag = rag + model_args.wandb_project = "llama-adapter-tuning-squad" + model_args.wandb_kwargs = {"name": "squad-llama-2-7b-rag"} + model_args.output_dir = "squad-llama-2-7b-rag" + + +model = LanguageModelingModel( + "causal", + # "outputs", + # "stabilityai/stablelm-zephyr-3b", + "meta-llama/Llama-2-7b-hf", + args=model_args, + retrieval_model=rag_model if rag else None, +) + +model.train_model( + train_file, +) + +# generation_args = GenerationArgs() +# generation_args.max_length = None +# generation_args.max_new_tokens = 100 + +# test_df = pd.read_json("data/test.jsonl", lines=True) + +# to_predict = test_df["input_text"].tolist() + +# responses, _ = model.predict( +# to_predict, +# generation_args=generation_args, +# ) + +# print(responses[:5]) + +# test_df["generated_text"] = responses + +# test_df.to_json("data/test_output-finetuned.jsonl", orient="records", lines=True) \ No newline at end of file diff --git a/examples/retrieval/download_msmarco.py b/examples/retrieval/download_msmarco.py new file mode 100644 index 00000000..7976fcda --- /dev/null +++ b/examples/retrieval/download_msmarco.py @@ -0,0 +1,22 @@ +import os +from datasets import load_dataset + + +os.makedirs("data/msmarco", exist_ok=True) + +print("=== Downloading MSMARCO ===") +print("Downloading MSMARCO training triples...") +dataset = load_dataset("thilina/negative-sampling")["train"] + +print("Dataset loaded. Sample:") +print(dataset[0]) + +qrels = load_dataset("BeIR/msmarco-qrels")["validation"] + +print("Saving dataset to disk...") +# Save the dataset to disk +dataset.to_csv("data/msmarco/msmarco-train.tsv", sep="\t", index=False) +qrels.to_csv("data/msmarco/devs.tsv", sep="\t", index=False) + +print("Done.") +print("=== MSMARCO download complete ===") diff --git a/examples/retrieval/train_dpr_base.py b/examples/retrieval/train_dpr_base.py new file mode 100644 index 00000000..c6562704 --- /dev/null +++ b/examples/retrieval/train_dpr_base.py @@ -0,0 +1,64 @@ +import logging + +import pandas as pd +from simpletransformers.retrieval import RetrievalModel, RetrievalArgs + +# Configuring logging +logging.basicConfig(level=logging.INFO) +transformers_logger = logging.getLogger("transformers") +transformers_logger.setLevel(logging.WARNING) + +# Specifying the path to the training data +train_data_path = "data/msmarco/msmarco-train.tsv" + +# Loading the training data +if train_data_path.endswith(".tsv"): + train_data = pd.read_csv(train_data_path, sep="\t") +else: + train_data = train_data_path + +# Configuring the model arguments +model_args = RetrievalArgs() +model_args.reprocess_input_data = True +model_args.overwrite_output_dir = True +model_args.use_cached_eval_features = False +model_args.include_title = False if "msmarco" in train_data_path else True +model_args.max_seq_length = 256 +model_args.num_train_epochs = 40 +model_args.train_batch_size = 16 +model_args.use_hf_datasets = True +model_args.learning_rate = 1e-6 +model_args.warmup_steps = 5000 +model_args.save_steps = 300000 +model_args.evaluate_during_training = True +model_args.evaluate_during_training_steps = False +model_args.save_model_every_epoch = False +model_args.wandb_project = "Retrieval training example" +model_args.hard_negatives = False +model_args.n_gpu = 1 +model_args.data_format = "beir" +model_args.output_dir = f"trained_models/pretrained/DPR-base-msmarco" +model_args.wandb_kwargs = {"name": f"DPR-base-msmarco"} + +# Defining the model type and names +model_type = "custom" +model_name = None +context_name = "bert-base-multilingual-cased" +question_name = "bert-base-multilingual-cased" + +# Main execution +if __name__ == "__main__": + # Creating the model + model = RetrievalModel( + model_type, + model_name, + context_name, + question_name, + args=model_args, + ) + + # Training the model + model.train_model( + train_data, + eval_set="dev", + ) diff --git a/setup.py b/setup.py index 7d924cc3..657771f1 100755 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setup( name="simpletransformers", - version="0.65.1", + version="0.70.0", author="Thilina Rajapakse", author_email="chaturangarajapakshe@gmail.com", description="An easy-to-use wrapper library for the Transformers library.", diff --git a/simpletransformers/classification/classification_model.py b/simpletransformers/classification/classification_model.py index 66b01238..6ea168b4 100755 --- a/simpletransformers/classification/classification_model.py +++ b/simpletransformers/classification/classification_model.py @@ -22,12 +22,14 @@ from scipy.special import softmax from sklearn.metrics import ( confusion_matrix, + f1_score, label_ranking_average_precision_score, matthews_corrcoef, mean_squared_error, roc_curve, auc, average_precision_score, + accuracy_score, ) from torch.utils.tensorboard import SummaryWriter from torch.nn import CrossEntropyLoss @@ -1155,8 +1157,7 @@ def train( epoch_number += 1 output_dir_current = os.path.join( output_dir, - "checkpoint-{}-epoch-{}".format(global_step, epoch_number) -, + "checkpoint-{}-epoch-{}".format(global_step, epoch_number), ) if args.save_model_every_epoch or args.evaluate_during_training: @@ -1924,6 +1925,8 @@ def compute_metrics( return {**extra_metrics}, wrong mcc = matthews_corrcoef(labels, preds) + accuracy = accuracy_score(labels, preds) + f1 = f1_score(labels, preds) if self.model.num_labels == 2: tn, fp, fn, tp = confusion_matrix(labels, preds, labels=[0, 1]).ravel() if self.args.sliding_window: @@ -1943,6 +1946,8 @@ def compute_metrics( { **{ "mcc": mcc, + "accuracy": accuracy, + "f1_score": f1, "tp": tp, "tn": tn, "fp": fp, diff --git a/simpletransformers/classification/multi_modal_classification_model.py b/simpletransformers/classification/multi_modal_classification_model.py index b59acca7..9ec83919 100644 --- a/simpletransformers/classification/multi_modal_classification_model.py +++ b/simpletransformers/classification/multi_modal_classification_model.py @@ -814,8 +814,7 @@ def train( epoch_number += 1 output_dir_current = os.path.join( output_dir, - "checkpoint-{}-epoch-{}".format(global_step, epoch_number) -, + "checkpoint-{}-epoch-{}".format(global_step, epoch_number), ) if args.save_model_every_epoch or args.evaluate_during_training: diff --git a/simpletransformers/config/model_args.py b/simpletransformers/config/model_args.py index 2ab9da85..7e55cc57 100644 --- a/simpletransformers/config/model_args.py +++ b/simpletransformers/config/model_args.py @@ -93,6 +93,7 @@ class ModelArgs: tokenizer_type: str = None train_batch_size: int = 8 train_custom_parameters_only: bool = False + trust_remote_code: bool = False use_cached_eval_features: bool = False use_early_stopping: bool = False use_hf_datasets: bool = False @@ -246,6 +247,36 @@ class T5Args(ModelArgs): use_multiprocessed_decoding: bool = True +@dataclass +class GenerationArgs: + """ + Args for language generation. + """ + + max_length: int = 20 + max_new_tokens: int = None + min_length: int = 0 + min_new_tokens: int = None + early_stopping: bool = False + max_time: float = None + + do_sample: bool = False + num_beams: int = 1 + num_beam_groups: int = 1 + penalty_alpha: float = None + use_cache: bool = True + + temperature: float = 1.0 + top_k: int = 50 + top_p: float = 1.0 + repetition_penalty: float = 1.0 + diversity_penalty: float = 0.0 + + def get_dict(self): + d = asdict(self) + return {k: v for k, v in d.items() if v is not None} + + @dataclass class LanguageModelingArgs(ModelArgs): """ @@ -254,6 +285,7 @@ class LanguageModelingArgs(ModelArgs): model_class: str = "LanguageModelingModel" block_size: int = -1 + chunk_text: bool = True config_name: str = None dataset_class: Dataset = None dataset_type: str = "None" @@ -276,6 +308,14 @@ class LanguageModelingArgs(ModelArgs): special_tokens_list: list = field(default_factory=list) strip_accents: bool = True local_rank: int = -1 + loftq_bits: int = 4 + loftq_config: dict = field(default_factory=dict) + lora_config: dict = field(default_factory=dict) + peft: bool = False + qlora: bool = False + rag: bool = False + rag_replace_method: str = "prepend" + nf4: bool = False use_autoencoder: bool = False stream_hf_datasets: bool = False diff --git a/simpletransformers/conv_ai/conv_ai_model.py b/simpletransformers/conv_ai/conv_ai_model.py index b8162f9f..f7a33da7 100644 --- a/simpletransformers/conv_ai/conv_ai_model.py +++ b/simpletransformers/conv_ai/conv_ai_model.py @@ -723,8 +723,7 @@ def train( epoch_number += 1 output_dir_current = os.path.join( output_dir, - "checkpoint-{}-epoch-{}".format(global_step, epoch_number) -, + "checkpoint-{}-epoch-{}".format(global_step, epoch_number), ) if args.save_model_every_epoch or args.evaluate_during_training: diff --git a/simpletransformers/language_modeling/__init__.py b/simpletransformers/language_modeling/__init__.py index bde1ccf1..264f6e5e 100755 --- a/simpletransformers/language_modeling/__init__.py +++ b/simpletransformers/language_modeling/__init__.py @@ -1,4 +1,4 @@ -from simpletransformers.config.model_args import LanguageModelingArgs +from simpletransformers.config.model_args import LanguageModelingArgs, GenerationArgs from simpletransformers.language_modeling.language_modeling_model import ( LanguageModelingModel, ) diff --git a/simpletransformers/language_modeling/language_modeling_model.py b/simpletransformers/language_modeling/language_modeling_model.py index c9cdb307..31dea2f4 100755 --- a/simpletransformers/language_modeling/language_modeling_model.py +++ b/simpletransformers/language_modeling/language_modeling_model.py @@ -49,6 +49,7 @@ from simpletransformers.custom_models.models import RobertaWithAutoEncoderForMaskedLM from transformers import DummyObject, requires_backends +from torch.utils.data import DataLoader, TensorDataset class NystromformerTokenizer(metaclass=DummyObject): @@ -63,6 +64,7 @@ def __init__(self, *args, **kwargs): AutoConfig, AutoModelWithLMHead, AutoTokenizer, + AutoModelForCausalLM, BertConfig, BertForMaskedLM, BertTokenizer, @@ -102,6 +104,7 @@ def __init__(self, *args, **kwargs): XLMRobertaConfig, XLMRobertaForMaskedLM, XLMRobertaTokenizer, + GenerationConfig, ) from transformers.data.datasets.language_modeling import ( LineByLineTextDataset, @@ -109,7 +112,7 @@ def __init__(self, *args, **kwargs): ) from simpletransformers.config.global_args import global_args -from simpletransformers.config.model_args import LanguageModelingArgs +from simpletransformers.config.model_args import LanguageModelingArgs, GenerationArgs from simpletransformers.config.utils import sweep_config_to_sweep_values from simpletransformers.custom_models.models import ElectraForLanguageModelingModel from simpletransformers.language_modeling.language_modeling_utils import ( @@ -132,6 +135,7 @@ def __init__(self, *args, **kwargs): "bert": (BertConfig, BertForMaskedLM, BertTokenizer), "bigbird": (BigBirdConfig, BigBirdForMaskedLM, BigBirdTokenizer), "camembert": (CamembertConfig, CamembertForMaskedLM, CamembertTokenizer), + "causal": (AutoConfig, AutoModelForCausalLM, AutoTokenizer), "distilbert": (DistilBertConfig, DistilBertForMaskedLM, DistilBertTokenizer), "electra": (ElectraConfig, ElectraForLanguageModelingModel, ElectraTokenizer), "gpt2": (GPT2Config, GPT2LMHeadModel, GPT2Tokenizer), @@ -159,7 +163,9 @@ def __init__( train_files=None, args=None, use_cuda=True, - autoencoder_model=None, + retrieval_model=None, + adapter_name=None, + # autoencoder_model=None, cuda_device=-1, **kwargs, ): @@ -173,6 +179,8 @@ def __init__( discriminator_name (optional): A pretrained model name or path to a directory containing an ELECTRA discriminator model. args (optional): Default args will be used if this parameter is not provided. If provided, it should be a dict containing the args that should be changed in the default args. train_files (optional): List of files to be used when training the tokenizer. + rag_corpus (optional): A collection of documents to be used for Retrieval-Augmented Generation. This may + retrieval_model (optional): A pretrained model name or path to a directory containing a retrieval model. This should be preloaded with a knowledge index. use_cuda (optional): Use GPU if available. Setting to False will force model to use CPU only. cuda_device (optional): Specific GPU that should be used. Will use the first available GPU by default. **kwargs (optional): For providing proxies, force_download, resume_download, cache_dir and other options specific to the 'from_pretrained' implementation where this will be supplied. @@ -258,11 +266,17 @@ def __init__( if self.args.config_name: self.config = config_class.from_pretrained( - self.args.config_name, cache_dir=self.args.cache_dir + self.args.config_name, + cache_dir=self.args.cache_dir, + trust_remote_code=self.args.trust_remote_code, + **kwargs, ) elif self.args.model_name and self.args.model_name != "electra": self.config = config_class.from_pretrained( - model_name, cache_dir=self.args.cache_dir, **kwargs + model_name, + cache_dir=self.args.cache_dir, + trust_remote_code=self.args.trust_remote_code, + **kwargs, ) else: self.config = config_class(**self.args.config, **kwargs) @@ -359,12 +373,65 @@ def __init__( ) ) else: - self.model = model_class.from_pretrained( - model_name, - config=self.config, - cache_dir=self.args.cache_dir, - **kwargs, - ) + if self.args.nf4: + from transformers import BitsAndBytesConfig + + + nf4_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_use_double_quant=True, + bnb_4bit_compute_dtype=torch.bfloat16, + ) + self.model = model_class.from_pretrained( + model_name, + quantization_config=nf4_config, + trust_remote_code=self.args.trust_remote_code, + ) + else: + self.model = model_class.from_pretrained( + model_name, + config=self.config, + cache_dir=self.args.cache_dir, + **kwargs, + ) + + if self.args.peft: + from peft import LoraConfig, get_peft_model, LoftQConfig + from peft.peft_model import PeftModel + + if self.args.qlora: + if self.args.nf4: + raise ValueError( + "PEFT and QLORA cannot be used together with NF4" + ) + loftq_config = LoftQConfig( + loftq_bits=self.args.loftq_bits, **self.args.loftq_config + ) + self.lora_config = LoraConfig( + init_lora_weights="loftq", + target_modules="all-linear", + loftq_config=loftq_config, + **self.args.lora_config, + ) + self.args.fp16 = False + else: + self.lora_config = LoraConfig( + use_rslora=True, target_modules="all-linear" + ) + self.model.gradient_checkpointing_enable() + self.model.enable_input_require_grads() + if adapter_name is not None: + self.model = PeftModel.from_pretrained( + self.model, + model_id=adapter_name, + adapter_name=adapter_name, + ) + self.adapter_name = adapter_name + else: + self.model = get_peft_model(self.model, self.lora_config) + self.model.print_trainable_parameters() + else: logger.info(" Training language model from scratch") if self.args.model_type == "electra": @@ -432,6 +499,15 @@ def __init__( ) self.args.wandb_project = None + if self.args.rag: + if retrieval_model: + self.retrieval_model = retrieval_model + else: + raise ValueError( + "RAG is enabled but no retrieval model is specified." + " Pass a retrieval model when instantiating the LanguageModelingModel to use RAG." + ) + def train_model( self, train_file, @@ -608,6 +684,15 @@ def collate(examples: List[torch.Tensor]): relative_step=args.adafactor_relative_step, warmup_init=args.adafactor_warmup_init, ) + elif args.optimizer == "Adam8bit": + from bitsandbytes.optim import Adam8bit + + optimizer = Adam8bit( + optimizer_grouped_parameters, + lr=args.learning_rate, + eps=args.adam_epsilon, + betas=args.adam_betas, + ) else: raise ValueError( @@ -769,7 +854,6 @@ def collate(examples: List[torch.Tensor]): # TODO: Move this to _get_inputs_dict and keep the attention masks if self.args.use_hf_datasets: if self.args.stream_hf_datasets: - # BUG: TOKENIZATION IS BUGGED FOR HF DATASETS batch["input_ids"] = torch.stack(batch["input_ids"]) batch["attention_mask"] = torch.stack(batch["attention_mask"]) if self.args.model_type in [ @@ -806,16 +890,22 @@ def collate(examples: List[torch.Tensor]): ) token_type_ids = ( batch["token_type_ids"].to(self.device) - if self.args.use_hf_datasets + if self.args.use_hf_datasets and "token_type_ids" in batch else None ) labels = labels.to(self.device) - inputs_dict = { - "input_ids": inputs, - "attention_mask": attention_mask, - "token_type_ids": token_type_ids, - } + if token_type_ids is None: + inputs_dict = { + "input_ids": inputs, + "attention_mask": attention_mask, + } + else: + inputs_dict = { + "input_ids": inputs, + "attention_mask": attention_mask, + "token_type_ids": token_type_ids, + } if args.fp16: with amp.autocast(): @@ -1100,8 +1190,7 @@ def collate(examples: List[torch.Tensor]): epoch_number += 1 output_dir_current = os.path.join( output_dir, - "checkpoint-{}-epoch-{}".format(global_step, epoch_number) -, + "checkpoint-{}-epoch-{}".format(global_step, epoch_number), ) if args.save_model_every_epoch or args.evaluate_during_training: @@ -1249,7 +1338,13 @@ def collate(examples: List[torch.Tensor]): ) def eval_model( - self, eval_file, output_dir=None, verbose=True, silent=False, **kwargs + self, + eval_file, + output_dir=None, + evaluate_generated_text=False, + verbose=True, + silent=False, + **kwargs, ): """ Evaluates the model on eval_df. Saves results to args.output_dir @@ -1261,26 +1356,30 @@ def eval_model( self._move_model_to_device() - eval_dataset = self.load_and_cache_examples( - eval_file, evaluate=True, verbose=verbose, silent=silent - ) - os.makedirs(output_dir, exist_ok=True) + if evaluate_generated_text: + raise NotImplementedError( + "evaluate_generated_text is not yet implemented for this model type." + ) + else: + eval_dataset = self.load_and_cache_examples( + eval_file, evaluate=True, verbose=verbose, silent=silent + ) + os.makedirs(output_dir, exist_ok=True) - result = self.evaluate( - eval_dataset, output_dir, verbose=verbose, silent=silent, **kwargs - ) - self.results.update(result) + result = self.evaluate( + eval_dataset, output_dir, verbose=verbose, silent=silent, **kwargs + ) + self.results.update(result) - if verbose: - logger.info(self.results) + if verbose: + logger.info(self.results) - return result + return result def evaluate( self, eval_dataset, output_dir, - multi_label=False, prefix="", verbose=True, silent=False, @@ -1331,18 +1430,39 @@ def collate(examples: List[torch.Tensor]): eval_dataloader, disable=args.silent or silent, desc="Running Evaluation" ): if self.args.use_hf_datasets: - batch = batch["input_ids"] + input_ids = batch["input_ids"] inputs, labels = ( - mask_tokens(batch, tokenizer, args) if args.mlm else (batch, batch) + mask_tokens(batch, tokenizer, args) + if args.mlm + else (input_ids, input_ids) ) inputs = inputs.to(self.device) labels = labels.to(self.device) + + if "token_type_ids" in batch: + inputs_dict = { + "input_ids": inputs, + "attention_mask": batch["attention_mask"].to(self.device) + if self.args.use_hf_datasets + else None, + "token_type_ids": batch["token_type_ids"].to(self.device) + if self.args.use_hf_datasets + else None, + } + else: + inputs_dict = { + "input_ids": inputs, + "attention_mask": batch["attention_mask"].to(self.device) + if self.args.use_hf_datasets + else None, + } + with torch.no_grad(): outputs = ( - model(inputs, labels=labels) + model(**inputs_dict, labels=labels) if args.mlm - else model(inputs, labels=labels) + else model(**inputs_dict, labels=labels) ) if args.model_type == "electra": g_loss = outputs[0] @@ -1368,6 +1488,127 @@ def collate(examples: List[torch.Tensor]): return results + def predict( + self, + to_predict, + generation_args=None, + rag_queries=None, + knowledge_dataset=None, + **kwargs, + ): + """ + Performs text completions on a list of text. To be used with language models. + + Args: + + to_predict: A list of text to make predictions on. + generation_args: An instance of the `GenerationArgs` class containing the generation arguments for the model. + rag_queries (optional): A list of text to be used as queries for the RAG model. Only applicable if rag is enabled. + knowledge_dataset (optional): A list of text to be used as knowledge for the RAG model. Only applicable if the model is a RAG model. + **kwargs: Additional arguments to be passed to the models `generate()` method during inference. + + Returns: + preds: A list of the predicted sequences. + """ + self._move_model_to_device() + + if not generation_args: + generation_args = GenerationArgs() + + if self.args.peft and self.adapter_name: + logger.info( + "Merging adapter with model for faster inference. Contunuing training from this point may result in unexpected behavior." + ) + self.model = self.model.merge_and_unload() + + self.tokenizer.padding_side = "left" + + if self.args.rag: + if not rag_queries: + rag_queries = to_predict + raise Warning( + "No `rag_queries` provided. Using `to_predict` as `rag_queries`." + ) + + context_docs = self.retrieval_model.predict( + rag_queries, + passages_only=True, + prediction_passages=knowledge_dataset, + ) + + to_predict = [ + f"Context: {' '.join(context_doc)} {text}" + for context_doc, text in zip(context_docs, to_predict) + ] + # TODO: + # - Simplest option is to just prepend context: context_docs to to_predict + # - Advanced option is to have ... in to_predict and then replace ... with context_docs + + try: + inputs = self.tokenizer( + to_predict, + padding=True, + return_tensors="pt", + ) + except ValueError: + if not self.tokenizer.pad_token: + warnings.warn( + "The tokenizer you are using does not have a pad token set. Setting to `eos_token`." + ) + self.tokenizer.pad_token = self.tokenizer.eos_token + inputs = self.tokenizer( + to_predict, + padding=True, + return_tensors="pt", + ) + + input_ids_tensor = inputs["input_ids"] + attention_mask_tensor = inputs["attention_mask"] + + # Create a TensorDataset + dataset = TensorDataset(input_ids_tensor, attention_mask_tensor) + + # Define batch size + + # Create the dataloader + predict_dataloader = DataLoader( + dataset, batch_size=self.args.eval_batch_size, shuffle=False + ) + + # Put model in evaluation mode + self.model.eval() + + # Predict + responses = [] + outputs = [] + for batch in tqdm( + predict_dataloader, desc="Generating outputs", disable=self.args.silent + ): + batch = tuple(t.to(self.device) for t in batch) + input_ids, attention_mask = batch + + generation_output = self.model.generate( + input_ids, + attention_mask=attention_mask, + return_dict_in_generate=True, + output_scores=True, + **generation_args.get_dict(), + **kwargs, + ) + + # response_tests = self.tokenizer.batch_decode(generation_output.sequences[:, input_ids.shape[1]:], skip_special_tokens=True) + + for i, s in enumerate(generation_output.sequences): + output = self.tokenizer.decode( + s[input_ids[i].shape[0] :], skip_special_tokens=True + ) + responses.append(output) + + # responses.extend(response_tests) + outputs.extend(generation_output) + + return responses, generation_output + def load_and_cache_examples( self, file_path, evaluate=False, no_cache=False, verbose=True, silent=False ): @@ -1389,7 +1630,9 @@ def load_and_cache_examples( mode = "dev" if evaluate else "train" if self.args.use_hf_datasets: - dataset = load_hf_dataset(file_path, tokenizer, self.args) + dataset = load_hf_dataset( + file_path, tokenizer, self.args, retrieval_model=self.retrieval_model + ) return dataset elif args.dataset_class: CustomDataset = args.dataset_class @@ -1629,7 +1872,8 @@ def _threshold(self, x, threshold): return 0 def _move_model_to_device(self): - self.model.to(self.device) + if not self.args.qlora and not self.args.nf4: + self.model.to(self.device) def _create_training_progress_scores(self, **kwargs): extra_metrics = {key: [] for key in kwargs} diff --git a/simpletransformers/language_modeling/language_modeling_utils.py b/simpletransformers/language_modeling/language_modeling_utils.py index 09264119..7ae9fb7b 100644 --- a/simpletransformers/language_modeling/language_modeling_utils.py +++ b/simpletransformers/language_modeling/language_modeling_utils.py @@ -3,6 +3,7 @@ import pickle from multiprocessing import Pool from typing import Tuple +import warnings import torch from torch.utils.data import Dataset @@ -74,13 +75,23 @@ def chunk_sequence(sequence, max_length): def preprocess_and_chunk_batch_for_hf_dataset( - dataset, tokenizer, max_seq_length, max_word_length=100 + dataset, tokenizer, max_seq_length, chunk_text=True ): - chunked_texts = [] - for text in dataset["text"]: - chunks = chunk_sequence(text, max_seq_length) - for chunk in chunks: - chunked_texts.append(chunk) + if chunk_text: + chunked_texts = [] + for text in dataset["text"]: + chunks = chunk_sequence(text, max_seq_length) + for chunk in chunks: + chunked_texts.append(chunk) + + logger.info( + "Chunked %d examples into %d chunks with a maximum length of %d.", + len(dataset["text"]), + len(chunked_texts), + max_seq_length, + ) + else: + chunked_texts = dataset["text"] return tokenizer( text=chunked_texts, @@ -99,14 +110,15 @@ def preprocess_batch_for_hf_dataset(dataset, tokenizer, max_seq_length): ) -def load_hf_dataset(data, tokenizer, args): +def load_hf_dataset(data, tokenizer, args, retrieval_model=None): if args.data_format == "text": dataset = load_dataset( + "text", data_files=data, download_mode="force_redownload" if args.reprocess_input_data else "reuse_dataset_if_exists", - streaming=True, + streaming=True if args.stream_hf_datasets else False, ) elif args.data_format == "tsv": dataset = load_dataset( @@ -116,28 +128,94 @@ def load_hf_dataset(data, tokenizer, args): download_mode="force_redownload" if args.reprocess_input_data else "reuse_dataset_if_exists", - streaming=True, + streaming=True if args.stream_hf_datasets else False, + ) + elif args.data_format == "json" or args.data_format == "jsonl": + dataset = load_dataset( + "json", + data_files=data, + download_mode="force_redownload" + if args.reprocess_input_data + else "reuse_dataset_if_exists", + streaming=True if args.stream_hf_datasets else False, ) else: raise ValueError("args.data_format must be either 'text' or 'tsv'") - dataset = dataset.map( - lambda x: preprocess_and_chunk_batch_for_hf_dataset( - x, tokenizer=tokenizer, max_seq_length=args.max_seq_length - ), - batched=True, - remove_columns=["text"], - ) - - # dataset = dataset.with_format( - # type="pt", columns=["input_ids", "token_type_ids", "attention_mask"] - # ) + if retrieval_model: + if retrieval_model.prediction_passages is None: + raise ValueError( + "The RetrievalModel must be initialized with prediction_passages to use it for RAG training." + ) + dataset = dataset["train"] + logger.info("Retrieving context documents for RAG training.") + rag_queries = dataset["rag_query"] + context_docs = retrieval_model.predict(rag_queries, passages_only=True) + retrieval_model.context_encoder.to("cpu") + retrieval_model.query_encoder.to("cpu") + context_docs = [" ".join(docs) for docs in context_docs] + + dataset = dataset.add_column("context", context_docs) + + logger.info("Merging context documents with the original text.") + + def batch_process(examples): + # Concatenate "context" and "text" for each example in the batch + concatenated_texts = [ + context + " " + text + for context, text in zip(examples["context"], examples["text"]) + ] + return {"text": concatenated_texts} + + # Apply the batch processing function to the dataset + dataset = dataset.map(batch_process, batched=True) + + logger.info("Merged context documents with the original text.") + + try: + dataset = dataset.map( + lambda x: preprocess_and_chunk_batch_for_hf_dataset( + x, + tokenizer=tokenizer, + max_seq_length=args.max_seq_length, + chunk_text=args.chunk_text, + ), + batched=True, + remove_columns=["text"], + ) + except ValueError: + if not tokenizer.pad_token: + warnings.warn( + "The tokenizer you are using does not have a pad token set. Setting to 'tokenizer.eos_token'" + ) + tokenizer.pad_token = tokenizer.eos_token + dataset = dataset.map( + lambda x: preprocess_and_chunk_batch_for_hf_dataset( + x, + tokenizer=tokenizer, + max_seq_length=args.max_seq_length, + chunk_text=args.chunk_text, + ), + batched=True, + remove_columns=["text"], + ) - if isinstance(data, str): + try: # This is not necessarily a train dataset. The datasets library insists on calling it train. - return dataset["train"] + dataset = dataset["train"] + except: + pass + + if "token_type_ids" in dataset.features: + dataset = dataset.with_format( + type="pt", columns=["input_ids", "token_type_ids", "attention_mask"] + ) else: - return dataset + dataset = dataset.with_format( + type="pt", columns=["input_ids", "attention_mask"] + ) + + return dataset class SimpleDataset(Dataset): diff --git a/simpletransformers/ner/ner_model.py b/simpletransformers/ner/ner_model.py index 19fdd6dc..7e3a02f9 100755 --- a/simpletransformers/ner/ner_model.py +++ b/simpletransformers/ner/ner_model.py @@ -1036,8 +1036,7 @@ def train( epoch_number += 1 output_dir_current = os.path.join( output_dir, - "checkpoint-{}-epoch-{}".format(global_step, epoch_number) -, + "checkpoint-{}-epoch-{}".format(global_step, epoch_number), ) if args.save_model_every_epoch or args.evaluate_during_training: diff --git a/simpletransformers/pretrain_retrieval/pretrain_retrieval_model.py b/simpletransformers/pretrain_retrieval/pretrain_retrieval_model.py index 216414c9..40feae54 100644 --- a/simpletransformers/pretrain_retrieval/pretrain_retrieval_model.py +++ b/simpletransformers/pretrain_retrieval/pretrain_retrieval_model.py @@ -879,8 +879,7 @@ def train( epoch_number += 1 output_dir_current = os.path.join( output_dir, - "checkpoint-{}-epoch-{}".format(global_step, epoch_number) -, + "checkpoint-{}-epoch-{}".format(global_step, epoch_number), ) if clustered_training: diff --git a/simpletransformers/question_answering/question_answering_model.py b/simpletransformers/question_answering/question_answering_model.py index 3c3ee72c..7aa363fa 100755 --- a/simpletransformers/question_answering/question_answering_model.py +++ b/simpletransformers/question_answering/question_answering_model.py @@ -938,8 +938,7 @@ def train( epoch_number += 1 output_dir_current = os.path.join( output_dir, - "checkpoint-{}-epoch-{}".format(global_step, epoch_number) -, + "checkpoint-{}-epoch-{}".format(global_step, epoch_number), ) if args.save_model_every_epoch or args.evaluate_during_training: diff --git a/simpletransformers/retrieval/retrieval_model.py b/simpletransformers/retrieval/retrieval_model.py index c75d5581..6c653b53 100644 --- a/simpletransformers/retrieval/retrieval_model.py +++ b/simpletransformers/retrieval/retrieval_model.py @@ -316,7 +316,7 @@ def __init__( query_encoder_name, **self.args.query_config ) # if self.args.query_config.get("projection_dim") is not None: - # query_encoder._keys_to_ignore_on_load_missing.append("encode_proj") + # query_encoder._keys_to_ignore_on_load_missing.append("encode_proj") self.query_encoder = query_encoder.from_pretrained( query_encoder_name, config=self.query_config ) @@ -1608,7 +1608,11 @@ def eval_model( else: custom_mrr = False - evaluator = pytrec_eval.RelevanceEvaluator(qrels_dict, pytrec_eval_metrics, relevance_level=self.args.relevance_level) + evaluator = pytrec_eval.RelevanceEvaluator( + qrels_dict, + pytrec_eval_metrics, + relevance_level=self.args.relevance_level, + ) try: results = evaluator.evaluate(run_dict) @@ -2208,7 +2212,11 @@ def predict( ) else: doc_ids, doc_vectors, doc_dicts = retrieval_outputs - passages = [d["passages"] for d in doc_dicts] + + try: + passages = [d["passages"] for d in doc_dicts] + except KeyError: + passages = [d["passage_text"] for d in doc_dicts] if self.args.unified_rr: rerank_similarity = compute_rerank_similarity( @@ -2416,7 +2424,10 @@ def retrieve_docs_from_query_embeddings( passages_only=True, ) - passages.extend([d["passages"] for d in doc_dicts_batch]) + try: + passages.extend([d["passages"] for d in doc_dicts_batch]) + except KeyError: + passages.extend([d["passage_text"] for d in doc_dicts_batch]) return passages elif doc_ids_only: diff --git a/simpletransformers/retrieval/retrieval_tools.py b/simpletransformers/retrieval/retrieval_tools.py index 8a58714a..b223ca5a 100644 --- a/simpletransformers/retrieval/retrieval_tools.py +++ b/simpletransformers/retrieval/retrieval_tools.py @@ -184,8 +184,6 @@ def generate_flipped_latex_row( return row - - def generate_latex_table( results, all_metrics, @@ -387,8 +385,6 @@ def generate_flipped_latex_table( return table - - def analyze_experiment( experiment_dir, model_name_map=None, @@ -516,10 +512,10 @@ def calculate_significance(run_a, run_b): return pvalue -def convert_trec_queries_to_beir_format(trec_queries_path, beir_queries_path=None, save=True): - trec_queries_df = pd.read_csv( - trec_queries_path, sep="\t", names=["_id", "text"] - ) +def convert_trec_queries_to_beir_format( + trec_queries_path, beir_queries_path=None, save=True +): + trec_queries_df = pd.read_csv(trec_queries_path, sep="\t", names=["_id", "text"]) trec_queries_df["_id"] = trec_queries_df["_id"].astype(str) trec_queries_df["text"] = trec_queries_df["text"].astype(str) if beir_queries_path is None: @@ -532,6 +528,7 @@ def convert_trec_queries_to_beir_format(trec_queries_path, beir_queries_path=Non return trec_queries_df + def convert_trec_qrels_to_beir_format(trec_qrels_path, beir_qrels_path=None, save=True): trec_qrel_df = pd.read_csv( trec_qrels_path, sep=" ", names=["query-id", "Q0", "corpus-id", "score"] @@ -544,7 +541,9 @@ def convert_trec_qrels_to_beir_format(trec_qrels_path, beir_qrels_path=None, sav # Extension for trec_qrels_path can be .txt or .tsv beir_qrels_path = os.path.join( os.path.dirname(trec_qrels_path), - os.path.basename(trec_qrels_path).replace(".txt", ".tsv").replace(".tsv", ".tsv"), + os.path.basename(trec_qrels_path) + .replace(".txt", ".tsv") + .replace(".tsv", ".tsv"), ) beir_qrels_df = trec_qrel_df[["query-id", "corpus-id", "score"]] diff --git a/simpletransformers/retrieval/retrieval_utils.py b/simpletransformers/retrieval/retrieval_utils.py index 450e1163..c2ec8974 100644 --- a/simpletransformers/retrieval/retrieval_utils.py +++ b/simpletransformers/retrieval/retrieval_utils.py @@ -97,7 +97,9 @@ def load_hf_dataset( ) dataset = dataset.map( lambda example: { - "gold_passage": example["title"] + " " + example["gold_passage"] if example["title"] is not None else example["gold_passage"] + "gold_passage": example["title"] + " " + example["gold_passage"] + if example["title"] is not None + else example["gold_passage"] } ) @@ -935,7 +937,6 @@ def get_prediction_passage_dataset( "csv", data_files=prediction_passages, delimiter="\t", - column_names=["passages"], cache_dir=args.dataset_cache_dir, ) prediction_passages_dataset = prediction_passages_dataset["train"] @@ -946,7 +947,7 @@ def get_prediction_passage_dataset( ) prediction_passages_dataset = prediction_passages_dataset.map( lambda example: { - "gold_passage": example["title"] + " " + example["gold_passage"] + "passages": example["title"] + " " + example["gold_passage"] } # Should these be "passage" instead of "gold_passage"? ) elif isinstance(prediction_passages, list): @@ -962,7 +963,7 @@ def get_prediction_passage_dataset( ) prediction_passages_dataset = prediction_passages_dataset.map( lambda example: { - "gold_passage": example["title"] + " " + example["gold_passage"] + "passages": example["title"] + " " + example["gold_passage"] } ) @@ -1086,7 +1087,9 @@ def get_top_docs( def get_top_doc_ids( self, question_hidden_states, n_docs=5, reranking_query_outputs=None ): - scores, ids = self.dataset.search_batch("embeddings", question_hidden_states, n_docs) + scores, ids = self.dataset.search_batch( + "embeddings", question_hidden_states, n_docs + ) docs = [self.dataset[[i for i in indices if i >= 0]] for indices in ids] doc_ids = [doc["passage_id"] for doc in docs] diff --git a/simpletransformers/seq2seq/seq2seq_model.py b/simpletransformers/seq2seq/seq2seq_model.py index 2daf9b8e..ad0f1e23 100644 --- a/simpletransformers/seq2seq/seq2seq_model.py +++ b/simpletransformers/seq2seq/seq2seq_model.py @@ -976,8 +976,7 @@ def train( epoch_number += 1 output_dir_current = os.path.join( output_dir, - "checkpoint-{}-epoch-{}".format(global_step, epoch_number) -, + "checkpoint-{}-epoch-{}".format(global_step, epoch_number), ) if args.save_model_every_epoch or args.evaluate_during_training: diff --git a/simpletransformers/t5/t5_model.py b/simpletransformers/t5/t5_model.py index 8091a4f4..602c4f8a 100644 --- a/simpletransformers/t5/t5_model.py +++ b/simpletransformers/t5/t5_model.py @@ -754,8 +754,7 @@ def train( epoch_number += 1 output_dir_current = os.path.join( output_dir, - "checkpoint-{}-epoch-{}".format(global_step, epoch_number) -, + "checkpoint-{}-epoch-{}".format(global_step, epoch_number), ) if args.save_model_every_epoch or args.evaluate_during_training: @@ -1167,14 +1166,16 @@ def rerank(self, eval_data, qrels): if args.n_gpu > 1: model = torch.nn.DataParallel(model) - eval_iterator = tqdm( - eval_dataloader, desc="Evaluating", disable=args.silent - ) + eval_iterator = tqdm(eval_dataloader, desc="Evaluating", disable=args.silent) reranking_preds = [] reranking_scores = [] - true_token_idx = self.tokenizer("true", add_special_tokens=False)["input_ids"][0] - false_token_idx = self.tokenizer("false", add_special_tokens=False)["input_ids"][0] + true_token_idx = self.tokenizer("true", add_special_tokens=False)["input_ids"][ + 0 + ] + false_token_idx = self.tokenizer("false", add_special_tokens=False)[ + "input_ids" + ][0] for batch in eval_iterator: inputs = self._get_inputs_dict(batch) @@ -1186,7 +1187,7 @@ def rerank(self, eval_data, qrels): encoder_outputs=inputs["encoder_outputs"], max_new_tokens=self.args.max_length, output_scores=True, - return_dict_in_generate=True + return_dict_in_generate=True, ) else: outputs = self.model.generate( @@ -1194,10 +1195,12 @@ def rerank(self, eval_data, qrels): attention_mask=inputs["attention_mask"], max_new_tokens=self.args.max_length, output_scores=True, - return_dict_in_generate=True + return_dict_in_generate=True, ) - preds, scores = self._get_reranking_outputs(outputs, true_token_idx, false_token_idx) + preds, scores = self._get_reranking_outputs( + outputs, true_token_idx, false_token_idx + ) reranking_preds.extend(preds) reranking_scores.extend(scores) @@ -1216,12 +1219,16 @@ def rerank(self, eval_data, qrels): # Sort by score for query_id in run_dict: run_dict[query_id] = dict( - sorted(run_dict[query_id].items(), key=lambda item: item[1], reverse=True) + sorted( + run_dict[query_id].items(), key=lambda item: item[1], reverse=True + ) ) os.makedirs(args.output_dir, exist_ok=True) - runfile_save_path = os.path.join(args.output_dir, f"{eval_data.split('/')[-1]}-runfile.json") + runfile_save_path = os.path.join( + args.output_dir, f"{eval_data.split('/')[-1]}-runfile.json" + ) with open(runfile_save_path, "w") as f: json.dump(run_dict, f) @@ -1334,7 +1341,14 @@ def _get_inputs_dict(self, batch): return inputs def load_and_cache_examples( - self, data, evaluate=False, no_cache=False, verbose=True, silent=False, tokenize_targets=True, reranking=False + self, + data, + evaluate=False, + no_cache=False, + verbose=True, + silent=False, + tokenize_targets=True, + reranking=False, ): """ Creates a T5Dataset from data. @@ -1354,7 +1368,13 @@ def load_and_cache_examples( mode = "dev" if evaluate else "train" if self.args.use_hf_datasets: - dataset = load_hf_dataset(data, tokenizer, self.args, tokenize_targets=tokenize_targets, reranking=reranking) + dataset = load_hf_dataset( + data, + tokenizer, + self.args, + tokenize_targets=tokenize_targets, + reranking=reranking, + ) return dataset elif args.dataset_class: CustomDataset = args.dataset_class diff --git a/simpletransformers/t5/t5_utils.py b/simpletransformers/t5/t5_utils.py index e290fcc1..5ab9730b 100644 --- a/simpletransformers/t5/t5_utils.py +++ b/simpletransformers/t5/t5_utils.py @@ -92,7 +92,9 @@ def load_hf_dataset(data, tokenizer, args, tokenize_targets=True, reranking=Fals "input_text": datasets.Value("string"), "target_text": datasets.Value("string"), } - ) if args.add_prefix else datasets.Features( + ) + if args.add_prefix + else datasets.Features( { "input_text": datasets.Value("string"), "target_text": datasets.Value("string"), @@ -105,7 +107,9 @@ def load_hf_dataset(data, tokenizer, args, tokenize_targets=True, reranking=Fals # tokenize_targets = not (evaluate and args.model_type == "eet5") dataset = dataset.map( - lambda x: preprocess_batch_for_hf_dataset(x, tokenizer=tokenizer, args=args, tokenize_targets=tokenize_targets), + lambda x: preprocess_batch_for_hf_dataset( + x, tokenizer=tokenizer, args=args, tokenize_targets=tokenize_targets + ), batched=True, ) @@ -118,11 +122,17 @@ def load_hf_dataset(data, tokenizer, args, tokenize_targets=True, reranking=Fals if tokenize_targets: dataset.set_format( type="pt", - columns=["input_ids", "attention_mask", "labels", "encoder_outputs"], + columns=[ + "input_ids", + "attention_mask", + "labels", + "encoder_outputs", + ], ) else: dataset.set_format( - type="pt", columns=["input_ids", "attention_mask", "encoder_outputs"] + type="pt", + columns=["input_ids", "attention_mask", "encoder_outputs"], ) else: if tokenize_targets: @@ -131,9 +141,7 @@ def load_hf_dataset(data, tokenizer, args, tokenize_targets=True, reranking=Fals columns=["input_ids", "attention_mask", "labels"], ) else: - dataset.set_format( - type="pt", columns=["input_ids", "attention_mask"] - ) + dataset.set_format(type="pt", columns=["input_ids", "attention_mask"]) else: dataset.set_format(type="pt", columns=["input_ids", "attention_mask", "labels"]) diff --git a/train.txt b/train.txt new file mode 100644 index 00000000..721c1e19 --- /dev/null +++ b/train.txt @@ -0,0 +1,100 @@ +Hello world with Simple Transformers! +Hello world with Simple Transformers! +Hello world with Simple Transformers! +Hello world with Simple Transformers! +Hello world with Simple Transformers! +Hello world with Simple Transformers! +Hello world with Simple Transformers! +Hello world with Simple Transformers! +Hello world with Simple Transformers! +Hello world with Simple Transformers! +Hello world with Simple Transformers! +Hello world with Simple Transformers! +Hello world with Simple Transformers! +Hello world with Simple Transformers! +Hello world with Simple Transformers! +Hello world with Simple Transformers! +Hello world with Simple Transformers! +Hello world with Simple Transformers! +Hello world with Simple Transformers! +Hello world with Simple Transformers! +Hello world with Simple Transformers! +Hello world with Simple Transformers! +Hello world with Simple Transformers! +Hello world with Simple Transformers! +Hello world with Simple Transformers! +Hello world with Simple Transformers! +Hello world with Simple Transformers! +Hello world with Simple Transformers! +Hello world with Simple Transformers! +Hello world with Simple Transformers! +Hello world with Simple Transformers! +Hello world with Simple Transformers! +Hello world with Simple Transformers! +Hello world with Simple Transformers! +Hello world with Simple Transformers! +Hello world with Simple Transformers! +Hello world with Simple Transformers! +Hello world with Simple Transformers! +Hello world with Simple Transformers! +Hello world with Simple Transformers! +Hello world with Simple Transformers! +Hello world with Simple Transformers! +Hello world with Simple Transformers! +Hello world with Simple Transformers! +Hello world with Simple Transformers! +Hello world with Simple Transformers! +Hello world with Simple Transformers! +Hello world with Simple Transformers! +Hello world with Simple Transformers! +Hello world with Simple Transformers! +Hello world with Simple Transformers! +Hello world with Simple Transformers! +Hello world with Simple Transformers! +Hello world with Simple Transformers! +Hello world with Simple Transformers! +Hello world with Simple Transformers! +Hello world with Simple Transformers! +Hello world with Simple Transformers! +Hello world with Simple Transformers! +Hello world with Simple Transformers! +Hello world with Simple Transformers! +Hello world with Simple Transformers! +Hello world with Simple Transformers! +Hello world with Simple Transformers! +Hello world with Simple Transformers! +Hello world with Simple Transformers! +Hello world with Simple Transformers! +Hello world with Simple Transformers! +Hello world with Simple Transformers! +Hello world with Simple Transformers! +Hello world with Simple Transformers! +Hello world with Simple Transformers! +Hello world with Simple Transformers! +Hello world with Simple Transformers! +Hello world with Simple Transformers! +Hello world with Simple Transformers! +Hello world with Simple Transformers! +Hello world with Simple Transformers! +Hello world with Simple Transformers! +Hello world with Simple Transformers! +Hello world with Simple Transformers! +Hello world with Simple Transformers! +Hello world with Simple Transformers! +Hello world with Simple Transformers! +Hello world with Simple Transformers! +Hello world with Simple Transformers! +Hello world with Simple Transformers! +Hello world with Simple Transformers! +Hello world with Simple Transformers! +Hello world with Simple Transformers! +Hello world with Simple Transformers! +Hello world with Simple Transformers! +Hello world with Simple Transformers! +Hello world with Simple Transformers! +Hello world with Simple Transformers! +Hello world with Simple Transformers! +Hello world with Simple Transformers! +Hello world with Simple Transformers! +Hello world with Simple Transformers! +Hello world with Simple Transformers!