diff --git a/elmo-sanity-check-v2-elmo_40k.ipynb b/elmo-sanity-check-v2-elmo_40k.ipynb
new file mode 100644
index 0000000..d2c4073
--- /dev/null
+++ b/elmo-sanity-check-v2-elmo_40k.ipynb
@@ -0,0 +1,1104 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "refined-creek",
+ "metadata": {},
+ "source": [
+ "## Prepare Data + Model"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "generous-timothy",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# !cat examples/data/text_forward.txt"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "valid-lingerie",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# !ls -al ./outputs/en.1-percent.elmo-bert-causal-fixr2l"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "optional-motivation",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "tired-commodity",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import torch\n",
+ "from newlm.lm.elmo.modeling_elmo.elmo_head import ELMOBertLMHeadModel\n",
+ "from newlm.lm.elmo.lm_builder import ELMOLMBuilder\n",
+ "from transformers import BertConfig"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "outdoor-recognition",
+ "metadata": {},
+ "source": [
+ "#### Model"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "fifty-projector",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "model = ELMOBertLMHeadModel.from_pretrained(\n",
+ " \"./outputs/en.100-percent.elmo-bert-causal.40k\"\n",
+ ") # use pre-trained model"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "former-neighbor",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Model in eval mode for consistency\n"
+ ]
+ }
+ ],
+ "source": [
+ "model.eval()\n",
+ "print(\"Model in eval mode for consistency\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "ultimate-posting",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "markdown",
+ "id": "living-consumption",
+ "metadata": {},
+ "source": [
+ "#### Data"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "id": "italic-resort",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "2021-11-22 19:46:05.850 | INFO | newlm.lm.elmo.lm_builder:_get_dataset:142 - Constructing roBERTa style dataset\n"
+ ]
+ }
+ ],
+ "source": [
+ "%%capture\n",
+ "\n",
+ "from newlm.utils.file_util import read_from_yaml\n",
+ "config_file = read_from_yaml('examples/configs_gcloud/run-100-percent.elmo-bert-causal.yaml')\n",
+ "\n",
+ "# lm builder (helper)\n",
+ "elmo_lm_builder = ELMOLMBuilder(\n",
+ " model_config = config_file['lm']['model']['config'],\n",
+ " tokenizer=\"./outputs/en.100-percent.elmo-bert-causal.40k\", # use pre-trained tokenizer\n",
+ " model_type=\"bert-causal-elmo\",\n",
+ " max_len=128\n",
+ ")\n",
+ "\n",
+ "# dataset-forward\n",
+ "train_path = \"./examples/data/text_forward-small.txt\"\n",
+ "ds_f = elmo_lm_builder._get_dataset(train_path)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "id": "forward-litigation",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "max_steps is given, it will override any value given in num_train_epochs\n"
+ ]
+ }
+ ],
+ "source": [
+ "# trainer (helper)\n",
+ "from transformers import TrainingArguments, Trainer\n",
+ "args = TrainingArguments(output_dir=\"tmpout\",**config_file['lm']['hf_trainer']['args'])\n",
+ "\n",
+ "# dataloader-forward\n",
+ "trainer = Trainer(model=model, args=args, data_collator=elmo_lm_builder.data_collator, train_dataset=ds_f,)\n",
+ "dl_f = trainer.get_train_dataloader() # Data Loader-forward"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "id": "subject-exposure",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "torch.Size([1, 123])"
+ ]
+ },
+ "execution_count": 8,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "batch_f = next(iter(dl_f))\n",
+ "batch_f['input_ids'].shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "supposed-relief",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "id": "international-debate",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Model in eval mode for consistency\n"
+ ]
+ }
+ ],
+ "source": [
+ "model.eval()\n",
+ "print(\"Model in eval mode for consistency\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "substantial-warning",
+ "metadata": {},
+ "source": [
+ "## Sanity Check"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "id": "weekly-reverse",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# batch_f"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "id": "optimum-italy",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import torch\n",
+ "\n",
+ "def reverse_batch(batch_f):\n",
+ " # reverse input\n",
+ " batch_f_input = torch.clone(batch_f['input_ids'])\n",
+ " batch_f_rev_input = torch.cat(\n",
+ " (\n",
+ " batch_f_input[0][0:1],\n",
+ " torch.flip(batch_f_input[0][1:-1], [0]),\n",
+ " batch_f_input[0][-1:]\n",
+ " )\n",
+ " )\n",
+ " batch_f_rev_input = batch_f_rev_input.reshape(1,-1)\n",
+ "\n",
+ " # reverse labels\n",
+ " batch_f_rev_labels = torch.clone(batch_f_rev_input)\n",
+ " \n",
+ " # batch_rev\n",
+ " batch_rev = batch_f.copy()\n",
+ " batch_rev['input_ids'] = batch_f_rev_input\n",
+ " batch_rev['labels'] = batch_f_rev_labels\n",
+ " \n",
+ " return batch_rev"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "id": "eastern-transfer",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import pandas as pd\n",
+ "\n",
+ "def pandas_check(batch_f, batch_rev):\n",
+ " tokens_f = elmo_lm_builder.tokenizer.convert_ids_to_tokens(batch_f['input_ids'][0])\n",
+ " tokens_f_rev = elmo_lm_builder.tokenizer.convert_ids_to_tokens(batch_rev['input_ids'][0])\n",
+ " return pd.DataFrame({\"forward\": tokens_f, \"reverse\": tokens_f_rev})"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "flush-syracuse",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "markdown",
+ "id": "human-reynolds",
+ "metadata": {},
+ "source": [
+ "#### Normal vs Reverse"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "id": "sexual-meditation",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "batch_rev = reverse_batch(batch_f)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "id": "rotary-ontario",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " forward | \n",
+ " reverse | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " [CLS] | \n",
+ " [CLS] | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " [UNK] | \n",
+ " . | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " is | \n",
+ " Jews | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " located | \n",
+ " among | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " in | \n",
+ " lived | \n",
+ "
\n",
+ " \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ "
\n",
+ " \n",
+ " 118 | \n",
+ " lived | \n",
+ " in | \n",
+ "
\n",
+ " \n",
+ " 119 | \n",
+ " among | \n",
+ " located | \n",
+ "
\n",
+ " \n",
+ " 120 | \n",
+ " Jews | \n",
+ " is | \n",
+ "
\n",
+ " \n",
+ " 121 | \n",
+ " . | \n",
+ " [UNK] | \n",
+ "
\n",
+ " \n",
+ " 122 | \n",
+ " [SEP] | \n",
+ " [SEP] | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
123 rows × 2 columns
\n",
+ "
"
+ ],
+ "text/plain": [
+ " forward reverse\n",
+ "0 [CLS] [CLS]\n",
+ "1 [UNK] .\n",
+ "2 is Jews\n",
+ "3 located among\n",
+ "4 in lived\n",
+ ".. ... ...\n",
+ "118 lived in\n",
+ "119 among located\n",
+ "120 Jews is\n",
+ "121 . [UNK]\n",
+ "122 [SEP] [SEP]\n",
+ "\n",
+ "[123 rows x 2 columns]"
+ ]
+ },
+ "execution_count": 14,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "pandas_check(batch_f, batch_rev)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "id": "beginning-committee",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "(torch.Size([1, 123]), torch.Size([1, 123]))"
+ ]
+ },
+ "execution_count": 15,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "batch_f['input_ids'].shape, batch_rev['input_ids'].shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "id": "cultural-block",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "l2r_loss tensor(3.9125, grad_fn=)\n",
+ "r2l_loss tensor(3.8753, grad_fn=)\n"
+ ]
+ }
+ ],
+ "source": [
+ "res = model(**batch_f) # forward"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "id": "sublime-rider",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "l2r_loss tensor(8.8075, grad_fn=)\n",
+ "r2l_loss tensor(8.8464, grad_fn=)\n"
+ ]
+ }
+ ],
+ "source": [
+ "res = model(**batch_rev) # reverse"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "european-examination",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "markdown",
+ "id": "synthetic-volleyball",
+ "metadata": {},
+ "source": [
+ "#### Random String"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "id": "ecological-advocate",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "torch.Size([1, 123])"
+ ]
+ },
+ "execution_count": 18,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "batch_f['input_ids'].shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "id": "cosmetic-shame",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# shuffle data\n",
+ "batch_f_input = batch_f['input_ids']\n",
+ "batch_shuffle_input = torch.cat(\n",
+ " (\n",
+ " batch_f_input[0][0:1],\n",
+ " torch.randint(\n",
+ " low=5, # 0-4 > ['[PAD]', '[UNK]', '[CLS]', '[SEP]', '[MASK]']\n",
+ " high=29999,\n",
+ " size=(121,), # modified based on seqlen!\n",
+ " dtype=torch.long\n",
+ " ),\n",
+ " batch_f_input[0][-1:]\n",
+ " )\n",
+ ")\n",
+ "batch_shuffle_input = batch_shuffle_input.reshape(1,-1)\n",
+ "# labels\n",
+ "batch_shuffle_labels = torch.clone(batch_shuffle_input) \n",
+ "# batch_shuffle\n",
+ "batch_shuffle = batch_f.copy()\n",
+ "batch_shuffle['input_ids'] = batch_shuffle_input\n",
+ "batch_shuffle['labels'] = batch_shuffle_labels"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 20,
+ "id": "respective-receiver",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "batch_shuffle_rev = reverse_batch(batch_shuffle)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "religious-camping",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "markdown",
+ "id": "advisory-subject",
+ "metadata": {},
+ "source": [
+ "##### Trial-1"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "id": "rotary-reduction",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " forward | \n",
+ " reverse | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " [CLS] | \n",
+ " [CLS] | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " ##wich | \n",
+ " battery | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " desk | \n",
+ " Ner | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " Collection | \n",
+ " ought | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " fisher | \n",
+ " Mug | \n",
+ "
\n",
+ " \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ "
\n",
+ " \n",
+ " 118 | \n",
+ " Mug | \n",
+ " fisher | \n",
+ "
\n",
+ " \n",
+ " 119 | \n",
+ " ought | \n",
+ " Collection | \n",
+ "
\n",
+ " \n",
+ " 120 | \n",
+ " Ner | \n",
+ " desk | \n",
+ "
\n",
+ " \n",
+ " 121 | \n",
+ " battery | \n",
+ " ##wich | \n",
+ "
\n",
+ " \n",
+ " 122 | \n",
+ " [SEP] | \n",
+ " [SEP] | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
123 rows × 2 columns
\n",
+ "
"
+ ],
+ "text/plain": [
+ " forward reverse\n",
+ "0 [CLS] [CLS]\n",
+ "1 ##wich battery\n",
+ "2 desk Ner\n",
+ "3 Collection ought\n",
+ "4 fisher Mug\n",
+ ".. ... ...\n",
+ "118 Mug fisher\n",
+ "119 ought Collection\n",
+ "120 Ner desk\n",
+ "121 battery ##wich\n",
+ "122 [SEP] [SEP]\n",
+ "\n",
+ "[123 rows x 2 columns]"
+ ]
+ },
+ "execution_count": 21,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "pandas_check(batch_shuffle, batch_shuffle_rev)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 22,
+ "id": "presidential-impossible",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "l2r_loss tensor(13.4756, grad_fn=)\n",
+ "r2l_loss tensor(13.2225, grad_fn=)\n"
+ ]
+ }
+ ],
+ "source": [
+ "res = model(**batch_shuffle) # forward"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 23,
+ "id": "figured-carroll",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "l2r_loss tensor(13.2165, grad_fn=)\n",
+ "r2l_loss tensor(13.3072, grad_fn=)\n"
+ ]
+ }
+ ],
+ "source": [
+ "res = model(**batch_shuffle_rev) # reverse"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "fabulous-dutch",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "markdown",
+ "id": "dominant-falls",
+ "metadata": {},
+ "source": [
+ "##### Trial-2"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 24,
+ "id": "professional-visitor",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " forward | \n",
+ " reverse | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " [CLS] | \n",
+ " [CLS] | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " ##wich | \n",
+ " battery | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " desk | \n",
+ " Ner | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " Collection | \n",
+ " ought | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " fisher | \n",
+ " Mug | \n",
+ "
\n",
+ " \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ "
\n",
+ " \n",
+ " 118 | \n",
+ " Mug | \n",
+ " fisher | \n",
+ "
\n",
+ " \n",
+ " 119 | \n",
+ " ought | \n",
+ " Collection | \n",
+ "
\n",
+ " \n",
+ " 120 | \n",
+ " Ner | \n",
+ " desk | \n",
+ "
\n",
+ " \n",
+ " 121 | \n",
+ " battery | \n",
+ " ##wich | \n",
+ "
\n",
+ " \n",
+ " 122 | \n",
+ " [SEP] | \n",
+ " [SEP] | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
123 rows × 2 columns
\n",
+ "
"
+ ],
+ "text/plain": [
+ " forward reverse\n",
+ "0 [CLS] [CLS]\n",
+ "1 ##wich battery\n",
+ "2 desk Ner\n",
+ "3 Collection ought\n",
+ "4 fisher Mug\n",
+ ".. ... ...\n",
+ "118 Mug fisher\n",
+ "119 ought Collection\n",
+ "120 Ner desk\n",
+ "121 battery ##wich\n",
+ "122 [SEP] [SEP]\n",
+ "\n",
+ "[123 rows x 2 columns]"
+ ]
+ },
+ "execution_count": 24,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "pandas_check(batch_shuffle, batch_shuffle_rev)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 25,
+ "id": "reflected-thursday",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "l2r_loss tensor(13.4756, grad_fn=)\n",
+ "r2l_loss tensor(13.2225, grad_fn=)\n"
+ ]
+ }
+ ],
+ "source": [
+ "res = model(**batch_shuffle) # forward"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 26,
+ "id": "stupid-violin",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "l2r_loss tensor(13.2165, grad_fn=)\n",
+ "r2l_loss tensor(13.3072, grad_fn=)\n"
+ ]
+ }
+ ],
+ "source": [
+ "res = model(**batch_shuffle_rev) # reverse"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "quantitative-memory",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "markdown",
+ "id": "marine-settle",
+ "metadata": {},
+ "source": [
+ "##### Trial-3"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 27,
+ "id": "informal-wright",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " forward | \n",
+ " reverse | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " [CLS] | \n",
+ " [CLS] | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " ##wich | \n",
+ " battery | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " desk | \n",
+ " Ner | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " Collection | \n",
+ " ought | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " fisher | \n",
+ " Mug | \n",
+ "
\n",
+ " \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ "
\n",
+ " \n",
+ " 118 | \n",
+ " Mug | \n",
+ " fisher | \n",
+ "
\n",
+ " \n",
+ " 119 | \n",
+ " ought | \n",
+ " Collection | \n",
+ "
\n",
+ " \n",
+ " 120 | \n",
+ " Ner | \n",
+ " desk | \n",
+ "
\n",
+ " \n",
+ " 121 | \n",
+ " battery | \n",
+ " ##wich | \n",
+ "
\n",
+ " \n",
+ " 122 | \n",
+ " [SEP] | \n",
+ " [SEP] | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
123 rows × 2 columns
\n",
+ "
"
+ ],
+ "text/plain": [
+ " forward reverse\n",
+ "0 [CLS] [CLS]\n",
+ "1 ##wich battery\n",
+ "2 desk Ner\n",
+ "3 Collection ought\n",
+ "4 fisher Mug\n",
+ ".. ... ...\n",
+ "118 Mug fisher\n",
+ "119 ought Collection\n",
+ "120 Ner desk\n",
+ "121 battery ##wich\n",
+ "122 [SEP] [SEP]\n",
+ "\n",
+ "[123 rows x 2 columns]"
+ ]
+ },
+ "execution_count": 27,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "pandas_check(batch_shuffle, batch_shuffle_rev)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 28,
+ "id": "devoted-brush",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "l2r_loss tensor(13.4756, grad_fn=)\n",
+ "r2l_loss tensor(13.2225, grad_fn=)\n"
+ ]
+ }
+ ],
+ "source": [
+ "res = model(**batch_shuffle) # forward"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 29,
+ "id": "supported-entrance",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "l2r_loss tensor(13.2165, grad_fn=)\n",
+ "r2l_loss tensor(13.3072, grad_fn=)\n"
+ ]
+ }
+ ],
+ "source": [
+ "res = model(**batch_shuffle_rev) # reverse"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "inclusive-poker",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.8.5"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/elmo-sanity-check-v2-elmo_small_40k.ipynb b/elmo-sanity-check-v2-elmo_small_40k.ipynb
new file mode 100644
index 0000000..679bcb5
--- /dev/null
+++ b/elmo-sanity-check-v2-elmo_small_40k.ipynb
@@ -0,0 +1,1104 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "swedish-breast",
+ "metadata": {},
+ "source": [
+ "## Prepare Data + Model"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "former-delta",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# !cat examples/data/text_forward.txt"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "noted-jacksonville",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# !ls -al ./outputs/en.1-percent.elmo-bert-causal-fixr2l"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "requested-nudist",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "atomic-orbit",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import torch\n",
+ "from newlm.lm.elmo.modeling_elmo.elmo_head import ELMOBertLMHeadModel\n",
+ "from newlm.lm.elmo.lm_builder import ELMOLMBuilder\n",
+ "from transformers import BertConfig"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "empirical-forestry",
+ "metadata": {},
+ "source": [
+ "#### Model"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "outside-consultancy",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "model = ELMOBertLMHeadModel.from_pretrained(\n",
+ " \"./outputs/en.100-percent.elmo-small-bert-causal.40k\"\n",
+ ") # use pre-trained model"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "fewer-harmony",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Model in eval mode for consistency\n"
+ ]
+ }
+ ],
+ "source": [
+ "model.eval()\n",
+ "print(\"Model in eval mode for consistency\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "impressive-convert",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "markdown",
+ "id": "radical-secret",
+ "metadata": {},
+ "source": [
+ "#### Data"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "id": "sexual-wisconsin",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "2021-11-22 19:37:37.967 | INFO | newlm.lm.elmo.lm_builder:_get_dataset:142 - Constructing roBERTa style dataset\n"
+ ]
+ }
+ ],
+ "source": [
+ "%%capture\n",
+ "\n",
+ "from newlm.utils.file_util import read_from_yaml\n",
+ "config_file = read_from_yaml('examples/configs_gcloud/run-100-percent.elmo-small-bert-causal.yaml')\n",
+ "\n",
+ "# lm builder (helper)\n",
+ "elmo_lm_builder = ELMOLMBuilder(\n",
+ " model_config = config_file['lm']['model']['config'],\n",
+ " tokenizer=\"./outputs/en.100-percent.elmo-small-bert-causal.40k\", # use pre-trained tokenizer\n",
+ " model_type=\"bert-causal-elmo\",\n",
+ " max_len=128\n",
+ ")\n",
+ "\n",
+ "# dataset-forward\n",
+ "train_path = \"./examples/data/text_forward-small.txt\"\n",
+ "ds_f = elmo_lm_builder._get_dataset(train_path)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "id": "sophisticated-texas",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "max_steps is given, it will override any value given in num_train_epochs\n"
+ ]
+ }
+ ],
+ "source": [
+ "# trainer (helper)\n",
+ "from transformers import TrainingArguments, Trainer\n",
+ "args = TrainingArguments(output_dir=\"tmpout\",**config_file['lm']['hf_trainer']['args'])\n",
+ "\n",
+ "# dataloader-forward\n",
+ "trainer = Trainer(model=model, args=args, data_collator=elmo_lm_builder.data_collator, train_dataset=ds_f,)\n",
+ "dl_f = trainer.get_train_dataloader() # Data Loader-forward"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "id": "attractive-insight",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "torch.Size([1, 123])"
+ ]
+ },
+ "execution_count": 8,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "batch_f = next(iter(dl_f))\n",
+ "batch_f['input_ids'].shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "incorporate-integration",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "id": "original-voluntary",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Model in eval mode for consistency\n"
+ ]
+ }
+ ],
+ "source": [
+ "model.eval()\n",
+ "print(\"Model in eval mode for consistency\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "empirical-accessory",
+ "metadata": {},
+ "source": [
+ "## Sanity Check"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "id": "impaired-soviet",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# batch_f"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "id": "activated-platinum",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import torch\n",
+ "\n",
+ "def reverse_batch(batch_f):\n",
+ " # reverse input\n",
+ " batch_f_input = torch.clone(batch_f['input_ids'])\n",
+ " batch_f_rev_input = torch.cat(\n",
+ " (\n",
+ " batch_f_input[0][0:1],\n",
+ " torch.flip(batch_f_input[0][1:-1], [0]),\n",
+ " batch_f_input[0][-1:]\n",
+ " )\n",
+ " )\n",
+ " batch_f_rev_input = batch_f_rev_input.reshape(1,-1)\n",
+ "\n",
+ " # reverse labels\n",
+ " batch_f_rev_labels = torch.clone(batch_f_rev_input)\n",
+ " \n",
+ " # batch_rev\n",
+ " batch_rev = batch_f.copy()\n",
+ " batch_rev['input_ids'] = batch_f_rev_input\n",
+ " batch_rev['labels'] = batch_f_rev_labels\n",
+ " \n",
+ " return batch_rev"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "id": "adopted-tamil",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import pandas as pd\n",
+ "\n",
+ "def pandas_check(batch_f, batch_rev):\n",
+ " tokens_f = elmo_lm_builder.tokenizer.convert_ids_to_tokens(batch_f['input_ids'][0])\n",
+ " tokens_f_rev = elmo_lm_builder.tokenizer.convert_ids_to_tokens(batch_rev['input_ids'][0])\n",
+ " return pd.DataFrame({\"forward\": tokens_f, \"reverse\": tokens_f_rev})"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "opponent-grant",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "markdown",
+ "id": "electoral-effort",
+ "metadata": {},
+ "source": [
+ "#### Normal vs Reverse"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "id": "elegant-occasions",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "batch_rev = reverse_batch(batch_f)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "id": "demanding-wealth",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " forward | \n",
+ " reverse | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " [CLS] | \n",
+ " [CLS] | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " [UNK] | \n",
+ " . | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " is | \n",
+ " Jews | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " located | \n",
+ " among | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " in | \n",
+ " lived | \n",
+ "
\n",
+ " \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ "
\n",
+ " \n",
+ " 118 | \n",
+ " lived | \n",
+ " in | \n",
+ "
\n",
+ " \n",
+ " 119 | \n",
+ " among | \n",
+ " located | \n",
+ "
\n",
+ " \n",
+ " 120 | \n",
+ " Jews | \n",
+ " is | \n",
+ "
\n",
+ " \n",
+ " 121 | \n",
+ " . | \n",
+ " [UNK] | \n",
+ "
\n",
+ " \n",
+ " 122 | \n",
+ " [SEP] | \n",
+ " [SEP] | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
123 rows × 2 columns
\n",
+ "
"
+ ],
+ "text/plain": [
+ " forward reverse\n",
+ "0 [CLS] [CLS]\n",
+ "1 [UNK] .\n",
+ "2 is Jews\n",
+ "3 located among\n",
+ "4 in lived\n",
+ ".. ... ...\n",
+ "118 lived in\n",
+ "119 among located\n",
+ "120 Jews is\n",
+ "121 . [UNK]\n",
+ "122 [SEP] [SEP]\n",
+ "\n",
+ "[123 rows x 2 columns]"
+ ]
+ },
+ "execution_count": 14,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "pandas_check(batch_f, batch_rev)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "id": "stuffed-lightweight",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "(torch.Size([1, 123]), torch.Size([1, 123]))"
+ ]
+ },
+ "execution_count": 15,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "batch_f['input_ids'].shape, batch_rev['input_ids'].shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "id": "norman-pricing",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "l2r_loss tensor(4.0284, grad_fn=)\n",
+ "r2l_loss tensor(3.9578, grad_fn=)\n"
+ ]
+ }
+ ],
+ "source": [
+ "res = model(**batch_f) # forward"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "id": "divine-joseph",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "l2r_loss tensor(9.0815, grad_fn=)\n",
+ "r2l_loss tensor(9.1060, grad_fn=)\n"
+ ]
+ }
+ ],
+ "source": [
+ "res = model(**batch_rev) # reverse"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "hairy-logan",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "markdown",
+ "id": "expanded-portal",
+ "metadata": {},
+ "source": [
+ "#### Random String"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "id": "understanding-knock",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "torch.Size([1, 123])"
+ ]
+ },
+ "execution_count": 18,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "batch_f['input_ids'].shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "id": "raising-deviation",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# shuffle data\n",
+ "batch_f_input = batch_f['input_ids']\n",
+ "batch_shuffle_input = torch.cat(\n",
+ " (\n",
+ " batch_f_input[0][0:1],\n",
+ " torch.randint(\n",
+ " low=5, # 0-4 > ['[PAD]', '[UNK]', '[CLS]', '[SEP]', '[MASK]']\n",
+ " high=29999,\n",
+ " size=(121,), # modified based on seqlen!\n",
+ " dtype=torch.long\n",
+ " ),\n",
+ " batch_f_input[0][-1:]\n",
+ " )\n",
+ ")\n",
+ "batch_shuffle_input = batch_shuffle_input.reshape(1,-1)\n",
+ "# labels\n",
+ "batch_shuffle_labels = torch.clone(batch_shuffle_input) \n",
+ "# batch_shuffle\n",
+ "batch_shuffle = batch_f.copy()\n",
+ "batch_shuffle['input_ids'] = batch_shuffle_input\n",
+ "batch_shuffle['labels'] = batch_shuffle_labels"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 20,
+ "id": "headed-aside",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "batch_shuffle_rev = reverse_batch(batch_shuffle)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "distinct-steel",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "markdown",
+ "id": "dirty-doctor",
+ "metadata": {},
+ "source": [
+ "##### Trial-1"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "id": "virtual-church",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " forward | \n",
+ " reverse | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " [CLS] | \n",
+ " [CLS] | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " ##wich | \n",
+ " battery | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " desk | \n",
+ " Ner | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " Collection | \n",
+ " ought | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " fisher | \n",
+ " Mug | \n",
+ "
\n",
+ " \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ "
\n",
+ " \n",
+ " 118 | \n",
+ " Mug | \n",
+ " fisher | \n",
+ "
\n",
+ " \n",
+ " 119 | \n",
+ " ought | \n",
+ " Collection | \n",
+ "
\n",
+ " \n",
+ " 120 | \n",
+ " Ner | \n",
+ " desk | \n",
+ "
\n",
+ " \n",
+ " 121 | \n",
+ " battery | \n",
+ " ##wich | \n",
+ "
\n",
+ " \n",
+ " 122 | \n",
+ " [SEP] | \n",
+ " [SEP] | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
123 rows × 2 columns
\n",
+ "
"
+ ],
+ "text/plain": [
+ " forward reverse\n",
+ "0 [CLS] [CLS]\n",
+ "1 ##wich battery\n",
+ "2 desk Ner\n",
+ "3 Collection ought\n",
+ "4 fisher Mug\n",
+ ".. ... ...\n",
+ "118 Mug fisher\n",
+ "119 ought Collection\n",
+ "120 Ner desk\n",
+ "121 battery ##wich\n",
+ "122 [SEP] [SEP]\n",
+ "\n",
+ "[123 rows x 2 columns]"
+ ]
+ },
+ "execution_count": 21,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "pandas_check(batch_shuffle, batch_shuffle_rev)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 22,
+ "id": "artificial-salem",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "l2r_loss tensor(14.4735, grad_fn=)\n",
+ "r2l_loss tensor(13.7806, grad_fn=)\n"
+ ]
+ }
+ ],
+ "source": [
+ "res = model(**batch_shuffle) # forward"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 23,
+ "id": "blond-pitch",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "l2r_loss tensor(14.0861, grad_fn=)\n",
+ "r2l_loss tensor(13.5071, grad_fn=)\n"
+ ]
+ }
+ ],
+ "source": [
+ "res = model(**batch_shuffle_rev) # reverse"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "clean-favorite",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "markdown",
+ "id": "another-penetration",
+ "metadata": {},
+ "source": [
+ "##### Trial-2"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 24,
+ "id": "brilliant-antarctica",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " forward | \n",
+ " reverse | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " [CLS] | \n",
+ " [CLS] | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " ##wich | \n",
+ " battery | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " desk | \n",
+ " Ner | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " Collection | \n",
+ " ought | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " fisher | \n",
+ " Mug | \n",
+ "
\n",
+ " \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ "
\n",
+ " \n",
+ " 118 | \n",
+ " Mug | \n",
+ " fisher | \n",
+ "
\n",
+ " \n",
+ " 119 | \n",
+ " ought | \n",
+ " Collection | \n",
+ "
\n",
+ " \n",
+ " 120 | \n",
+ " Ner | \n",
+ " desk | \n",
+ "
\n",
+ " \n",
+ " 121 | \n",
+ " battery | \n",
+ " ##wich | \n",
+ "
\n",
+ " \n",
+ " 122 | \n",
+ " [SEP] | \n",
+ " [SEP] | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
123 rows × 2 columns
\n",
+ "
"
+ ],
+ "text/plain": [
+ " forward reverse\n",
+ "0 [CLS] [CLS]\n",
+ "1 ##wich battery\n",
+ "2 desk Ner\n",
+ "3 Collection ought\n",
+ "4 fisher Mug\n",
+ ".. ... ...\n",
+ "118 Mug fisher\n",
+ "119 ought Collection\n",
+ "120 Ner desk\n",
+ "121 battery ##wich\n",
+ "122 [SEP] [SEP]\n",
+ "\n",
+ "[123 rows x 2 columns]"
+ ]
+ },
+ "execution_count": 24,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "pandas_check(batch_shuffle, batch_shuffle_rev)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 25,
+ "id": "animal-break",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "l2r_loss tensor(14.4735, grad_fn=)\n",
+ "r2l_loss tensor(13.7806, grad_fn=)\n"
+ ]
+ }
+ ],
+ "source": [
+ "res = model(**batch_shuffle) # forward"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 26,
+ "id": "unlikely-table",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "l2r_loss tensor(14.0861, grad_fn=)\n",
+ "r2l_loss tensor(13.5071, grad_fn=)\n"
+ ]
+ }
+ ],
+ "source": [
+ "res = model(**batch_shuffle_rev) # reverse"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "boolean-supplier",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "markdown",
+ "id": "mexican-canadian",
+ "metadata": {},
+ "source": [
+ "##### Trial-3"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 27,
+ "id": "assured-bahrain",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " forward | \n",
+ " reverse | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " [CLS] | \n",
+ " [CLS] | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " ##wich | \n",
+ " battery | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " desk | \n",
+ " Ner | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " Collection | \n",
+ " ought | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " fisher | \n",
+ " Mug | \n",
+ "
\n",
+ " \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ "
\n",
+ " \n",
+ " 118 | \n",
+ " Mug | \n",
+ " fisher | \n",
+ "
\n",
+ " \n",
+ " 119 | \n",
+ " ought | \n",
+ " Collection | \n",
+ "
\n",
+ " \n",
+ " 120 | \n",
+ " Ner | \n",
+ " desk | \n",
+ "
\n",
+ " \n",
+ " 121 | \n",
+ " battery | \n",
+ " ##wich | \n",
+ "
\n",
+ " \n",
+ " 122 | \n",
+ " [SEP] | \n",
+ " [SEP] | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
123 rows × 2 columns
\n",
+ "
"
+ ],
+ "text/plain": [
+ " forward reverse\n",
+ "0 [CLS] [CLS]\n",
+ "1 ##wich battery\n",
+ "2 desk Ner\n",
+ "3 Collection ought\n",
+ "4 fisher Mug\n",
+ ".. ... ...\n",
+ "118 Mug fisher\n",
+ "119 ought Collection\n",
+ "120 Ner desk\n",
+ "121 battery ##wich\n",
+ "122 [SEP] [SEP]\n",
+ "\n",
+ "[123 rows x 2 columns]"
+ ]
+ },
+ "execution_count": 27,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "pandas_check(batch_shuffle, batch_shuffle_rev)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 28,
+ "id": "promotional-panel",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "l2r_loss tensor(14.4735, grad_fn=)\n",
+ "r2l_loss tensor(13.7806, grad_fn=)\n"
+ ]
+ }
+ ],
+ "source": [
+ "res = model(**batch_shuffle) # forward"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 29,
+ "id": "moving-karma",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "l2r_loss tensor(14.0861, grad_fn=)\n",
+ "r2l_loss tensor(13.5071, grad_fn=)\n"
+ ]
+ }
+ ],
+ "source": [
+ "res = model(**batch_shuffle_rev) # reverse"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "clean-gibson",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.8.5"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/elmo-sanity-check-v2.ipynb b/elmo-sanity-check-v2.ipynb
new file mode 100644
index 0000000..9004f6f
--- /dev/null
+++ b/elmo-sanity-check-v2.ipynb
@@ -0,0 +1,1104 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "working-fisher",
+ "metadata": {},
+ "source": [
+ "## Prepare Data + Model"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "quiet-macro",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# !cat examples/data/text_forward.txt"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "resistant-denver",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# !ls -al ./outputs/en.1-percent.elmo-bert-causal-fixr2l"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "egyptian-association",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "robust-strategy",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import torch\n",
+ "from newlm.lm.elmo.modeling_elmo.elmo_head import ELMOBertLMHeadModel\n",
+ "from newlm.lm.elmo.lm_builder import ELMOLMBuilder\n",
+ "from transformers import BertConfig"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "beneficial-franchise",
+ "metadata": {},
+ "source": [
+ "#### Model"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "supported-strategy",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "model = ELMOBertLMHeadModel.from_pretrained(\n",
+ " \"./outputs/en.1-percent.elmo-bert-causal-fixr2l\"\n",
+ ") # use pre-trained model"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "emotional-moisture",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Model in eval mode for consistency\n"
+ ]
+ }
+ ],
+ "source": [
+ "model.eval()\n",
+ "print(\"Model in eval mode for consistency\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "following-execution",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "markdown",
+ "id": "sized-enemy",
+ "metadata": {},
+ "source": [
+ "#### Data"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "id": "previous-crack",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "2021-11-13 00:29:52.826 | INFO | newlm.lm.elmo.lm_builder:_get_dataset:142 - Constructing roBERTa style dataset\n"
+ ]
+ }
+ ],
+ "source": [
+ "%%capture\n",
+ "\n",
+ "from newlm.utils.file_util import read_from_yaml\n",
+ "config_file = read_from_yaml('examples/configs/run.1-percent-elmo-bert-causal.yaml')\n",
+ "\n",
+ "# lm builder (helper)\n",
+ "elmo_lm_builder = ELMOLMBuilder(\n",
+ " model_config = config_file['lm']['model']['config'],\n",
+ " tokenizer=\"./outputs/en.1-percent.elmo-bert-causal-fixr2l\", # use pre-trained tokenizer\n",
+ " model_type=\"bert-causal-elmo\",\n",
+ " max_len=128\n",
+ ")\n",
+ "\n",
+ "# dataset-forward\n",
+ "train_path = \"./examples/data/text_forward-small.txt\"\n",
+ "ds_f = elmo_lm_builder._get_dataset(train_path)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "id": "mighty-major",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "max_steps is given, it will override any value given in num_train_epochs\n"
+ ]
+ }
+ ],
+ "source": [
+ "# trainer (helper)\n",
+ "from transformers import TrainingArguments, Trainer\n",
+ "args = TrainingArguments(output_dir=\"tmpout\",**config_file['lm']['hf_trainer']['args'])\n",
+ "\n",
+ "# dataloader-forward\n",
+ "trainer = Trainer(model=model, args=args, data_collator=elmo_lm_builder.data_collator, train_dataset=ds_f,)\n",
+ "dl_f = trainer.get_train_dataloader() # Data Loader-forward"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "id": "attempted-tomato",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "torch.Size([1, 127])"
+ ]
+ },
+ "execution_count": 8,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "batch_f = next(iter(dl_f))\n",
+ "batch_f['input_ids'].shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "fallen-statement",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "id": "paperback-luther",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Model in eval mode for consistency\n"
+ ]
+ }
+ ],
+ "source": [
+ "model.eval()\n",
+ "print(\"Model in eval mode for consistency\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "fluid-product",
+ "metadata": {},
+ "source": [
+ "## Sanity Check"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "id": "dimensional-sight",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# batch_f"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "id": "environmental-experience",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import torch\n",
+ "\n",
+ "def reverse_batch(batch_f):\n",
+ " # reverse input\n",
+ " batch_f_input = torch.clone(batch_f['input_ids'])\n",
+ " batch_f_rev_input = torch.cat(\n",
+ " (\n",
+ " batch_f_input[0][0:1],\n",
+ " torch.flip(batch_f_input[0][1:-1], [0]),\n",
+ " batch_f_input[0][-1:]\n",
+ " )\n",
+ " )\n",
+ " batch_f_rev_input = batch_f_rev_input.reshape(1,-1)\n",
+ "\n",
+ " # reverse labels\n",
+ " batch_f_rev_labels = torch.clone(batch_f_rev_input)\n",
+ " \n",
+ " # batch_rev\n",
+ " batch_rev = batch_f.copy()\n",
+ " batch_rev['input_ids'] = batch_f_rev_input\n",
+ " batch_rev['labels'] = batch_f_rev_labels\n",
+ " \n",
+ " return batch_rev"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "id": "informed-termination",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import pandas as pd\n",
+ "\n",
+ "def pandas_check(batch_f, batch_rev):\n",
+ " tokens_f = elmo_lm_builder.tokenizer.convert_ids_to_tokens(batch_f['input_ids'][0])\n",
+ " tokens_f_rev = elmo_lm_builder.tokenizer.convert_ids_to_tokens(batch_rev['input_ids'][0])\n",
+ " return pd.DataFrame({\"forward\": tokens_f, \"reverse\": tokens_f_rev})"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "complex-music",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "markdown",
+ "id": "tight-bennett",
+ "metadata": {},
+ "source": [
+ "#### Normal vs Reverse"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "id": "respective-beauty",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "batch_rev = reverse_batch(batch_f)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "id": "direct-warren",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " forward | \n",
+ " reverse | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " [CLS] | \n",
+ " [CLS] | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " R | \n",
+ " . | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " ##ø | \n",
+ " Jews | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " ##d | \n",
+ " among | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " ##berg | \n",
+ " lived | \n",
+ "
\n",
+ " \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ "
\n",
+ " \n",
+ " 122 | \n",
+ " lived | \n",
+ " ##berg | \n",
+ "
\n",
+ " \n",
+ " 123 | \n",
+ " among | \n",
+ " ##d | \n",
+ "
\n",
+ " \n",
+ " 124 | \n",
+ " Jews | \n",
+ " ##ø | \n",
+ "
\n",
+ " \n",
+ " 125 | \n",
+ " . | \n",
+ " R | \n",
+ "
\n",
+ " \n",
+ " 126 | \n",
+ " [SEP] | \n",
+ " [SEP] | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
127 rows × 2 columns
\n",
+ "
"
+ ],
+ "text/plain": [
+ " forward reverse\n",
+ "0 [CLS] [CLS]\n",
+ "1 R .\n",
+ "2 ##ø Jews\n",
+ "3 ##d among\n",
+ "4 ##berg lived\n",
+ ".. ... ...\n",
+ "122 lived ##berg\n",
+ "123 among ##d\n",
+ "124 Jews ##ø\n",
+ "125 . R\n",
+ "126 [SEP] [SEP]\n",
+ "\n",
+ "[127 rows x 2 columns]"
+ ]
+ },
+ "execution_count": 14,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "pandas_check(batch_f, batch_rev)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "id": "fourth-salmon",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "(torch.Size([1, 127]), torch.Size([1, 127]))"
+ ]
+ },
+ "execution_count": 15,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "batch_f['input_ids'].shape, batch_rev['input_ids'].shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "id": "violent-shell",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "l2r_loss tensor(4.4801, grad_fn=)\n",
+ "r2l_loss tensor(4.3589, grad_fn=)\n"
+ ]
+ }
+ ],
+ "source": [
+ "res = model(**batch_f) # forward"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "id": "related-mambo",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "l2r_loss tensor(9.1742, grad_fn=)\n",
+ "r2l_loss tensor(9.0804, grad_fn=)\n"
+ ]
+ }
+ ],
+ "source": [
+ "res = model(**batch_rev) # reverse"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "retired-bahamas",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "markdown",
+ "id": "broad-agriculture",
+ "metadata": {},
+ "source": [
+ "#### Random String"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "id": "remarkable-sailing",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "torch.Size([1, 127])"
+ ]
+ },
+ "execution_count": 18,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "batch_f['input_ids'].shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 29,
+ "id": "fifth-booth",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# shuffle data\n",
+ "batch_f_input = batch_f['input_ids']\n",
+ "batch_shuffle_input = torch.cat(\n",
+ " (\n",
+ " batch_f_input[0][0:1],\n",
+ " torch.randint(\n",
+ " low=5, # 0-4 > ['[PAD]', '[UNK]', '[CLS]', '[SEP]', '[MASK]']\n",
+ " high=29999,\n",
+ " size=(125,), # modified based on seqlen!\n",
+ " dtype=torch.long\n",
+ " ),\n",
+ " batch_f_input[0][-1:]\n",
+ " )\n",
+ ")\n",
+ "batch_shuffle_input = batch_shuffle_input.reshape(1,-1)\n",
+ "# labels\n",
+ "batch_shuffle_labels = torch.clone(batch_shuffle_input) \n",
+ "# batch_shuffle\n",
+ "batch_shuffle = batch_f.copy()\n",
+ "batch_shuffle['input_ids'] = batch_shuffle_input\n",
+ "batch_shuffle['labels'] = batch_shuffle_labels"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 30,
+ "id": "white-juice",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "batch_shuffle_rev = reverse_batch(batch_shuffle)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "addressed-collective",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "markdown",
+ "id": "abroad-johnson",
+ "metadata": {},
+ "source": [
+ "##### Trial-1"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "id": "inclusive-station",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " forward | \n",
+ " reverse | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " [CLS] | \n",
+ " [CLS] | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " Lex | \n",
+ " ##leg | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " neither | \n",
+ " squinting | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " fever | \n",
+ " ы | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " Leeds | \n",
+ " Russell | \n",
+ "
\n",
+ " \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ "
\n",
+ " \n",
+ " 122 | \n",
+ " Russell | \n",
+ " Leeds | \n",
+ "
\n",
+ " \n",
+ " 123 | \n",
+ " ы | \n",
+ " fever | \n",
+ "
\n",
+ " \n",
+ " 124 | \n",
+ " squinting | \n",
+ " neither | \n",
+ "
\n",
+ " \n",
+ " 125 | \n",
+ " ##leg | \n",
+ " Lex | \n",
+ "
\n",
+ " \n",
+ " 126 | \n",
+ " [SEP] | \n",
+ " [SEP] | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
127 rows × 2 columns
\n",
+ "
"
+ ],
+ "text/plain": [
+ " forward reverse\n",
+ "0 [CLS] [CLS]\n",
+ "1 Lex ##leg\n",
+ "2 neither squinting\n",
+ "3 fever ы\n",
+ "4 Leeds Russell\n",
+ ".. ... ...\n",
+ "122 Russell Leeds\n",
+ "123 ы fever\n",
+ "124 squinting neither\n",
+ "125 ##leg Lex\n",
+ "126 [SEP] [SEP]\n",
+ "\n",
+ "[127 rows x 2 columns]"
+ ]
+ },
+ "execution_count": 21,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "pandas_check(batch_shuffle, batch_shuffle_rev)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 22,
+ "id": "hungarian-soundtrack",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "l2r_loss tensor(13.5902, grad_fn=)\n",
+ "r2l_loss tensor(13.4444, grad_fn=)\n"
+ ]
+ }
+ ],
+ "source": [
+ "res = model(**batch_shuffle) # forward"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 23,
+ "id": "employed-yukon",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "l2r_loss tensor(13.7210, grad_fn=)\n",
+ "r2l_loss tensor(13.3258, grad_fn=)\n"
+ ]
+ }
+ ],
+ "source": [
+ "res = model(**batch_shuffle_rev) # reverse"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "voluntary-necessity",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "markdown",
+ "id": "judicial-republic",
+ "metadata": {},
+ "source": [
+ "##### Trial-2"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 26,
+ "id": "reported-sweden",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " forward | \n",
+ " reverse | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " [CLS] | \n",
+ " [CLS] | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " Byrne | \n",
+ " Clean | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " vanilla | \n",
+ " Wear | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " departed | \n",
+ " flickered | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " ##ische | \n",
+ " ##ghter | \n",
+ "
\n",
+ " \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ "
\n",
+ " \n",
+ " 122 | \n",
+ " ##ghter | \n",
+ " ##ische | \n",
+ "
\n",
+ " \n",
+ " 123 | \n",
+ " flickered | \n",
+ " departed | \n",
+ "
\n",
+ " \n",
+ " 124 | \n",
+ " Wear | \n",
+ " vanilla | \n",
+ "
\n",
+ " \n",
+ " 125 | \n",
+ " Clean | \n",
+ " Byrne | \n",
+ "
\n",
+ " \n",
+ " 126 | \n",
+ " [SEP] | \n",
+ " [SEP] | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
127 rows × 2 columns
\n",
+ "
"
+ ],
+ "text/plain": [
+ " forward reverse\n",
+ "0 [CLS] [CLS]\n",
+ "1 Byrne Clean\n",
+ "2 vanilla Wear\n",
+ "3 departed flickered\n",
+ "4 ##ische ##ghter\n",
+ ".. ... ...\n",
+ "122 ##ghter ##ische\n",
+ "123 flickered departed\n",
+ "124 Wear vanilla\n",
+ "125 Clean Byrne\n",
+ "126 [SEP] [SEP]\n",
+ "\n",
+ "[127 rows x 2 columns]"
+ ]
+ },
+ "execution_count": 26,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "pandas_check(batch_shuffle, batch_shuffle_rev)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 27,
+ "id": "equipped-ordinary",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "l2r_loss tensor(13.4740, grad_fn=)\n",
+ "r2l_loss tensor(13.1427, grad_fn=)\n"
+ ]
+ }
+ ],
+ "source": [
+ "res = model(**batch_shuffle) # forward"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 28,
+ "id": "celtic-spider",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "l2r_loss tensor(13.3491, grad_fn=)\n",
+ "r2l_loss tensor(13.3746, grad_fn=)\n"
+ ]
+ }
+ ],
+ "source": [
+ "res = model(**batch_shuffle_rev) # reverse"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "rough-kelly",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "markdown",
+ "id": "invisible-white",
+ "metadata": {},
+ "source": [
+ "##### Trial-3"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 31,
+ "id": "chubby-progressive",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " forward | \n",
+ " reverse | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " [CLS] | \n",
+ " [CLS] | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " Severance | \n",
+ " Hild | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " endorse | \n",
+ " nor | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " reviewed | \n",
+ " ##frey | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " Alone | \n",
+ " Fred | \n",
+ "
\n",
+ " \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ "
\n",
+ " \n",
+ " 122 | \n",
+ " Fred | \n",
+ " Alone | \n",
+ "
\n",
+ " \n",
+ " 123 | \n",
+ " ##frey | \n",
+ " reviewed | \n",
+ "
\n",
+ " \n",
+ " 124 | \n",
+ " nor | \n",
+ " endorse | \n",
+ "
\n",
+ " \n",
+ " 125 | \n",
+ " Hild | \n",
+ " Severance | \n",
+ "
\n",
+ " \n",
+ " 126 | \n",
+ " [SEP] | \n",
+ " [SEP] | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
127 rows × 2 columns
\n",
+ "
"
+ ],
+ "text/plain": [
+ " forward reverse\n",
+ "0 [CLS] [CLS]\n",
+ "1 Severance Hild\n",
+ "2 endorse nor\n",
+ "3 reviewed ##frey\n",
+ "4 Alone Fred\n",
+ ".. ... ...\n",
+ "122 Fred Alone\n",
+ "123 ##frey reviewed\n",
+ "124 nor endorse\n",
+ "125 Hild Severance\n",
+ "126 [SEP] [SEP]\n",
+ "\n",
+ "[127 rows x 2 columns]"
+ ]
+ },
+ "execution_count": 31,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "pandas_check(batch_shuffle, batch_shuffle_rev)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 32,
+ "id": "partial-devil",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "l2r_loss tensor(13.9457, grad_fn=)\n",
+ "r2l_loss tensor(13.5547, grad_fn=)\n"
+ ]
+ }
+ ],
+ "source": [
+ "res = model(**batch_shuffle) # forward"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 33,
+ "id": "structured-farmer",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "l2r_loss tensor(13.5305, grad_fn=)\n",
+ "r2l_loss tensor(13.4759, grad_fn=)\n"
+ ]
+ }
+ ],
+ "source": [
+ "res = model(**batch_shuffle_rev) # reverse"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "funky-space",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.8.5"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/elmo-sanity-check.ipynb b/elmo-sanity-check.ipynb
new file mode 100644
index 0000000..9dc106d
--- /dev/null
+++ b/elmo-sanity-check.ipynb
@@ -0,0 +1,669 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "laden-tsunami",
+ "metadata": {},
+ "source": [
+ "## Prepare data"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "lovely-flood",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "with open(\"./examples/data/text_forward-small.txt\", \"r+\") as fr:\n",
+ " lines = fr.readlines()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "false-bernard",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "lines = [line.strip() for line in lines]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "funded-listing",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "lines = [line.split() for line in lines]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "medical-visiting",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "lines = [line[::-1] for line in lines]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "eastern-publicity",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "lines = [\" \".join(line) for line in lines]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "id": "alien-nightmare",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "lines = lines[::-1]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "id": "acute-black",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "lines = \"\\n\".join(lines)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "id": "modified-links",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "with open(\"./examples/data/text_backward-small.txt\", \"w+\") as fw:\n",
+ " fw.write(lines)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "absolute-world",
+ "metadata": {},
+ "source": [
+ "## Sanity Check"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "id": "hollow-newsletter",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import torch\n",
+ "from newlm.lm.elmo.modeling_elmo.elmo_head import ELMOBertLMHeadModel\n",
+ "from newlm.lm.elmo.lm_builder import ELMOLMBuilder\n",
+ "from transformers import BertConfig"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "chronic-crack",
+ "metadata": {},
+ "source": [
+ "Model from scratch"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "id": "identified-ivory",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from newlm.utils.file_util import read_from_yaml\n",
+ "config_file = read_from_yaml('examples/configs/run.1-percent-bert-causal.yaml')\n",
+ "\n",
+ "elmo_lm_builder = ELMOLMBuilder(\n",
+ " model_config = config_file['lm']['model']['config'], # no pretrained model\n",
+ " tokenizer=\"./outputs/en.1-percent.elmo-bert-causal\", # use pre-trained tokenizer\n",
+ " model_type=\"bert-causal-elmo\",\n",
+ " max_len=128\n",
+ ")\n",
+ "\n",
+ "# model\n",
+ "config = BertConfig(**elmo_lm_builder.model_config)\n",
+ "model = ELMOBertLMHeadModel(config=config)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "id": "boring-workplace",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Model in eval mode for consistency\n"
+ ]
+ }
+ ],
+ "source": [
+ "model.eval()\n",
+ "print(\"Model in eval mode for consistency\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "id": "deluxe-expert",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "2021-11-12 11:07:27.881 | INFO | newlm.lm.elmo.lm_builder:_get_dataset:142 - Constructing roBERTa style dataset\n",
+ "2021-11-12 11:07:29.324 | INFO | newlm.lm.elmo.lm_builder:_get_dataset:142 - Constructing roBERTa style dataset\n"
+ ]
+ }
+ ],
+ "source": [
+ "%%capture\n",
+ "\n",
+ "# dataset-forward\n",
+ "train_path = \"./examples/data/text_forward-small.txt\"\n",
+ "ds_f = elmo_lm_builder._get_dataset(train_path)\n",
+ "\n",
+ "# dataset-backward\n",
+ "train_path = \"./examples/data/text_backward-small.txt\"\n",
+ "ds_b = elmo_lm_builder._get_dataset(train_path)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "id": "immune-quick",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "max_steps is given, it will override any value given in num_train_epochs\n",
+ "max_steps is given, it will override any value given in num_train_epochs\n"
+ ]
+ }
+ ],
+ "source": [
+ "# trainer (for helper)\n",
+ "from transformers import TrainingArguments, Trainer\n",
+ "args = TrainingArguments(output_dir=\"tmpout\",**config_file['lm']['hf_trainer']['args'])\n",
+ "\n",
+ "trainer = Trainer(model=model, args=args, data_collator=elmo_lm_builder.data_collator,\n",
+ " train_dataset=ds_f,\n",
+ ")\n",
+ "dl_f = trainer.get_train_dataloader() # Data Loader-forward\n",
+ "\n",
+ "trainer = Trainer(model=model, args=args,data_collator=elmo_lm_builder.data_collator,\n",
+ " train_dataset=ds_b,\n",
+ ")\n",
+ "dl_b = trainer.get_train_dataloader() # Data Loader-backward"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "enabling-business",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "markdown",
+ "id": "preceding-evening",
+ "metadata": {},
+ "source": [
+ "### Compare text_forward and text_backward"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "id": "promising-panama",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "(torch.Size([1, 127]), torch.Size([1, 127]))"
+ ]
+ },
+ "execution_count": 14,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "batch_f = next(iter(dl_f))\n",
+ "batch_b = next(iter(dl_b))\n",
+ "\n",
+ "batch_f['input_ids'].shape, batch_b['input_ids'].shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "id": "italic-treaty",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " forward | \n",
+ " backward | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " [CLS] | \n",
+ " [CLS] | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " R | \n",
+ " . | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " ##ø | \n",
+ " Jews | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " ##d | \n",
+ " among | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " ##berg | \n",
+ " lived | \n",
+ "
\n",
+ " \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ "
\n",
+ " \n",
+ " 122 | \n",
+ " lived | \n",
+ " R | \n",
+ "
\n",
+ " \n",
+ " 123 | \n",
+ " among | \n",
+ " ##ø | \n",
+ "
\n",
+ " \n",
+ " 124 | \n",
+ " Jews | \n",
+ " ##d | \n",
+ "
\n",
+ " \n",
+ " 125 | \n",
+ " . | \n",
+ " ##berg | \n",
+ "
\n",
+ " \n",
+ " 126 | \n",
+ " [SEP] | \n",
+ " [SEP] | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
127 rows × 2 columns
\n",
+ "
"
+ ],
+ "text/plain": [
+ " forward backward\n",
+ "0 [CLS] [CLS]\n",
+ "1 R .\n",
+ "2 ##ø Jews\n",
+ "3 ##d among\n",
+ "4 ##berg lived\n",
+ ".. ... ...\n",
+ "122 lived R\n",
+ "123 among ##ø\n",
+ "124 Jews ##d\n",
+ "125 . ##berg\n",
+ "126 [SEP] [SEP]\n",
+ "\n",
+ "[127 rows x 2 columns]"
+ ]
+ },
+ "execution_count": 15,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "tokens_f = elmo_lm_builder.tokenizer.convert_ids_to_tokens(batch_f['input_ids'][0])\n",
+ "tokens_b = elmo_lm_builder.tokenizer.convert_ids_to_tokens(batch_b['input_ids'][0])\n",
+ "\n",
+ "import pandas as pd\n",
+ "pd.DataFrame({\"forward\": tokens_f, \"backward\": tokens_b})"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "attached-nursing",
+ "metadata": {},
+ "source": [
+ "Here we can see that the data is not completely flip when the tokenizer couldn't parse a single word into a single id"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "id": "graphic-paintball",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "l2r_loss tensor(10.5425, grad_fn=)\n",
+ "r2l_loss tensor(10.4237, grad_fn=)\n"
+ ]
+ }
+ ],
+ "source": [
+ "res = model(**batch_f)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "id": "frequent-scheduling",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "l2r_loss tensor(10.5305, grad_fn=)\n",
+ "r2l_loss tensor(10.4198, grad_fn=)\n"
+ ]
+ }
+ ],
+ "source": [
+ "res = model(**batch_b)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "amber-auckland",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "markdown",
+ "id": "unique-jungle",
+ "metadata": {},
+ "source": [
+ "### From batch_forward compare Normal vs Rev"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "id": "right-coalition",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import torch\n",
+ "\n",
+ "batch_f_input = torch.clone(batch_f['input_ids'])\n",
+ "batch_f_rev_input = torch.cat(\n",
+ " (\n",
+ " batch_f_input[0][0:1],\n",
+ " torch.flip(batch_f_input[0][1:-1], [0]),\n",
+ " batch_f_input[0][-1:]\n",
+ " )\n",
+ ")\n",
+ "batch_f_rev_input = batch_f_rev_input.reshape(1,-1)\n",
+ "\n",
+ "batch_f_labels = torch.clone(batch_f['labels'])\n",
+ "batch_f_rev_labels = torch.cat(\n",
+ " (\n",
+ " batch_f_labels[0][0:1],\n",
+ " torch.flip(batch_f_labels[0][1:-1], [0]),\n",
+ " batch_f_labels[0][-1:]\n",
+ " )\n",
+ ")\n",
+ "batch_f_rev_labels = batch_f_rev_labels.reshape(1,-1)\n",
+ "\n",
+ "batch_rev = batch_f.copy()\n",
+ "batch_rev['input_ids'] = batch_f_rev_input\n",
+ "batch_rev['labels'] = batch_f_rev_labels"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "id": "terminal-withdrawal",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " forward | \n",
+ " reverse | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " [CLS] | \n",
+ " [CLS] | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " R | \n",
+ " . | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " ##ø | \n",
+ " Jews | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " ##d | \n",
+ " among | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " ##berg | \n",
+ " lived | \n",
+ "
\n",
+ " \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ "
\n",
+ " \n",
+ " 122 | \n",
+ " lived | \n",
+ " ##berg | \n",
+ "
\n",
+ " \n",
+ " 123 | \n",
+ " among | \n",
+ " ##d | \n",
+ "
\n",
+ " \n",
+ " 124 | \n",
+ " Jews | \n",
+ " ##ø | \n",
+ "
\n",
+ " \n",
+ " 125 | \n",
+ " . | \n",
+ " R | \n",
+ "
\n",
+ " \n",
+ " 126 | \n",
+ " [SEP] | \n",
+ " [SEP] | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
127 rows × 2 columns
\n",
+ "
"
+ ],
+ "text/plain": [
+ " forward reverse\n",
+ "0 [CLS] [CLS]\n",
+ "1 R .\n",
+ "2 ##ø Jews\n",
+ "3 ##d among\n",
+ "4 ##berg lived\n",
+ ".. ... ...\n",
+ "122 lived ##berg\n",
+ "123 among ##d\n",
+ "124 Jews ##ø\n",
+ "125 . R\n",
+ "126 [SEP] [SEP]\n",
+ "\n",
+ "[127 rows x 2 columns]"
+ ]
+ },
+ "execution_count": 19,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "tokens_f = elmo_lm_builder.tokenizer.convert_ids_to_tokens(batch_f['input_ids'][0])\n",
+ "tokens_f_rev = elmo_lm_builder.tokenizer.convert_ids_to_tokens(batch_rev['input_ids'][0])\n",
+ "\n",
+ "import pandas as pd\n",
+ "pd.DataFrame({\"forward\": tokens_f, \"reverse\": tokens_f_rev})"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "continued-recommendation",
+ "metadata": {},
+ "source": [
+ "with the exception of [CLS] and [SEP], the data are completely flip"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 20,
+ "id": "narrow-pipeline",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "l2r_loss tensor(10.5425, grad_fn=)\n",
+ "r2l_loss tensor(10.4237, grad_fn=)\n"
+ ]
+ }
+ ],
+ "source": [
+ "res = model(**batch_f)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "id": "desirable-peripheral",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "l2r_loss tensor(10.5323, grad_fn=)\n",
+ "r2l_loss tensor(10.4094, grad_fn=)\n"
+ ]
+ }
+ ],
+ "source": [
+ "res = model(**batch_rev)"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.8.5"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/newlm/lm/elmo/lm_builder.py b/newlm/lm/elmo/lm_builder.py
index 590a05f..28bf661 100644
--- a/newlm/lm/elmo/lm_builder.py
+++ b/newlm/lm/elmo/lm_builder.py
@@ -135,7 +135,7 @@ def preprocess_function(examples):
encoded_dataset = dataset.map(preprocess_function, batched=True)
return encoded_dataset["train"]
- def __get_dataset(self, train_path):
+ def _get_dataset(self, train_path):
dataset = self.__get_dataset_via_ds(train_path)["input_ids"]
print(len(dataset))
diff --git a/newlm/lm/elmo/modeling_elmo/elmo_head.py b/newlm/lm/elmo/modeling_elmo/elmo_head.py
index 4683fc7..fc3e0cd 100644
--- a/newlm/lm/elmo/modeling_elmo/elmo_head.py
+++ b/newlm/lm/elmo/modeling_elmo/elmo_head.py
@@ -190,6 +190,8 @@ def forward(
last_hidden_states, flip_labels, r2l=True
)
+ print("l2r_loss", l2r_loss)
+ print("r2l_loss", r2l_loss)
total_loss = l2r_loss + r2l_loss if labels is not None else None
return ElmoGPTCausalLMOutput(