diff --git a/CHANGELOG.md b/CHANGELOG.md
index f360bd9b..5b7867f7 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -5,7 +5,14 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
-## [0.65.0] - 2025-01-31
+## [0.70.0] - 2024-02-15
+
+### Added
+
+- Added support for using LLMs with `LanguageModelingModel` (with or without quantization)
+- Added support for adapter tuning LLMs
+
+## [0.65.0] - 2024-01-31
- Lots of QOL improvements
- Added support for evaluating retrieval models with `pytrec_eval`
diff --git a/README.md b/README.md
index 867f8fb5..90184be9 100755
--- a/README.md
+++ b/README.md
@@ -9,17 +9,17 @@ This library is based on the [Transformers](https://github.com/huggingface/trans
**Supported Tasks:**
+- Information Retrieval (Dense Retrieval)
+- (Large) Language Models (Training, Fine-tuning, and Generation)
+- Encoder Model Training and Fine-tuning
- Sequence Classification
- Token Classification (NER)
- Question Answering
-- Language Model Fine-Tuning
-- Language Model Training
- Language Generation
- T5 Model
- Seq2Seq Tasks
- Multi-Modal Classification
- Conversational AI.
-- Text Representation Generation.
# Table of contents
diff --git a/examples/llms/download_squad.ipynb b/examples/llms/download_squad.ipynb
new file mode 100644
index 00000000..ed9d3077
--- /dev/null
+++ b/examples/llms/download_squad.ipynb
@@ -0,0 +1,375 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import datasets"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "bfc6efd35877482fab97a58798a178be",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Downloading readme: 0%| | 0.00/7.83k [00:00, ?B/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "b457b4b3b2eb4fa1a1e12b17731baea6",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Downloading data files: 0%| | 0/2 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "3f38a9d9ea4f4d658c68d790e2137930",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Downloading data: 0%| | 0.00/14.5M [00:00, ?B/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "9942411e672c41038f5886d0bec29322",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Downloading data: 0%| | 0.00/1.82M [00:00, ?B/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "cbaf88354da04305bd3614875e301e23",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Extracting data files: 0%| | 0/2 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "c6c6a0c016b348d89495411c1d880f56",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Generating train split: 0%| | 0/87599 [00:00, ? examples/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "d89f137ce70440bb89364597d18e52e4",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Generating validation split: 0%| | 0/10570 [00:00, ? examples/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "DatasetDict({\n",
+ " train: Dataset({\n",
+ " features: ['id', 'title', 'context', 'question', 'answers'],\n",
+ " num_rows: 87599\n",
+ " })\n",
+ " validation: Dataset({\n",
+ " features: ['id', 'title', 'context', 'question', 'answers'],\n",
+ " num_rows: 10570\n",
+ " })\n",
+ "})\n"
+ ]
+ }
+ ],
+ "source": [
+ "dataset = datasets.load_dataset(\"squad\")\n",
+ "print(dataset)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "train_data = dataset[\"train\"]\n",
+ "eval_data = dataset[\"validation\"]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "train_texts = [f\"Question: {q} Answer: {a['text'][0]}\" for q, a in zip(train_data[\"question\"], train_data[\"answers\"])]\n",
+ "eval_texts = [f\"Question: {q} Answer: \" for q, a in zip(eval_data[\"question\"], eval_data[\"answers\"])]\n",
+ "eval_answers = [a[\"text\"][0] for a in eval_data[\"answers\"]]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "'Question: To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France? Answer: Saint Bernadette Soubirous'"
+ ]
+ },
+ "execution_count": 17,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "train_texts[0]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "'Question: Which NFL team represented the AFC at Super Bowl 50? Answer: Denver Broncos'"
+ ]
+ },
+ "execution_count": 18,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "eval_texts[0]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 20,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import json\n",
+ "import pandas as pd"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " text | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " Question: To whom did the Virgin Mary allegedl... | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " Question: What is in front of the Notre Dame M... | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " Question: The Basilica of the Sacred heart at ... | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " Question: What is the Grotto at Notre Dame? An... | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " Question: What sits on top of the Main Buildin... | \n",
+ "
\n",
+ " \n",
+ " ... | \n",
+ " ... | \n",
+ "
\n",
+ " \n",
+ " 87594 | \n",
+ " Question: In what US state did Kathmandu first... | \n",
+ "
\n",
+ " \n",
+ " 87595 | \n",
+ " Question: What was Yangon previously known as?... | \n",
+ "
\n",
+ " \n",
+ " 87596 | \n",
+ " Question: With what Belorussian city does Kath... | \n",
+ "
\n",
+ " \n",
+ " 87597 | \n",
+ " Question: In what year did Kathmandu create it... | \n",
+ "
\n",
+ " \n",
+ " 87598 | \n",
+ " Question: What is KMC an initialism of? Answer... | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
87599 rows × 1 columns
\n",
+ "
"
+ ],
+ "text/plain": [
+ " text\n",
+ "0 Question: To whom did the Virgin Mary allegedl...\n",
+ "1 Question: What is in front of the Notre Dame M...\n",
+ "2 Question: The Basilica of the Sacred heart at ...\n",
+ "3 Question: What is the Grotto at Notre Dame? An...\n",
+ "4 Question: What sits on top of the Main Buildin...\n",
+ "... ...\n",
+ "87594 Question: In what US state did Kathmandu first...\n",
+ "87595 Question: What was Yangon previously known as?...\n",
+ "87596 Question: With what Belorussian city does Kath...\n",
+ "87597 Question: In what year did Kathmandu create it...\n",
+ "87598 Question: What is KMC an initialism of? Answer...\n",
+ "\n",
+ "[87599 rows x 1 columns]"
+ ]
+ },
+ "execution_count": 21,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "train_df = pd.DataFrame(train_texts, columns=[\"text\"])\n",
+ "eval_df = pd.DataFrame({\"text\": eval_texts, \"answer\": eval_answers})\n",
+ "\n",
+ "train_df"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 22,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "train_df.to_json(\"../data/squad-train.jsonl\", orient=\"records\", lines=True)\n",
+ "eval_df.to_json(\"../data/squad-eval.jsonl\", orient=\"records\", lines=True)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 28,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "train_df[\"rag_query\"] = train_data[\"question\"]\n",
+ "eval_df[\"rag_query\"] = eval_data[\"question\"]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 29,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "train_df.to_json(\"../data/squad-train.jsonl\", orient=\"records\", lines=True)\n",
+ "eval_df.to_json(\"../data/squad-eval.jsonl\", orient=\"records\", lines=True)\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "st",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.13"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/examples/llms/train.py b/examples/llms/train.py
new file mode 100644
index 00000000..c8eb2bf2
--- /dev/null
+++ b/examples/llms/train.py
@@ -0,0 +1,98 @@
+import logging
+import pandas as pd
+
+from simpletransformers.language_modeling import LanguageModelingModel, LanguageModelingArgs, GenerationArgs
+from simpletransformers.retrieval import (
+ RetrievalModel,
+ RetrievalArgs,
+)
+
+
+logging.basicConfig(level=logging.INFO)
+transformers_logger = logging.getLogger("transformers")
+transformers_logger.setLevel(logging.WARNING)
+trainsformers_modules_logger = logging.getLogger("transformers_modules")
+trainsformers_modules_logger.setLevel(logging.ERROR)
+
+
+rag = True
+
+if rag:
+ from rag_setup import model as rag_model
+
+
+train_file = "../data/squad-train.jsonl"
+
+
+model_args = LanguageModelingArgs()
+model_args.reprocess_input_data = True
+model_args.overwrite_output_dir = True
+model_args.num_train_epochs = 1
+model_args.save_eval_checkpoints = False
+model_args.save_model_every_epoch = False
+model_args.train_batch_size = 2
+model_args.eval_batch_size = 4
+model_args.gradient_accumulation_steps = 1
+model_args.manual_seed = 4
+model_args.fp16 = True
+model_args.dataset_type = "simple"
+model_args.logging_steps = 100
+model_args.evaluate_during_training = False
+model_args.mlm = False
+model_args.use_multiprocessing = False
+model_args.use_hf_datasets = True
+model_args.peft = True
+model_args.qlora = False
+model_args.nf4 = True
+model_args.loftq_bits = 4
+model_args.lora_config = {"r": 8}
+model_args.data_format = "jsonl"
+model_args.trust_remote_code = True
+model_args.save_steps = 1000
+model_args.optimizer = "Adam8bit"
+model_args.chunk_text = False
+model_args.max_seq_length = 500
+
+if not rag:
+ model_args.wandb_project = "llama-adapter-tuning-squad"
+ model_args.wandb_kwargs = {"name": "squad-llama-2-7b-vanilla"}
+ model_args.output_dir = "squad-llama-2-7b"
+
+if rag:
+ model_args.rag = rag
+ model_args.wandb_project = "llama-adapter-tuning-squad"
+ model_args.wandb_kwargs = {"name": "squad-llama-2-7b-rag"}
+ model_args.output_dir = "squad-llama-2-7b-rag"
+
+
+model = LanguageModelingModel(
+ "causal",
+ # "outputs",
+ # "stabilityai/stablelm-zephyr-3b",
+ "meta-llama/Llama-2-7b-hf",
+ args=model_args,
+ retrieval_model=rag_model if rag else None,
+)
+
+model.train_model(
+ train_file,
+)
+
+# generation_args = GenerationArgs()
+# generation_args.max_length = None
+# generation_args.max_new_tokens = 100
+
+# test_df = pd.read_json("data/test.jsonl", lines=True)
+
+# to_predict = test_df["input_text"].tolist()
+
+# responses, _ = model.predict(
+# to_predict,
+# generation_args=generation_args,
+# )
+
+# print(responses[:5])
+
+# test_df["generated_text"] = responses
+
+# test_df.to_json("data/test_output-finetuned.jsonl", orient="records", lines=True)
\ No newline at end of file
diff --git a/examples/retrieval/download_msmarco.py b/examples/retrieval/download_msmarco.py
new file mode 100644
index 00000000..7976fcda
--- /dev/null
+++ b/examples/retrieval/download_msmarco.py
@@ -0,0 +1,22 @@
+import os
+from datasets import load_dataset
+
+
+os.makedirs("data/msmarco", exist_ok=True)
+
+print("=== Downloading MSMARCO ===")
+print("Downloading MSMARCO training triples...")
+dataset = load_dataset("thilina/negative-sampling")["train"]
+
+print("Dataset loaded. Sample:")
+print(dataset[0])
+
+qrels = load_dataset("BeIR/msmarco-qrels")["validation"]
+
+print("Saving dataset to disk...")
+# Save the dataset to disk
+dataset.to_csv("data/msmarco/msmarco-train.tsv", sep="\t", index=False)
+qrels.to_csv("data/msmarco/devs.tsv", sep="\t", index=False)
+
+print("Done.")
+print("=== MSMARCO download complete ===")
diff --git a/examples/retrieval/train_dpr_base.py b/examples/retrieval/train_dpr_base.py
new file mode 100644
index 00000000..c6562704
--- /dev/null
+++ b/examples/retrieval/train_dpr_base.py
@@ -0,0 +1,64 @@
+import logging
+
+import pandas as pd
+from simpletransformers.retrieval import RetrievalModel, RetrievalArgs
+
+# Configuring logging
+logging.basicConfig(level=logging.INFO)
+transformers_logger = logging.getLogger("transformers")
+transformers_logger.setLevel(logging.WARNING)
+
+# Specifying the path to the training data
+train_data_path = "data/msmarco/msmarco-train.tsv"
+
+# Loading the training data
+if train_data_path.endswith(".tsv"):
+ train_data = pd.read_csv(train_data_path, sep="\t")
+else:
+ train_data = train_data_path
+
+# Configuring the model arguments
+model_args = RetrievalArgs()
+model_args.reprocess_input_data = True
+model_args.overwrite_output_dir = True
+model_args.use_cached_eval_features = False
+model_args.include_title = False if "msmarco" in train_data_path else True
+model_args.max_seq_length = 256
+model_args.num_train_epochs = 40
+model_args.train_batch_size = 16
+model_args.use_hf_datasets = True
+model_args.learning_rate = 1e-6
+model_args.warmup_steps = 5000
+model_args.save_steps = 300000
+model_args.evaluate_during_training = True
+model_args.evaluate_during_training_steps = False
+model_args.save_model_every_epoch = False
+model_args.wandb_project = "Retrieval training example"
+model_args.hard_negatives = False
+model_args.n_gpu = 1
+model_args.data_format = "beir"
+model_args.output_dir = f"trained_models/pretrained/DPR-base-msmarco"
+model_args.wandb_kwargs = {"name": f"DPR-base-msmarco"}
+
+# Defining the model type and names
+model_type = "custom"
+model_name = None
+context_name = "bert-base-multilingual-cased"
+question_name = "bert-base-multilingual-cased"
+
+# Main execution
+if __name__ == "__main__":
+ # Creating the model
+ model = RetrievalModel(
+ model_type,
+ model_name,
+ context_name,
+ question_name,
+ args=model_args,
+ )
+
+ # Training the model
+ model.train_model(
+ train_data,
+ eval_set="dev",
+ )
diff --git a/setup.py b/setup.py
index 7d924cc3..657771f1 100755
--- a/setup.py
+++ b/setup.py
@@ -5,7 +5,7 @@
setup(
name="simpletransformers",
- version="0.65.1",
+ version="0.70.0",
author="Thilina Rajapakse",
author_email="chaturangarajapakshe@gmail.com",
description="An easy-to-use wrapper library for the Transformers library.",
diff --git a/simpletransformers/classification/classification_model.py b/simpletransformers/classification/classification_model.py
index 66b01238..6ea168b4 100755
--- a/simpletransformers/classification/classification_model.py
+++ b/simpletransformers/classification/classification_model.py
@@ -22,12 +22,14 @@
from scipy.special import softmax
from sklearn.metrics import (
confusion_matrix,
+ f1_score,
label_ranking_average_precision_score,
matthews_corrcoef,
mean_squared_error,
roc_curve,
auc,
average_precision_score,
+ accuracy_score,
)
from torch.utils.tensorboard import SummaryWriter
from torch.nn import CrossEntropyLoss
@@ -1155,8 +1157,7 @@ def train(
epoch_number += 1
output_dir_current = os.path.join(
output_dir,
- "checkpoint-{}-epoch-{}".format(global_step, epoch_number)
-,
+ "checkpoint-{}-epoch-{}".format(global_step, epoch_number),
)
if args.save_model_every_epoch or args.evaluate_during_training:
@@ -1924,6 +1925,8 @@ def compute_metrics(
return {**extra_metrics}, wrong
mcc = matthews_corrcoef(labels, preds)
+ accuracy = accuracy_score(labels, preds)
+ f1 = f1_score(labels, preds)
if self.model.num_labels == 2:
tn, fp, fn, tp = confusion_matrix(labels, preds, labels=[0, 1]).ravel()
if self.args.sliding_window:
@@ -1943,6 +1946,8 @@ def compute_metrics(
{
**{
"mcc": mcc,
+ "accuracy": accuracy,
+ "f1_score": f1,
"tp": tp,
"tn": tn,
"fp": fp,
diff --git a/simpletransformers/classification/multi_modal_classification_model.py b/simpletransformers/classification/multi_modal_classification_model.py
index b59acca7..9ec83919 100644
--- a/simpletransformers/classification/multi_modal_classification_model.py
+++ b/simpletransformers/classification/multi_modal_classification_model.py
@@ -814,8 +814,7 @@ def train(
epoch_number += 1
output_dir_current = os.path.join(
output_dir,
- "checkpoint-{}-epoch-{}".format(global_step, epoch_number)
-,
+ "checkpoint-{}-epoch-{}".format(global_step, epoch_number),
)
if args.save_model_every_epoch or args.evaluate_during_training:
diff --git a/simpletransformers/config/model_args.py b/simpletransformers/config/model_args.py
index 2ab9da85..7e55cc57 100644
--- a/simpletransformers/config/model_args.py
+++ b/simpletransformers/config/model_args.py
@@ -93,6 +93,7 @@ class ModelArgs:
tokenizer_type: str = None
train_batch_size: int = 8
train_custom_parameters_only: bool = False
+ trust_remote_code: bool = False
use_cached_eval_features: bool = False
use_early_stopping: bool = False
use_hf_datasets: bool = False
@@ -246,6 +247,36 @@ class T5Args(ModelArgs):
use_multiprocessed_decoding: bool = True
+@dataclass
+class GenerationArgs:
+ """
+ Args for language generation.
+ """
+
+ max_length: int = 20
+ max_new_tokens: int = None
+ min_length: int = 0
+ min_new_tokens: int = None
+ early_stopping: bool = False
+ max_time: float = None
+
+ do_sample: bool = False
+ num_beams: int = 1
+ num_beam_groups: int = 1
+ penalty_alpha: float = None
+ use_cache: bool = True
+
+ temperature: float = 1.0
+ top_k: int = 50
+ top_p: float = 1.0
+ repetition_penalty: float = 1.0
+ diversity_penalty: float = 0.0
+
+ def get_dict(self):
+ d = asdict(self)
+ return {k: v for k, v in d.items() if v is not None}
+
+
@dataclass
class LanguageModelingArgs(ModelArgs):
"""
@@ -254,6 +285,7 @@ class LanguageModelingArgs(ModelArgs):
model_class: str = "LanguageModelingModel"
block_size: int = -1
+ chunk_text: bool = True
config_name: str = None
dataset_class: Dataset = None
dataset_type: str = "None"
@@ -276,6 +308,14 @@ class LanguageModelingArgs(ModelArgs):
special_tokens_list: list = field(default_factory=list)
strip_accents: bool = True
local_rank: int = -1
+ loftq_bits: int = 4
+ loftq_config: dict = field(default_factory=dict)
+ lora_config: dict = field(default_factory=dict)
+ peft: bool = False
+ qlora: bool = False
+ rag: bool = False
+ rag_replace_method: str = "prepend"
+ nf4: bool = False
use_autoencoder: bool = False
stream_hf_datasets: bool = False
diff --git a/simpletransformers/conv_ai/conv_ai_model.py b/simpletransformers/conv_ai/conv_ai_model.py
index b8162f9f..f7a33da7 100644
--- a/simpletransformers/conv_ai/conv_ai_model.py
+++ b/simpletransformers/conv_ai/conv_ai_model.py
@@ -723,8 +723,7 @@ def train(
epoch_number += 1
output_dir_current = os.path.join(
output_dir,
- "checkpoint-{}-epoch-{}".format(global_step, epoch_number)
-,
+ "checkpoint-{}-epoch-{}".format(global_step, epoch_number),
)
if args.save_model_every_epoch or args.evaluate_during_training:
diff --git a/simpletransformers/language_modeling/__init__.py b/simpletransformers/language_modeling/__init__.py
index bde1ccf1..264f6e5e 100755
--- a/simpletransformers/language_modeling/__init__.py
+++ b/simpletransformers/language_modeling/__init__.py
@@ -1,4 +1,4 @@
-from simpletransformers.config.model_args import LanguageModelingArgs
+from simpletransformers.config.model_args import LanguageModelingArgs, GenerationArgs
from simpletransformers.language_modeling.language_modeling_model import (
LanguageModelingModel,
)
diff --git a/simpletransformers/language_modeling/language_modeling_model.py b/simpletransformers/language_modeling/language_modeling_model.py
index c9cdb307..31dea2f4 100755
--- a/simpletransformers/language_modeling/language_modeling_model.py
+++ b/simpletransformers/language_modeling/language_modeling_model.py
@@ -49,6 +49,7 @@
from simpletransformers.custom_models.models import RobertaWithAutoEncoderForMaskedLM
from transformers import DummyObject, requires_backends
+from torch.utils.data import DataLoader, TensorDataset
class NystromformerTokenizer(metaclass=DummyObject):
@@ -63,6 +64,7 @@ def __init__(self, *args, **kwargs):
AutoConfig,
AutoModelWithLMHead,
AutoTokenizer,
+ AutoModelForCausalLM,
BertConfig,
BertForMaskedLM,
BertTokenizer,
@@ -102,6 +104,7 @@ def __init__(self, *args, **kwargs):
XLMRobertaConfig,
XLMRobertaForMaskedLM,
XLMRobertaTokenizer,
+ GenerationConfig,
)
from transformers.data.datasets.language_modeling import (
LineByLineTextDataset,
@@ -109,7 +112,7 @@ def __init__(self, *args, **kwargs):
)
from simpletransformers.config.global_args import global_args
-from simpletransformers.config.model_args import LanguageModelingArgs
+from simpletransformers.config.model_args import LanguageModelingArgs, GenerationArgs
from simpletransformers.config.utils import sweep_config_to_sweep_values
from simpletransformers.custom_models.models import ElectraForLanguageModelingModel
from simpletransformers.language_modeling.language_modeling_utils import (
@@ -132,6 +135,7 @@ def __init__(self, *args, **kwargs):
"bert": (BertConfig, BertForMaskedLM, BertTokenizer),
"bigbird": (BigBirdConfig, BigBirdForMaskedLM, BigBirdTokenizer),
"camembert": (CamembertConfig, CamembertForMaskedLM, CamembertTokenizer),
+ "causal": (AutoConfig, AutoModelForCausalLM, AutoTokenizer),
"distilbert": (DistilBertConfig, DistilBertForMaskedLM, DistilBertTokenizer),
"electra": (ElectraConfig, ElectraForLanguageModelingModel, ElectraTokenizer),
"gpt2": (GPT2Config, GPT2LMHeadModel, GPT2Tokenizer),
@@ -159,7 +163,9 @@ def __init__(
train_files=None,
args=None,
use_cuda=True,
- autoencoder_model=None,
+ retrieval_model=None,
+ adapter_name=None,
+ # autoencoder_model=None,
cuda_device=-1,
**kwargs,
):
@@ -173,6 +179,8 @@ def __init__(
discriminator_name (optional): A pretrained model name or path to a directory containing an ELECTRA discriminator model.
args (optional): Default args will be used if this parameter is not provided. If provided, it should be a dict containing the args that should be changed in the default args.
train_files (optional): List of files to be used when training the tokenizer.
+ rag_corpus (optional): A collection of documents to be used for Retrieval-Augmented Generation. This may
+ retrieval_model (optional): A pretrained model name or path to a directory containing a retrieval model. This should be preloaded with a knowledge index.
use_cuda (optional): Use GPU if available. Setting to False will force model to use CPU only.
cuda_device (optional): Specific GPU that should be used. Will use the first available GPU by default.
**kwargs (optional): For providing proxies, force_download, resume_download, cache_dir and other options specific to the 'from_pretrained' implementation where this will be supplied.
@@ -258,11 +266,17 @@ def __init__(
if self.args.config_name:
self.config = config_class.from_pretrained(
- self.args.config_name, cache_dir=self.args.cache_dir
+ self.args.config_name,
+ cache_dir=self.args.cache_dir,
+ trust_remote_code=self.args.trust_remote_code,
+ **kwargs,
)
elif self.args.model_name and self.args.model_name != "electra":
self.config = config_class.from_pretrained(
- model_name, cache_dir=self.args.cache_dir, **kwargs
+ model_name,
+ cache_dir=self.args.cache_dir,
+ trust_remote_code=self.args.trust_remote_code,
+ **kwargs,
)
else:
self.config = config_class(**self.args.config, **kwargs)
@@ -359,12 +373,65 @@ def __init__(
)
)
else:
- self.model = model_class.from_pretrained(
- model_name,
- config=self.config,
- cache_dir=self.args.cache_dir,
- **kwargs,
- )
+ if self.args.nf4:
+ from transformers import BitsAndBytesConfig
+
+
+ nf4_config = BitsAndBytesConfig(
+ load_in_4bit=True,
+ bnb_4bit_quant_type="nf4",
+ bnb_4bit_use_double_quant=True,
+ bnb_4bit_compute_dtype=torch.bfloat16,
+ )
+ self.model = model_class.from_pretrained(
+ model_name,
+ quantization_config=nf4_config,
+ trust_remote_code=self.args.trust_remote_code,
+ )
+ else:
+ self.model = model_class.from_pretrained(
+ model_name,
+ config=self.config,
+ cache_dir=self.args.cache_dir,
+ **kwargs,
+ )
+
+ if self.args.peft:
+ from peft import LoraConfig, get_peft_model, LoftQConfig
+ from peft.peft_model import PeftModel
+
+ if self.args.qlora:
+ if self.args.nf4:
+ raise ValueError(
+ "PEFT and QLORA cannot be used together with NF4"
+ )
+ loftq_config = LoftQConfig(
+ loftq_bits=self.args.loftq_bits, **self.args.loftq_config
+ )
+ self.lora_config = LoraConfig(
+ init_lora_weights="loftq",
+ target_modules="all-linear",
+ loftq_config=loftq_config,
+ **self.args.lora_config,
+ )
+ self.args.fp16 = False
+ else:
+ self.lora_config = LoraConfig(
+ use_rslora=True, target_modules="all-linear"
+ )
+ self.model.gradient_checkpointing_enable()
+ self.model.enable_input_require_grads()
+ if adapter_name is not None:
+ self.model = PeftModel.from_pretrained(
+ self.model,
+ model_id=adapter_name,
+ adapter_name=adapter_name,
+ )
+ self.adapter_name = adapter_name
+ else:
+ self.model = get_peft_model(self.model, self.lora_config)
+ self.model.print_trainable_parameters()
+
else:
logger.info(" Training language model from scratch")
if self.args.model_type == "electra":
@@ -432,6 +499,15 @@ def __init__(
)
self.args.wandb_project = None
+ if self.args.rag:
+ if retrieval_model:
+ self.retrieval_model = retrieval_model
+ else:
+ raise ValueError(
+ "RAG is enabled but no retrieval model is specified."
+ " Pass a retrieval model when instantiating the LanguageModelingModel to use RAG."
+ )
+
def train_model(
self,
train_file,
@@ -608,6 +684,15 @@ def collate(examples: List[torch.Tensor]):
relative_step=args.adafactor_relative_step,
warmup_init=args.adafactor_warmup_init,
)
+ elif args.optimizer == "Adam8bit":
+ from bitsandbytes.optim import Adam8bit
+
+ optimizer = Adam8bit(
+ optimizer_grouped_parameters,
+ lr=args.learning_rate,
+ eps=args.adam_epsilon,
+ betas=args.adam_betas,
+ )
else:
raise ValueError(
@@ -769,7 +854,6 @@ def collate(examples: List[torch.Tensor]):
# TODO: Move this to _get_inputs_dict and keep the attention masks
if self.args.use_hf_datasets:
if self.args.stream_hf_datasets:
- # BUG: TOKENIZATION IS BUGGED FOR HF DATASETS
batch["input_ids"] = torch.stack(batch["input_ids"])
batch["attention_mask"] = torch.stack(batch["attention_mask"])
if self.args.model_type in [
@@ -806,16 +890,22 @@ def collate(examples: List[torch.Tensor]):
)
token_type_ids = (
batch["token_type_ids"].to(self.device)
- if self.args.use_hf_datasets
+ if self.args.use_hf_datasets and "token_type_ids" in batch
else None
)
labels = labels.to(self.device)
- inputs_dict = {
- "input_ids": inputs,
- "attention_mask": attention_mask,
- "token_type_ids": token_type_ids,
- }
+ if token_type_ids is None:
+ inputs_dict = {
+ "input_ids": inputs,
+ "attention_mask": attention_mask,
+ }
+ else:
+ inputs_dict = {
+ "input_ids": inputs,
+ "attention_mask": attention_mask,
+ "token_type_ids": token_type_ids,
+ }
if args.fp16:
with amp.autocast():
@@ -1100,8 +1190,7 @@ def collate(examples: List[torch.Tensor]):
epoch_number += 1
output_dir_current = os.path.join(
output_dir,
- "checkpoint-{}-epoch-{}".format(global_step, epoch_number)
-,
+ "checkpoint-{}-epoch-{}".format(global_step, epoch_number),
)
if args.save_model_every_epoch or args.evaluate_during_training:
@@ -1249,7 +1338,13 @@ def collate(examples: List[torch.Tensor]):
)
def eval_model(
- self, eval_file, output_dir=None, verbose=True, silent=False, **kwargs
+ self,
+ eval_file,
+ output_dir=None,
+ evaluate_generated_text=False,
+ verbose=True,
+ silent=False,
+ **kwargs,
):
"""
Evaluates the model on eval_df. Saves results to args.output_dir
@@ -1261,26 +1356,30 @@ def eval_model(
self._move_model_to_device()
- eval_dataset = self.load_and_cache_examples(
- eval_file, evaluate=True, verbose=verbose, silent=silent
- )
- os.makedirs(output_dir, exist_ok=True)
+ if evaluate_generated_text:
+ raise NotImplementedError(
+ "evaluate_generated_text is not yet implemented for this model type."
+ )
+ else:
+ eval_dataset = self.load_and_cache_examples(
+ eval_file, evaluate=True, verbose=verbose, silent=silent
+ )
+ os.makedirs(output_dir, exist_ok=True)
- result = self.evaluate(
- eval_dataset, output_dir, verbose=verbose, silent=silent, **kwargs
- )
- self.results.update(result)
+ result = self.evaluate(
+ eval_dataset, output_dir, verbose=verbose, silent=silent, **kwargs
+ )
+ self.results.update(result)
- if verbose:
- logger.info(self.results)
+ if verbose:
+ logger.info(self.results)
- return result
+ return result
def evaluate(
self,
eval_dataset,
output_dir,
- multi_label=False,
prefix="",
verbose=True,
silent=False,
@@ -1331,18 +1430,39 @@ def collate(examples: List[torch.Tensor]):
eval_dataloader, disable=args.silent or silent, desc="Running Evaluation"
):
if self.args.use_hf_datasets:
- batch = batch["input_ids"]
+ input_ids = batch["input_ids"]
inputs, labels = (
- mask_tokens(batch, tokenizer, args) if args.mlm else (batch, batch)
+ mask_tokens(batch, tokenizer, args)
+ if args.mlm
+ else (input_ids, input_ids)
)
inputs = inputs.to(self.device)
labels = labels.to(self.device)
+
+ if "token_type_ids" in batch:
+ inputs_dict = {
+ "input_ids": inputs,
+ "attention_mask": batch["attention_mask"].to(self.device)
+ if self.args.use_hf_datasets
+ else None,
+ "token_type_ids": batch["token_type_ids"].to(self.device)
+ if self.args.use_hf_datasets
+ else None,
+ }
+ else:
+ inputs_dict = {
+ "input_ids": inputs,
+ "attention_mask": batch["attention_mask"].to(self.device)
+ if self.args.use_hf_datasets
+ else None,
+ }
+
with torch.no_grad():
outputs = (
- model(inputs, labels=labels)
+ model(**inputs_dict, labels=labels)
if args.mlm
- else model(inputs, labels=labels)
+ else model(**inputs_dict, labels=labels)
)
if args.model_type == "electra":
g_loss = outputs[0]
@@ -1368,6 +1488,127 @@ def collate(examples: List[torch.Tensor]):
return results
+ def predict(
+ self,
+ to_predict,
+ generation_args=None,
+ rag_queries=None,
+ knowledge_dataset=None,
+ **kwargs,
+ ):
+ """
+ Performs text completions on a list of text. To be used with language models.
+
+ Args:
+
+ to_predict: A list of text to make predictions on.
+ generation_args: An instance of the `GenerationArgs` class containing the generation arguments for the model.
+ rag_queries (optional): A list of text to be used as queries for the RAG model. Only applicable if rag is enabled.
+ knowledge_dataset (optional): A list of text to be used as knowledge for the RAG model. Only applicable if the model is a RAG model.
+ **kwargs: Additional arguments to be passed to the models `generate()` method during inference.
+
+ Returns:
+ preds: A list of the predicted sequences.
+ """
+ self._move_model_to_device()
+
+ if not generation_args:
+ generation_args = GenerationArgs()
+
+ if self.args.peft and self.adapter_name:
+ logger.info(
+ "Merging adapter with model for faster inference. Contunuing training from this point may result in unexpected behavior."
+ )
+ self.model = self.model.merge_and_unload()
+
+ self.tokenizer.padding_side = "left"
+
+ if self.args.rag:
+ if not rag_queries:
+ rag_queries = to_predict
+ raise Warning(
+ "No `rag_queries` provided. Using `to_predict` as `rag_queries`."
+ )
+
+ context_docs = self.retrieval_model.predict(
+ rag_queries,
+ passages_only=True,
+ prediction_passages=knowledge_dataset,
+ )
+
+ to_predict = [
+ f"Context: {' '.join(context_doc)} {text}"
+ for context_doc, text in zip(context_docs, to_predict)
+ ]
+ # TODO:
+ # - Simplest option is to just prepend context: context_docs to to_predict
+ # - Advanced option is to have ... in to_predict and then replace ... with context_docs
+
+ try:
+ inputs = self.tokenizer(
+ to_predict,
+ padding=True,
+ return_tensors="pt",
+ )
+ except ValueError:
+ if not self.tokenizer.pad_token:
+ warnings.warn(
+ "The tokenizer you are using does not have a pad token set. Setting to `eos_token`."
+ )
+ self.tokenizer.pad_token = self.tokenizer.eos_token
+ inputs = self.tokenizer(
+ to_predict,
+ padding=True,
+ return_tensors="pt",
+ )
+
+ input_ids_tensor = inputs["input_ids"]
+ attention_mask_tensor = inputs["attention_mask"]
+
+ # Create a TensorDataset
+ dataset = TensorDataset(input_ids_tensor, attention_mask_tensor)
+
+ # Define batch size
+
+ # Create the dataloader
+ predict_dataloader = DataLoader(
+ dataset, batch_size=self.args.eval_batch_size, shuffle=False
+ )
+
+ # Put model in evaluation mode
+ self.model.eval()
+
+ # Predict
+ responses = []
+ outputs = []
+ for batch in tqdm(
+ predict_dataloader, desc="Generating outputs", disable=self.args.silent
+ ):
+ batch = tuple(t.to(self.device) for t in batch)
+ input_ids, attention_mask = batch
+
+ generation_output = self.model.generate(
+ input_ids,
+ attention_mask=attention_mask,
+ return_dict_in_generate=True,
+ output_scores=True,
+ **generation_args.get_dict(),
+ **kwargs,
+ )
+
+ # response_tests = self.tokenizer.batch_decode(generation_output.sequences[:, input_ids.shape[1]:], skip_special_tokens=True)
+
+ for i, s in enumerate(generation_output.sequences):
+ output = self.tokenizer.decode(
+ s[input_ids[i].shape[0] :], skip_special_tokens=True
+ )
+ responses.append(output)
+
+ # responses.extend(response_tests)
+ outputs.extend(generation_output)
+
+ return responses, generation_output
+
def load_and_cache_examples(
self, file_path, evaluate=False, no_cache=False, verbose=True, silent=False
):
@@ -1389,7 +1630,9 @@ def load_and_cache_examples(
mode = "dev" if evaluate else "train"
if self.args.use_hf_datasets:
- dataset = load_hf_dataset(file_path, tokenizer, self.args)
+ dataset = load_hf_dataset(
+ file_path, tokenizer, self.args, retrieval_model=self.retrieval_model
+ )
return dataset
elif args.dataset_class:
CustomDataset = args.dataset_class
@@ -1629,7 +1872,8 @@ def _threshold(self, x, threshold):
return 0
def _move_model_to_device(self):
- self.model.to(self.device)
+ if not self.args.qlora and not self.args.nf4:
+ self.model.to(self.device)
def _create_training_progress_scores(self, **kwargs):
extra_metrics = {key: [] for key in kwargs}
diff --git a/simpletransformers/language_modeling/language_modeling_utils.py b/simpletransformers/language_modeling/language_modeling_utils.py
index 09264119..7ae9fb7b 100644
--- a/simpletransformers/language_modeling/language_modeling_utils.py
+++ b/simpletransformers/language_modeling/language_modeling_utils.py
@@ -3,6 +3,7 @@
import pickle
from multiprocessing import Pool
from typing import Tuple
+import warnings
import torch
from torch.utils.data import Dataset
@@ -74,13 +75,23 @@ def chunk_sequence(sequence, max_length):
def preprocess_and_chunk_batch_for_hf_dataset(
- dataset, tokenizer, max_seq_length, max_word_length=100
+ dataset, tokenizer, max_seq_length, chunk_text=True
):
- chunked_texts = []
- for text in dataset["text"]:
- chunks = chunk_sequence(text, max_seq_length)
- for chunk in chunks:
- chunked_texts.append(chunk)
+ if chunk_text:
+ chunked_texts = []
+ for text in dataset["text"]:
+ chunks = chunk_sequence(text, max_seq_length)
+ for chunk in chunks:
+ chunked_texts.append(chunk)
+
+ logger.info(
+ "Chunked %d examples into %d chunks with a maximum length of %d.",
+ len(dataset["text"]),
+ len(chunked_texts),
+ max_seq_length,
+ )
+ else:
+ chunked_texts = dataset["text"]
return tokenizer(
text=chunked_texts,
@@ -99,14 +110,15 @@ def preprocess_batch_for_hf_dataset(dataset, tokenizer, max_seq_length):
)
-def load_hf_dataset(data, tokenizer, args):
+def load_hf_dataset(data, tokenizer, args, retrieval_model=None):
if args.data_format == "text":
dataset = load_dataset(
+ "text",
data_files=data,
download_mode="force_redownload"
if args.reprocess_input_data
else "reuse_dataset_if_exists",
- streaming=True,
+ streaming=True if args.stream_hf_datasets else False,
)
elif args.data_format == "tsv":
dataset = load_dataset(
@@ -116,28 +128,94 @@ def load_hf_dataset(data, tokenizer, args):
download_mode="force_redownload"
if args.reprocess_input_data
else "reuse_dataset_if_exists",
- streaming=True,
+ streaming=True if args.stream_hf_datasets else False,
+ )
+ elif args.data_format == "json" or args.data_format == "jsonl":
+ dataset = load_dataset(
+ "json",
+ data_files=data,
+ download_mode="force_redownload"
+ if args.reprocess_input_data
+ else "reuse_dataset_if_exists",
+ streaming=True if args.stream_hf_datasets else False,
)
else:
raise ValueError("args.data_format must be either 'text' or 'tsv'")
- dataset = dataset.map(
- lambda x: preprocess_and_chunk_batch_for_hf_dataset(
- x, tokenizer=tokenizer, max_seq_length=args.max_seq_length
- ),
- batched=True,
- remove_columns=["text"],
- )
-
- # dataset = dataset.with_format(
- # type="pt", columns=["input_ids", "token_type_ids", "attention_mask"]
- # )
+ if retrieval_model:
+ if retrieval_model.prediction_passages is None:
+ raise ValueError(
+ "The RetrievalModel must be initialized with prediction_passages to use it for RAG training."
+ )
+ dataset = dataset["train"]
+ logger.info("Retrieving context documents for RAG training.")
+ rag_queries = dataset["rag_query"]
+ context_docs = retrieval_model.predict(rag_queries, passages_only=True)
+ retrieval_model.context_encoder.to("cpu")
+ retrieval_model.query_encoder.to("cpu")
+ context_docs = [" ".join(docs) for docs in context_docs]
+
+ dataset = dataset.add_column("context", context_docs)
+
+ logger.info("Merging context documents with the original text.")
+
+ def batch_process(examples):
+ # Concatenate "context" and "text" for each example in the batch
+ concatenated_texts = [
+ context + " " + text
+ for context, text in zip(examples["context"], examples["text"])
+ ]
+ return {"text": concatenated_texts}
+
+ # Apply the batch processing function to the dataset
+ dataset = dataset.map(batch_process, batched=True)
+
+ logger.info("Merged context documents with the original text.")
+
+ try:
+ dataset = dataset.map(
+ lambda x: preprocess_and_chunk_batch_for_hf_dataset(
+ x,
+ tokenizer=tokenizer,
+ max_seq_length=args.max_seq_length,
+ chunk_text=args.chunk_text,
+ ),
+ batched=True,
+ remove_columns=["text"],
+ )
+ except ValueError:
+ if not tokenizer.pad_token:
+ warnings.warn(
+ "The tokenizer you are using does not have a pad token set. Setting to 'tokenizer.eos_token'"
+ )
+ tokenizer.pad_token = tokenizer.eos_token
+ dataset = dataset.map(
+ lambda x: preprocess_and_chunk_batch_for_hf_dataset(
+ x,
+ tokenizer=tokenizer,
+ max_seq_length=args.max_seq_length,
+ chunk_text=args.chunk_text,
+ ),
+ batched=True,
+ remove_columns=["text"],
+ )
- if isinstance(data, str):
+ try:
# This is not necessarily a train dataset. The datasets library insists on calling it train.
- return dataset["train"]
+ dataset = dataset["train"]
+ except:
+ pass
+
+ if "token_type_ids" in dataset.features:
+ dataset = dataset.with_format(
+ type="pt", columns=["input_ids", "token_type_ids", "attention_mask"]
+ )
else:
- return dataset
+ dataset = dataset.with_format(
+ type="pt", columns=["input_ids", "attention_mask"]
+ )
+
+ return dataset
class SimpleDataset(Dataset):
diff --git a/simpletransformers/ner/ner_model.py b/simpletransformers/ner/ner_model.py
index 19fdd6dc..7e3a02f9 100755
--- a/simpletransformers/ner/ner_model.py
+++ b/simpletransformers/ner/ner_model.py
@@ -1036,8 +1036,7 @@ def train(
epoch_number += 1
output_dir_current = os.path.join(
output_dir,
- "checkpoint-{}-epoch-{}".format(global_step, epoch_number)
-,
+ "checkpoint-{}-epoch-{}".format(global_step, epoch_number),
)
if args.save_model_every_epoch or args.evaluate_during_training:
diff --git a/simpletransformers/pretrain_retrieval/pretrain_retrieval_model.py b/simpletransformers/pretrain_retrieval/pretrain_retrieval_model.py
index 216414c9..40feae54 100644
--- a/simpletransformers/pretrain_retrieval/pretrain_retrieval_model.py
+++ b/simpletransformers/pretrain_retrieval/pretrain_retrieval_model.py
@@ -879,8 +879,7 @@ def train(
epoch_number += 1
output_dir_current = os.path.join(
output_dir,
- "checkpoint-{}-epoch-{}".format(global_step, epoch_number)
-,
+ "checkpoint-{}-epoch-{}".format(global_step, epoch_number),
)
if clustered_training:
diff --git a/simpletransformers/question_answering/question_answering_model.py b/simpletransformers/question_answering/question_answering_model.py
index 3c3ee72c..7aa363fa 100755
--- a/simpletransformers/question_answering/question_answering_model.py
+++ b/simpletransformers/question_answering/question_answering_model.py
@@ -938,8 +938,7 @@ def train(
epoch_number += 1
output_dir_current = os.path.join(
output_dir,
- "checkpoint-{}-epoch-{}".format(global_step, epoch_number)
-,
+ "checkpoint-{}-epoch-{}".format(global_step, epoch_number),
)
if args.save_model_every_epoch or args.evaluate_during_training:
diff --git a/simpletransformers/retrieval/retrieval_model.py b/simpletransformers/retrieval/retrieval_model.py
index c75d5581..6c653b53 100644
--- a/simpletransformers/retrieval/retrieval_model.py
+++ b/simpletransformers/retrieval/retrieval_model.py
@@ -316,7 +316,7 @@ def __init__(
query_encoder_name, **self.args.query_config
)
# if self.args.query_config.get("projection_dim") is not None:
- # query_encoder._keys_to_ignore_on_load_missing.append("encode_proj")
+ # query_encoder._keys_to_ignore_on_load_missing.append("encode_proj")
self.query_encoder = query_encoder.from_pretrained(
query_encoder_name, config=self.query_config
)
@@ -1608,7 +1608,11 @@ def eval_model(
else:
custom_mrr = False
- evaluator = pytrec_eval.RelevanceEvaluator(qrels_dict, pytrec_eval_metrics, relevance_level=self.args.relevance_level)
+ evaluator = pytrec_eval.RelevanceEvaluator(
+ qrels_dict,
+ pytrec_eval_metrics,
+ relevance_level=self.args.relevance_level,
+ )
try:
results = evaluator.evaluate(run_dict)
@@ -2208,7 +2212,11 @@ def predict(
)
else:
doc_ids, doc_vectors, doc_dicts = retrieval_outputs
- passages = [d["passages"] for d in doc_dicts]
+
+ try:
+ passages = [d["passages"] for d in doc_dicts]
+ except KeyError:
+ passages = [d["passage_text"] for d in doc_dicts]
if self.args.unified_rr:
rerank_similarity = compute_rerank_similarity(
@@ -2416,7 +2424,10 @@ def retrieve_docs_from_query_embeddings(
passages_only=True,
)
- passages.extend([d["passages"] for d in doc_dicts_batch])
+ try:
+ passages.extend([d["passages"] for d in doc_dicts_batch])
+ except KeyError:
+ passages.extend([d["passage_text"] for d in doc_dicts_batch])
return passages
elif doc_ids_only:
diff --git a/simpletransformers/retrieval/retrieval_tools.py b/simpletransformers/retrieval/retrieval_tools.py
index 8a58714a..b223ca5a 100644
--- a/simpletransformers/retrieval/retrieval_tools.py
+++ b/simpletransformers/retrieval/retrieval_tools.py
@@ -184,8 +184,6 @@ def generate_flipped_latex_row(
return row
-
-
def generate_latex_table(
results,
all_metrics,
@@ -387,8 +385,6 @@ def generate_flipped_latex_table(
return table
-
-
def analyze_experiment(
experiment_dir,
model_name_map=None,
@@ -516,10 +512,10 @@ def calculate_significance(run_a, run_b):
return pvalue
-def convert_trec_queries_to_beir_format(trec_queries_path, beir_queries_path=None, save=True):
- trec_queries_df = pd.read_csv(
- trec_queries_path, sep="\t", names=["_id", "text"]
- )
+def convert_trec_queries_to_beir_format(
+ trec_queries_path, beir_queries_path=None, save=True
+):
+ trec_queries_df = pd.read_csv(trec_queries_path, sep="\t", names=["_id", "text"])
trec_queries_df["_id"] = trec_queries_df["_id"].astype(str)
trec_queries_df["text"] = trec_queries_df["text"].astype(str)
if beir_queries_path is None:
@@ -532,6 +528,7 @@ def convert_trec_queries_to_beir_format(trec_queries_path, beir_queries_path=Non
return trec_queries_df
+
def convert_trec_qrels_to_beir_format(trec_qrels_path, beir_qrels_path=None, save=True):
trec_qrel_df = pd.read_csv(
trec_qrels_path, sep=" ", names=["query-id", "Q0", "corpus-id", "score"]
@@ -544,7 +541,9 @@ def convert_trec_qrels_to_beir_format(trec_qrels_path, beir_qrels_path=None, sav
# Extension for trec_qrels_path can be .txt or .tsv
beir_qrels_path = os.path.join(
os.path.dirname(trec_qrels_path),
- os.path.basename(trec_qrels_path).replace(".txt", ".tsv").replace(".tsv", ".tsv"),
+ os.path.basename(trec_qrels_path)
+ .replace(".txt", ".tsv")
+ .replace(".tsv", ".tsv"),
)
beir_qrels_df = trec_qrel_df[["query-id", "corpus-id", "score"]]
diff --git a/simpletransformers/retrieval/retrieval_utils.py b/simpletransformers/retrieval/retrieval_utils.py
index 450e1163..c2ec8974 100644
--- a/simpletransformers/retrieval/retrieval_utils.py
+++ b/simpletransformers/retrieval/retrieval_utils.py
@@ -97,7 +97,9 @@ def load_hf_dataset(
)
dataset = dataset.map(
lambda example: {
- "gold_passage": example["title"] + " " + example["gold_passage"] if example["title"] is not None else example["gold_passage"]
+ "gold_passage": example["title"] + " " + example["gold_passage"]
+ if example["title"] is not None
+ else example["gold_passage"]
}
)
@@ -935,7 +937,6 @@ def get_prediction_passage_dataset(
"csv",
data_files=prediction_passages,
delimiter="\t",
- column_names=["passages"],
cache_dir=args.dataset_cache_dir,
)
prediction_passages_dataset = prediction_passages_dataset["train"]
@@ -946,7 +947,7 @@ def get_prediction_passage_dataset(
)
prediction_passages_dataset = prediction_passages_dataset.map(
lambda example: {
- "gold_passage": example["title"] + " " + example["gold_passage"]
+ "passages": example["title"] + " " + example["gold_passage"]
} # Should these be "passage" instead of "gold_passage"?
)
elif isinstance(prediction_passages, list):
@@ -962,7 +963,7 @@ def get_prediction_passage_dataset(
)
prediction_passages_dataset = prediction_passages_dataset.map(
lambda example: {
- "gold_passage": example["title"] + " " + example["gold_passage"]
+ "passages": example["title"] + " " + example["gold_passage"]
}
)
@@ -1086,7 +1087,9 @@ def get_top_docs(
def get_top_doc_ids(
self, question_hidden_states, n_docs=5, reranking_query_outputs=None
):
- scores, ids = self.dataset.search_batch("embeddings", question_hidden_states, n_docs)
+ scores, ids = self.dataset.search_batch(
+ "embeddings", question_hidden_states, n_docs
+ )
docs = [self.dataset[[i for i in indices if i >= 0]] for indices in ids]
doc_ids = [doc["passage_id"] for doc in docs]
diff --git a/simpletransformers/seq2seq/seq2seq_model.py b/simpletransformers/seq2seq/seq2seq_model.py
index 2daf9b8e..ad0f1e23 100644
--- a/simpletransformers/seq2seq/seq2seq_model.py
+++ b/simpletransformers/seq2seq/seq2seq_model.py
@@ -976,8 +976,7 @@ def train(
epoch_number += 1
output_dir_current = os.path.join(
output_dir,
- "checkpoint-{}-epoch-{}".format(global_step, epoch_number)
-,
+ "checkpoint-{}-epoch-{}".format(global_step, epoch_number),
)
if args.save_model_every_epoch or args.evaluate_during_training:
diff --git a/simpletransformers/t5/t5_model.py b/simpletransformers/t5/t5_model.py
index 8091a4f4..602c4f8a 100644
--- a/simpletransformers/t5/t5_model.py
+++ b/simpletransformers/t5/t5_model.py
@@ -754,8 +754,7 @@ def train(
epoch_number += 1
output_dir_current = os.path.join(
output_dir,
- "checkpoint-{}-epoch-{}".format(global_step, epoch_number)
-,
+ "checkpoint-{}-epoch-{}".format(global_step, epoch_number),
)
if args.save_model_every_epoch or args.evaluate_during_training:
@@ -1167,14 +1166,16 @@ def rerank(self, eval_data, qrels):
if args.n_gpu > 1:
model = torch.nn.DataParallel(model)
- eval_iterator = tqdm(
- eval_dataloader, desc="Evaluating", disable=args.silent
- )
+ eval_iterator = tqdm(eval_dataloader, desc="Evaluating", disable=args.silent)
reranking_preds = []
reranking_scores = []
- true_token_idx = self.tokenizer("true", add_special_tokens=False)["input_ids"][0]
- false_token_idx = self.tokenizer("false", add_special_tokens=False)["input_ids"][0]
+ true_token_idx = self.tokenizer("true", add_special_tokens=False)["input_ids"][
+ 0
+ ]
+ false_token_idx = self.tokenizer("false", add_special_tokens=False)[
+ "input_ids"
+ ][0]
for batch in eval_iterator:
inputs = self._get_inputs_dict(batch)
@@ -1186,7 +1187,7 @@ def rerank(self, eval_data, qrels):
encoder_outputs=inputs["encoder_outputs"],
max_new_tokens=self.args.max_length,
output_scores=True,
- return_dict_in_generate=True
+ return_dict_in_generate=True,
)
else:
outputs = self.model.generate(
@@ -1194,10 +1195,12 @@ def rerank(self, eval_data, qrels):
attention_mask=inputs["attention_mask"],
max_new_tokens=self.args.max_length,
output_scores=True,
- return_dict_in_generate=True
+ return_dict_in_generate=True,
)
- preds, scores = self._get_reranking_outputs(outputs, true_token_idx, false_token_idx)
+ preds, scores = self._get_reranking_outputs(
+ outputs, true_token_idx, false_token_idx
+ )
reranking_preds.extend(preds)
reranking_scores.extend(scores)
@@ -1216,12 +1219,16 @@ def rerank(self, eval_data, qrels):
# Sort by score
for query_id in run_dict:
run_dict[query_id] = dict(
- sorted(run_dict[query_id].items(), key=lambda item: item[1], reverse=True)
+ sorted(
+ run_dict[query_id].items(), key=lambda item: item[1], reverse=True
+ )
)
os.makedirs(args.output_dir, exist_ok=True)
- runfile_save_path = os.path.join(args.output_dir, f"{eval_data.split('/')[-1]}-runfile.json")
+ runfile_save_path = os.path.join(
+ args.output_dir, f"{eval_data.split('/')[-1]}-runfile.json"
+ )
with open(runfile_save_path, "w") as f:
json.dump(run_dict, f)
@@ -1334,7 +1341,14 @@ def _get_inputs_dict(self, batch):
return inputs
def load_and_cache_examples(
- self, data, evaluate=False, no_cache=False, verbose=True, silent=False, tokenize_targets=True, reranking=False
+ self,
+ data,
+ evaluate=False,
+ no_cache=False,
+ verbose=True,
+ silent=False,
+ tokenize_targets=True,
+ reranking=False,
):
"""
Creates a T5Dataset from data.
@@ -1354,7 +1368,13 @@ def load_and_cache_examples(
mode = "dev" if evaluate else "train"
if self.args.use_hf_datasets:
- dataset = load_hf_dataset(data, tokenizer, self.args, tokenize_targets=tokenize_targets, reranking=reranking)
+ dataset = load_hf_dataset(
+ data,
+ tokenizer,
+ self.args,
+ tokenize_targets=tokenize_targets,
+ reranking=reranking,
+ )
return dataset
elif args.dataset_class:
CustomDataset = args.dataset_class
diff --git a/simpletransformers/t5/t5_utils.py b/simpletransformers/t5/t5_utils.py
index e290fcc1..5ab9730b 100644
--- a/simpletransformers/t5/t5_utils.py
+++ b/simpletransformers/t5/t5_utils.py
@@ -92,7 +92,9 @@ def load_hf_dataset(data, tokenizer, args, tokenize_targets=True, reranking=Fals
"input_text": datasets.Value("string"),
"target_text": datasets.Value("string"),
}
- ) if args.add_prefix else datasets.Features(
+ )
+ if args.add_prefix
+ else datasets.Features(
{
"input_text": datasets.Value("string"),
"target_text": datasets.Value("string"),
@@ -105,7 +107,9 @@ def load_hf_dataset(data, tokenizer, args, tokenize_targets=True, reranking=Fals
# tokenize_targets = not (evaluate and args.model_type == "eet5")
dataset = dataset.map(
- lambda x: preprocess_batch_for_hf_dataset(x, tokenizer=tokenizer, args=args, tokenize_targets=tokenize_targets),
+ lambda x: preprocess_batch_for_hf_dataset(
+ x, tokenizer=tokenizer, args=args, tokenize_targets=tokenize_targets
+ ),
batched=True,
)
@@ -118,11 +122,17 @@ def load_hf_dataset(data, tokenizer, args, tokenize_targets=True, reranking=Fals
if tokenize_targets:
dataset.set_format(
type="pt",
- columns=["input_ids", "attention_mask", "labels", "encoder_outputs"],
+ columns=[
+ "input_ids",
+ "attention_mask",
+ "labels",
+ "encoder_outputs",
+ ],
)
else:
dataset.set_format(
- type="pt", columns=["input_ids", "attention_mask", "encoder_outputs"]
+ type="pt",
+ columns=["input_ids", "attention_mask", "encoder_outputs"],
)
else:
if tokenize_targets:
@@ -131,9 +141,7 @@ def load_hf_dataset(data, tokenizer, args, tokenize_targets=True, reranking=Fals
columns=["input_ids", "attention_mask", "labels"],
)
else:
- dataset.set_format(
- type="pt", columns=["input_ids", "attention_mask"]
- )
+ dataset.set_format(type="pt", columns=["input_ids", "attention_mask"])
else:
dataset.set_format(type="pt", columns=["input_ids", "attention_mask", "labels"])
diff --git a/train.txt b/train.txt
new file mode 100644
index 00000000..721c1e19
--- /dev/null
+++ b/train.txt
@@ -0,0 +1,100 @@
+Hello world with Simple Transformers!
+Hello world with Simple Transformers!
+Hello world with Simple Transformers!
+Hello world with Simple Transformers!
+Hello world with Simple Transformers!
+Hello world with Simple Transformers!
+Hello world with Simple Transformers!
+Hello world with Simple Transformers!
+Hello world with Simple Transformers!
+Hello world with Simple Transformers!
+Hello world with Simple Transformers!
+Hello world with Simple Transformers!
+Hello world with Simple Transformers!
+Hello world with Simple Transformers!
+Hello world with Simple Transformers!
+Hello world with Simple Transformers!
+Hello world with Simple Transformers!
+Hello world with Simple Transformers!
+Hello world with Simple Transformers!
+Hello world with Simple Transformers!
+Hello world with Simple Transformers!
+Hello world with Simple Transformers!
+Hello world with Simple Transformers!
+Hello world with Simple Transformers!
+Hello world with Simple Transformers!
+Hello world with Simple Transformers!
+Hello world with Simple Transformers!
+Hello world with Simple Transformers!
+Hello world with Simple Transformers!
+Hello world with Simple Transformers!
+Hello world with Simple Transformers!
+Hello world with Simple Transformers!
+Hello world with Simple Transformers!
+Hello world with Simple Transformers!
+Hello world with Simple Transformers!
+Hello world with Simple Transformers!
+Hello world with Simple Transformers!
+Hello world with Simple Transformers!
+Hello world with Simple Transformers!
+Hello world with Simple Transformers!
+Hello world with Simple Transformers!
+Hello world with Simple Transformers!
+Hello world with Simple Transformers!
+Hello world with Simple Transformers!
+Hello world with Simple Transformers!
+Hello world with Simple Transformers!
+Hello world with Simple Transformers!
+Hello world with Simple Transformers!
+Hello world with Simple Transformers!
+Hello world with Simple Transformers!
+Hello world with Simple Transformers!
+Hello world with Simple Transformers!
+Hello world with Simple Transformers!
+Hello world with Simple Transformers!
+Hello world with Simple Transformers!
+Hello world with Simple Transformers!
+Hello world with Simple Transformers!
+Hello world with Simple Transformers!
+Hello world with Simple Transformers!
+Hello world with Simple Transformers!
+Hello world with Simple Transformers!
+Hello world with Simple Transformers!
+Hello world with Simple Transformers!
+Hello world with Simple Transformers!
+Hello world with Simple Transformers!
+Hello world with Simple Transformers!
+Hello world with Simple Transformers!
+Hello world with Simple Transformers!
+Hello world with Simple Transformers!
+Hello world with Simple Transformers!
+Hello world with Simple Transformers!
+Hello world with Simple Transformers!
+Hello world with Simple Transformers!
+Hello world with Simple Transformers!
+Hello world with Simple Transformers!
+Hello world with Simple Transformers!
+Hello world with Simple Transformers!
+Hello world with Simple Transformers!
+Hello world with Simple Transformers!
+Hello world with Simple Transformers!
+Hello world with Simple Transformers!
+Hello world with Simple Transformers!
+Hello world with Simple Transformers!
+Hello world with Simple Transformers!
+Hello world with Simple Transformers!
+Hello world with Simple Transformers!
+Hello world with Simple Transformers!
+Hello world with Simple Transformers!
+Hello world with Simple Transformers!
+Hello world with Simple Transformers!
+Hello world with Simple Transformers!
+Hello world with Simple Transformers!
+Hello world with Simple Transformers!
+Hello world with Simple Transformers!
+Hello world with Simple Transformers!
+Hello world with Simple Transformers!
+Hello world with Simple Transformers!
+Hello world with Simple Transformers!
+Hello world with Simple Transformers!
+Hello world with Simple Transformers!