diff --git a/elmo-sanity-check-v2-elmo_40k.ipynb b/elmo-sanity-check-v2-elmo_40k.ipynb new file mode 100644 index 0000000..d2c4073 --- /dev/null +++ b/elmo-sanity-check-v2-elmo_40k.ipynb @@ -0,0 +1,1104 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "refined-creek", + "metadata": {}, + "source": [ + "## Prepare Data + Model" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "generous-timothy", + "metadata": {}, + "outputs": [], + "source": [ + "# !cat examples/data/text_forward.txt" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "valid-lingerie", + "metadata": {}, + "outputs": [], + "source": [ + "# !ls -al ./outputs/en.1-percent.elmo-bert-causal-fixr2l" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "optional-motivation", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "tired-commodity", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from newlm.lm.elmo.modeling_elmo.elmo_head import ELMOBertLMHeadModel\n", + "from newlm.lm.elmo.lm_builder import ELMOLMBuilder\n", + "from transformers import BertConfig" + ] + }, + { + "cell_type": "markdown", + "id": "outdoor-recognition", + "metadata": {}, + "source": [ + "#### Model" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "fifty-projector", + "metadata": {}, + "outputs": [], + "source": [ + "model = ELMOBertLMHeadModel.from_pretrained(\n", + " \"./outputs/en.100-percent.elmo-bert-causal.40k\"\n", + ") # use pre-trained model" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "former-neighbor", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Model in eval mode for consistency\n" + ] + } + ], + "source": [ + "model.eval()\n", + "print(\"Model in eval mode for consistency\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ultimate-posting", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "id": "living-consumption", + "metadata": {}, + "source": [ + "#### Data" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "italic-resort", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2021-11-22 19:46:05.850 | INFO | newlm.lm.elmo.lm_builder:_get_dataset:142 - Constructing roBERTa style dataset\n" + ] + } + ], + "source": [ + "%%capture\n", + "\n", + "from newlm.utils.file_util import read_from_yaml\n", + "config_file = read_from_yaml('examples/configs_gcloud/run-100-percent.elmo-bert-causal.yaml')\n", + "\n", + "# lm builder (helper)\n", + "elmo_lm_builder = ELMOLMBuilder(\n", + " model_config = config_file['lm']['model']['config'],\n", + " tokenizer=\"./outputs/en.100-percent.elmo-bert-causal.40k\", # use pre-trained tokenizer\n", + " model_type=\"bert-causal-elmo\",\n", + " max_len=128\n", + ")\n", + "\n", + "# dataset-forward\n", + "train_path = \"./examples/data/text_forward-small.txt\"\n", + "ds_f = elmo_lm_builder._get_dataset(train_path)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "forward-litigation", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "max_steps is given, it will override any value given in num_train_epochs\n" + ] + } + ], + "source": [ + "# trainer (helper)\n", + "from transformers import TrainingArguments, Trainer\n", + "args = TrainingArguments(output_dir=\"tmpout\",**config_file['lm']['hf_trainer']['args'])\n", + "\n", + "# dataloader-forward\n", + "trainer = Trainer(model=model, args=args, data_collator=elmo_lm_builder.data_collator, train_dataset=ds_f,)\n", + "dl_f = trainer.get_train_dataloader() # Data Loader-forward" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "subject-exposure", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([1, 123])" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "batch_f = next(iter(dl_f))\n", + "batch_f['input_ids'].shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "supposed-relief", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "international-debate", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Model in eval mode for consistency\n" + ] + } + ], + "source": [ + "model.eval()\n", + "print(\"Model in eval mode for consistency\")" + ] + }, + { + "cell_type": "markdown", + "id": "substantial-warning", + "metadata": {}, + "source": [ + "## Sanity Check" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "weekly-reverse", + "metadata": {}, + "outputs": [], + "source": [ + "# batch_f" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "optimum-italy", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "\n", + "def reverse_batch(batch_f):\n", + " # reverse input\n", + " batch_f_input = torch.clone(batch_f['input_ids'])\n", + " batch_f_rev_input = torch.cat(\n", + " (\n", + " batch_f_input[0][0:1],\n", + " torch.flip(batch_f_input[0][1:-1], [0]),\n", + " batch_f_input[0][-1:]\n", + " )\n", + " )\n", + " batch_f_rev_input = batch_f_rev_input.reshape(1,-1)\n", + "\n", + " # reverse labels\n", + " batch_f_rev_labels = torch.clone(batch_f_rev_input)\n", + " \n", + " # batch_rev\n", + " batch_rev = batch_f.copy()\n", + " batch_rev['input_ids'] = batch_f_rev_input\n", + " batch_rev['labels'] = batch_f_rev_labels\n", + " \n", + " return batch_rev" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "eastern-transfer", + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "\n", + "def pandas_check(batch_f, batch_rev):\n", + " tokens_f = elmo_lm_builder.tokenizer.convert_ids_to_tokens(batch_f['input_ids'][0])\n", + " tokens_f_rev = elmo_lm_builder.tokenizer.convert_ids_to_tokens(batch_rev['input_ids'][0])\n", + " return pd.DataFrame({\"forward\": tokens_f, \"reverse\": tokens_f_rev})" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "flush-syracuse", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "id": "human-reynolds", + "metadata": {}, + "source": [ + "#### Normal vs Reverse" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "sexual-meditation", + "metadata": {}, + "outputs": [], + "source": [ + "batch_rev = reverse_batch(batch_f)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "rotary-ontario", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
forwardreverse
0[CLS][CLS]
1[UNK].
2isJews
3locatedamong
4inlived
.........
118livedin
119amonglocated
120Jewsis
121.[UNK]
122[SEP][SEP]
\n", + "

123 rows × 2 columns

\n", + "
" + ], + "text/plain": [ + " forward reverse\n", + "0 [CLS] [CLS]\n", + "1 [UNK] .\n", + "2 is Jews\n", + "3 located among\n", + "4 in lived\n", + ".. ... ...\n", + "118 lived in\n", + "119 among located\n", + "120 Jews is\n", + "121 . [UNK]\n", + "122 [SEP] [SEP]\n", + "\n", + "[123 rows x 2 columns]" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pandas_check(batch_f, batch_rev)" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "beginning-committee", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(torch.Size([1, 123]), torch.Size([1, 123]))" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "batch_f['input_ids'].shape, batch_rev['input_ids'].shape" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "cultural-block", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "l2r_loss tensor(3.9125, grad_fn=)\n", + "r2l_loss tensor(3.8753, grad_fn=)\n" + ] + } + ], + "source": [ + "res = model(**batch_f) # forward" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "sublime-rider", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "l2r_loss tensor(8.8075, grad_fn=)\n", + "r2l_loss tensor(8.8464, grad_fn=)\n" + ] + } + ], + "source": [ + "res = model(**batch_rev) # reverse" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "european-examination", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "id": "synthetic-volleyball", + "metadata": {}, + "source": [ + "#### Random String" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "ecological-advocate", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([1, 123])" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "batch_f['input_ids'].shape" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "cosmetic-shame", + "metadata": {}, + "outputs": [], + "source": [ + "# shuffle data\n", + "batch_f_input = batch_f['input_ids']\n", + "batch_shuffle_input = torch.cat(\n", + " (\n", + " batch_f_input[0][0:1],\n", + " torch.randint(\n", + " low=5, # 0-4 > ['[PAD]', '[UNK]', '[CLS]', '[SEP]', '[MASK]']\n", + " high=29999,\n", + " size=(121,), # modified based on seqlen!\n", + " dtype=torch.long\n", + " ),\n", + " batch_f_input[0][-1:]\n", + " )\n", + ")\n", + "batch_shuffle_input = batch_shuffle_input.reshape(1,-1)\n", + "# labels\n", + "batch_shuffle_labels = torch.clone(batch_shuffle_input) \n", + "# batch_shuffle\n", + "batch_shuffle = batch_f.copy()\n", + "batch_shuffle['input_ids'] = batch_shuffle_input\n", + "batch_shuffle['labels'] = batch_shuffle_labels" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "respective-receiver", + "metadata": {}, + "outputs": [], + "source": [ + "batch_shuffle_rev = reverse_batch(batch_shuffle)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "religious-camping", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "id": "advisory-subject", + "metadata": {}, + "source": [ + "##### Trial-1" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "rotary-reduction", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
forwardreverse
0[CLS][CLS]
1##wichbattery
2deskNer
3Collectionought
4fisherMug
.........
118Mugfisher
119oughtCollection
120Nerdesk
121battery##wich
122[SEP][SEP]
\n", + "

123 rows × 2 columns

\n", + "
" + ], + "text/plain": [ + " forward reverse\n", + "0 [CLS] [CLS]\n", + "1 ##wich battery\n", + "2 desk Ner\n", + "3 Collection ought\n", + "4 fisher Mug\n", + ".. ... ...\n", + "118 Mug fisher\n", + "119 ought Collection\n", + "120 Ner desk\n", + "121 battery ##wich\n", + "122 [SEP] [SEP]\n", + "\n", + "[123 rows x 2 columns]" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pandas_check(batch_shuffle, batch_shuffle_rev)" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "presidential-impossible", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "l2r_loss tensor(13.4756, grad_fn=)\n", + "r2l_loss tensor(13.2225, grad_fn=)\n" + ] + } + ], + "source": [ + "res = model(**batch_shuffle) # forward" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "figured-carroll", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "l2r_loss tensor(13.2165, grad_fn=)\n", + "r2l_loss tensor(13.3072, grad_fn=)\n" + ] + } + ], + "source": [ + "res = model(**batch_shuffle_rev) # reverse" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fabulous-dutch", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "id": "dominant-falls", + "metadata": {}, + "source": [ + "##### Trial-2" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "professional-visitor", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
forwardreverse
0[CLS][CLS]
1##wichbattery
2deskNer
3Collectionought
4fisherMug
.........
118Mugfisher
119oughtCollection
120Nerdesk
121battery##wich
122[SEP][SEP]
\n", + "

123 rows × 2 columns

\n", + "
" + ], + "text/plain": [ + " forward reverse\n", + "0 [CLS] [CLS]\n", + "1 ##wich battery\n", + "2 desk Ner\n", + "3 Collection ought\n", + "4 fisher Mug\n", + ".. ... ...\n", + "118 Mug fisher\n", + "119 ought Collection\n", + "120 Ner desk\n", + "121 battery ##wich\n", + "122 [SEP] [SEP]\n", + "\n", + "[123 rows x 2 columns]" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pandas_check(batch_shuffle, batch_shuffle_rev)" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "reflected-thursday", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "l2r_loss tensor(13.4756, grad_fn=)\n", + "r2l_loss tensor(13.2225, grad_fn=)\n" + ] + } + ], + "source": [ + "res = model(**batch_shuffle) # forward" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "stupid-violin", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "l2r_loss tensor(13.2165, grad_fn=)\n", + "r2l_loss tensor(13.3072, grad_fn=)\n" + ] + } + ], + "source": [ + "res = model(**batch_shuffle_rev) # reverse" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "quantitative-memory", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "id": "marine-settle", + "metadata": {}, + "source": [ + "##### Trial-3" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "informal-wright", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
forwardreverse
0[CLS][CLS]
1##wichbattery
2deskNer
3Collectionought
4fisherMug
.........
118Mugfisher
119oughtCollection
120Nerdesk
121battery##wich
122[SEP][SEP]
\n", + "

123 rows × 2 columns

\n", + "
" + ], + "text/plain": [ + " forward reverse\n", + "0 [CLS] [CLS]\n", + "1 ##wich battery\n", + "2 desk Ner\n", + "3 Collection ought\n", + "4 fisher Mug\n", + ".. ... ...\n", + "118 Mug fisher\n", + "119 ought Collection\n", + "120 Ner desk\n", + "121 battery ##wich\n", + "122 [SEP] [SEP]\n", + "\n", + "[123 rows x 2 columns]" + ] + }, + "execution_count": 27, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pandas_check(batch_shuffle, batch_shuffle_rev)" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "devoted-brush", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "l2r_loss tensor(13.4756, grad_fn=)\n", + "r2l_loss tensor(13.2225, grad_fn=)\n" + ] + } + ], + "source": [ + "res = model(**batch_shuffle) # forward" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "id": "supported-entrance", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "l2r_loss tensor(13.2165, grad_fn=)\n", + "r2l_loss tensor(13.3072, grad_fn=)\n" + ] + } + ], + "source": [ + "res = model(**batch_shuffle_rev) # reverse" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "inclusive-poker", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.5" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/elmo-sanity-check-v2-elmo_small_40k.ipynb b/elmo-sanity-check-v2-elmo_small_40k.ipynb new file mode 100644 index 0000000..679bcb5 --- /dev/null +++ b/elmo-sanity-check-v2-elmo_small_40k.ipynb @@ -0,0 +1,1104 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "swedish-breast", + "metadata": {}, + "source": [ + "## Prepare Data + Model" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "former-delta", + "metadata": {}, + "outputs": [], + "source": [ + "# !cat examples/data/text_forward.txt" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "noted-jacksonville", + "metadata": {}, + "outputs": [], + "source": [ + "# !ls -al ./outputs/en.1-percent.elmo-bert-causal-fixr2l" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "requested-nudist", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "atomic-orbit", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from newlm.lm.elmo.modeling_elmo.elmo_head import ELMOBertLMHeadModel\n", + "from newlm.lm.elmo.lm_builder import ELMOLMBuilder\n", + "from transformers import BertConfig" + ] + }, + { + "cell_type": "markdown", + "id": "empirical-forestry", + "metadata": {}, + "source": [ + "#### Model" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "outside-consultancy", + "metadata": {}, + "outputs": [], + "source": [ + "model = ELMOBertLMHeadModel.from_pretrained(\n", + " \"./outputs/en.100-percent.elmo-small-bert-causal.40k\"\n", + ") # use pre-trained model" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "fewer-harmony", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Model in eval mode for consistency\n" + ] + } + ], + "source": [ + "model.eval()\n", + "print(\"Model in eval mode for consistency\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "impressive-convert", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "id": "radical-secret", + "metadata": {}, + "source": [ + "#### Data" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "sexual-wisconsin", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2021-11-22 19:37:37.967 | INFO | newlm.lm.elmo.lm_builder:_get_dataset:142 - Constructing roBERTa style dataset\n" + ] + } + ], + "source": [ + "%%capture\n", + "\n", + "from newlm.utils.file_util import read_from_yaml\n", + "config_file = read_from_yaml('examples/configs_gcloud/run-100-percent.elmo-small-bert-causal.yaml')\n", + "\n", + "# lm builder (helper)\n", + "elmo_lm_builder = ELMOLMBuilder(\n", + " model_config = config_file['lm']['model']['config'],\n", + " tokenizer=\"./outputs/en.100-percent.elmo-small-bert-causal.40k\", # use pre-trained tokenizer\n", + " model_type=\"bert-causal-elmo\",\n", + " max_len=128\n", + ")\n", + "\n", + "# dataset-forward\n", + "train_path = \"./examples/data/text_forward-small.txt\"\n", + "ds_f = elmo_lm_builder._get_dataset(train_path)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "sophisticated-texas", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "max_steps is given, it will override any value given in num_train_epochs\n" + ] + } + ], + "source": [ + "# trainer (helper)\n", + "from transformers import TrainingArguments, Trainer\n", + "args = TrainingArguments(output_dir=\"tmpout\",**config_file['lm']['hf_trainer']['args'])\n", + "\n", + "# dataloader-forward\n", + "trainer = Trainer(model=model, args=args, data_collator=elmo_lm_builder.data_collator, train_dataset=ds_f,)\n", + "dl_f = trainer.get_train_dataloader() # Data Loader-forward" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "attractive-insight", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([1, 123])" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "batch_f = next(iter(dl_f))\n", + "batch_f['input_ids'].shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "incorporate-integration", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "original-voluntary", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Model in eval mode for consistency\n" + ] + } + ], + "source": [ + "model.eval()\n", + "print(\"Model in eval mode for consistency\")" + ] + }, + { + "cell_type": "markdown", + "id": "empirical-accessory", + "metadata": {}, + "source": [ + "## Sanity Check" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "impaired-soviet", + "metadata": {}, + "outputs": [], + "source": [ + "# batch_f" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "activated-platinum", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "\n", + "def reverse_batch(batch_f):\n", + " # reverse input\n", + " batch_f_input = torch.clone(batch_f['input_ids'])\n", + " batch_f_rev_input = torch.cat(\n", + " (\n", + " batch_f_input[0][0:1],\n", + " torch.flip(batch_f_input[0][1:-1], [0]),\n", + " batch_f_input[0][-1:]\n", + " )\n", + " )\n", + " batch_f_rev_input = batch_f_rev_input.reshape(1,-1)\n", + "\n", + " # reverse labels\n", + " batch_f_rev_labels = torch.clone(batch_f_rev_input)\n", + " \n", + " # batch_rev\n", + " batch_rev = batch_f.copy()\n", + " batch_rev['input_ids'] = batch_f_rev_input\n", + " batch_rev['labels'] = batch_f_rev_labels\n", + " \n", + " return batch_rev" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "adopted-tamil", + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "\n", + "def pandas_check(batch_f, batch_rev):\n", + " tokens_f = elmo_lm_builder.tokenizer.convert_ids_to_tokens(batch_f['input_ids'][0])\n", + " tokens_f_rev = elmo_lm_builder.tokenizer.convert_ids_to_tokens(batch_rev['input_ids'][0])\n", + " return pd.DataFrame({\"forward\": tokens_f, \"reverse\": tokens_f_rev})" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "opponent-grant", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "id": "electoral-effort", + "metadata": {}, + "source": [ + "#### Normal vs Reverse" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "elegant-occasions", + "metadata": {}, + "outputs": [], + "source": [ + "batch_rev = reverse_batch(batch_f)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "demanding-wealth", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
forwardreverse
0[CLS][CLS]
1[UNK].
2isJews
3locatedamong
4inlived
.........
118livedin
119amonglocated
120Jewsis
121.[UNK]
122[SEP][SEP]
\n", + "

123 rows × 2 columns

\n", + "
" + ], + "text/plain": [ + " forward reverse\n", + "0 [CLS] [CLS]\n", + "1 [UNK] .\n", + "2 is Jews\n", + "3 located among\n", + "4 in lived\n", + ".. ... ...\n", + "118 lived in\n", + "119 among located\n", + "120 Jews is\n", + "121 . [UNK]\n", + "122 [SEP] [SEP]\n", + "\n", + "[123 rows x 2 columns]" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pandas_check(batch_f, batch_rev)" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "stuffed-lightweight", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(torch.Size([1, 123]), torch.Size([1, 123]))" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "batch_f['input_ids'].shape, batch_rev['input_ids'].shape" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "norman-pricing", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "l2r_loss tensor(4.0284, grad_fn=)\n", + "r2l_loss tensor(3.9578, grad_fn=)\n" + ] + } + ], + "source": [ + "res = model(**batch_f) # forward" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "divine-joseph", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "l2r_loss tensor(9.0815, grad_fn=)\n", + "r2l_loss tensor(9.1060, grad_fn=)\n" + ] + } + ], + "source": [ + "res = model(**batch_rev) # reverse" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "hairy-logan", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "id": "expanded-portal", + "metadata": {}, + "source": [ + "#### Random String" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "understanding-knock", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([1, 123])" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "batch_f['input_ids'].shape" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "raising-deviation", + "metadata": {}, + "outputs": [], + "source": [ + "# shuffle data\n", + "batch_f_input = batch_f['input_ids']\n", + "batch_shuffle_input = torch.cat(\n", + " (\n", + " batch_f_input[0][0:1],\n", + " torch.randint(\n", + " low=5, # 0-4 > ['[PAD]', '[UNK]', '[CLS]', '[SEP]', '[MASK]']\n", + " high=29999,\n", + " size=(121,), # modified based on seqlen!\n", + " dtype=torch.long\n", + " ),\n", + " batch_f_input[0][-1:]\n", + " )\n", + ")\n", + "batch_shuffle_input = batch_shuffle_input.reshape(1,-1)\n", + "# labels\n", + "batch_shuffle_labels = torch.clone(batch_shuffle_input) \n", + "# batch_shuffle\n", + "batch_shuffle = batch_f.copy()\n", + "batch_shuffle['input_ids'] = batch_shuffle_input\n", + "batch_shuffle['labels'] = batch_shuffle_labels" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "headed-aside", + "metadata": {}, + "outputs": [], + "source": [ + "batch_shuffle_rev = reverse_batch(batch_shuffle)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "distinct-steel", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "id": "dirty-doctor", + "metadata": {}, + "source": [ + "##### Trial-1" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "virtual-church", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
forwardreverse
0[CLS][CLS]
1##wichbattery
2deskNer
3Collectionought
4fisherMug
.........
118Mugfisher
119oughtCollection
120Nerdesk
121battery##wich
122[SEP][SEP]
\n", + "

123 rows × 2 columns

\n", + "
" + ], + "text/plain": [ + " forward reverse\n", + "0 [CLS] [CLS]\n", + "1 ##wich battery\n", + "2 desk Ner\n", + "3 Collection ought\n", + "4 fisher Mug\n", + ".. ... ...\n", + "118 Mug fisher\n", + "119 ought Collection\n", + "120 Ner desk\n", + "121 battery ##wich\n", + "122 [SEP] [SEP]\n", + "\n", + "[123 rows x 2 columns]" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pandas_check(batch_shuffle, batch_shuffle_rev)" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "artificial-salem", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "l2r_loss tensor(14.4735, grad_fn=)\n", + "r2l_loss tensor(13.7806, grad_fn=)\n" + ] + } + ], + "source": [ + "res = model(**batch_shuffle) # forward" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "blond-pitch", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "l2r_loss tensor(14.0861, grad_fn=)\n", + "r2l_loss tensor(13.5071, grad_fn=)\n" + ] + } + ], + "source": [ + "res = model(**batch_shuffle_rev) # reverse" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "clean-favorite", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "id": "another-penetration", + "metadata": {}, + "source": [ + "##### Trial-2" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "brilliant-antarctica", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
forwardreverse
0[CLS][CLS]
1##wichbattery
2deskNer
3Collectionought
4fisherMug
.........
118Mugfisher
119oughtCollection
120Nerdesk
121battery##wich
122[SEP][SEP]
\n", + "

123 rows × 2 columns

\n", + "
" + ], + "text/plain": [ + " forward reverse\n", + "0 [CLS] [CLS]\n", + "1 ##wich battery\n", + "2 desk Ner\n", + "3 Collection ought\n", + "4 fisher Mug\n", + ".. ... ...\n", + "118 Mug fisher\n", + "119 ought Collection\n", + "120 Ner desk\n", + "121 battery ##wich\n", + "122 [SEP] [SEP]\n", + "\n", + "[123 rows x 2 columns]" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pandas_check(batch_shuffle, batch_shuffle_rev)" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "animal-break", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "l2r_loss tensor(14.4735, grad_fn=)\n", + "r2l_loss tensor(13.7806, grad_fn=)\n" + ] + } + ], + "source": [ + "res = model(**batch_shuffle) # forward" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "unlikely-table", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "l2r_loss tensor(14.0861, grad_fn=)\n", + "r2l_loss tensor(13.5071, grad_fn=)\n" + ] + } + ], + "source": [ + "res = model(**batch_shuffle_rev) # reverse" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "boolean-supplier", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "id": "mexican-canadian", + "metadata": {}, + "source": [ + "##### Trial-3" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "assured-bahrain", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
forwardreverse
0[CLS][CLS]
1##wichbattery
2deskNer
3Collectionought
4fisherMug
.........
118Mugfisher
119oughtCollection
120Nerdesk
121battery##wich
122[SEP][SEP]
\n", + "

123 rows × 2 columns

\n", + "
" + ], + "text/plain": [ + " forward reverse\n", + "0 [CLS] [CLS]\n", + "1 ##wich battery\n", + "2 desk Ner\n", + "3 Collection ought\n", + "4 fisher Mug\n", + ".. ... ...\n", + "118 Mug fisher\n", + "119 ought Collection\n", + "120 Ner desk\n", + "121 battery ##wich\n", + "122 [SEP] [SEP]\n", + "\n", + "[123 rows x 2 columns]" + ] + }, + "execution_count": 27, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pandas_check(batch_shuffle, batch_shuffle_rev)" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "promotional-panel", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "l2r_loss tensor(14.4735, grad_fn=)\n", + "r2l_loss tensor(13.7806, grad_fn=)\n" + ] + } + ], + "source": [ + "res = model(**batch_shuffle) # forward" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "id": "moving-karma", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "l2r_loss tensor(14.0861, grad_fn=)\n", + "r2l_loss tensor(13.5071, grad_fn=)\n" + ] + } + ], + "source": [ + "res = model(**batch_shuffle_rev) # reverse" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "clean-gibson", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.5" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/elmo-sanity-check-v2.ipynb b/elmo-sanity-check-v2.ipynb new file mode 100644 index 0000000..9004f6f --- /dev/null +++ b/elmo-sanity-check-v2.ipynb @@ -0,0 +1,1104 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "working-fisher", + "metadata": {}, + "source": [ + "## Prepare Data + Model" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "quiet-macro", + "metadata": {}, + "outputs": [], + "source": [ + "# !cat examples/data/text_forward.txt" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "resistant-denver", + "metadata": {}, + "outputs": [], + "source": [ + "# !ls -al ./outputs/en.1-percent.elmo-bert-causal-fixr2l" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "egyptian-association", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "robust-strategy", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from newlm.lm.elmo.modeling_elmo.elmo_head import ELMOBertLMHeadModel\n", + "from newlm.lm.elmo.lm_builder import ELMOLMBuilder\n", + "from transformers import BertConfig" + ] + }, + { + "cell_type": "markdown", + "id": "beneficial-franchise", + "metadata": {}, + "source": [ + "#### Model" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "supported-strategy", + "metadata": {}, + "outputs": [], + "source": [ + "model = ELMOBertLMHeadModel.from_pretrained(\n", + " \"./outputs/en.1-percent.elmo-bert-causal-fixr2l\"\n", + ") # use pre-trained model" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "emotional-moisture", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Model in eval mode for consistency\n" + ] + } + ], + "source": [ + "model.eval()\n", + "print(\"Model in eval mode for consistency\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "following-execution", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "id": "sized-enemy", + "metadata": {}, + "source": [ + "#### Data" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "previous-crack", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2021-11-13 00:29:52.826 | INFO | newlm.lm.elmo.lm_builder:_get_dataset:142 - Constructing roBERTa style dataset\n" + ] + } + ], + "source": [ + "%%capture\n", + "\n", + "from newlm.utils.file_util import read_from_yaml\n", + "config_file = read_from_yaml('examples/configs/run.1-percent-elmo-bert-causal.yaml')\n", + "\n", + "# lm builder (helper)\n", + "elmo_lm_builder = ELMOLMBuilder(\n", + " model_config = config_file['lm']['model']['config'],\n", + " tokenizer=\"./outputs/en.1-percent.elmo-bert-causal-fixr2l\", # use pre-trained tokenizer\n", + " model_type=\"bert-causal-elmo\",\n", + " max_len=128\n", + ")\n", + "\n", + "# dataset-forward\n", + "train_path = \"./examples/data/text_forward-small.txt\"\n", + "ds_f = elmo_lm_builder._get_dataset(train_path)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "mighty-major", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "max_steps is given, it will override any value given in num_train_epochs\n" + ] + } + ], + "source": [ + "# trainer (helper)\n", + "from transformers import TrainingArguments, Trainer\n", + "args = TrainingArguments(output_dir=\"tmpout\",**config_file['lm']['hf_trainer']['args'])\n", + "\n", + "# dataloader-forward\n", + "trainer = Trainer(model=model, args=args, data_collator=elmo_lm_builder.data_collator, train_dataset=ds_f,)\n", + "dl_f = trainer.get_train_dataloader() # Data Loader-forward" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "attempted-tomato", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([1, 127])" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "batch_f = next(iter(dl_f))\n", + "batch_f['input_ids'].shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fallen-statement", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "paperback-luther", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Model in eval mode for consistency\n" + ] + } + ], + "source": [ + "model.eval()\n", + "print(\"Model in eval mode for consistency\")" + ] + }, + { + "cell_type": "markdown", + "id": "fluid-product", + "metadata": {}, + "source": [ + "## Sanity Check" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "dimensional-sight", + "metadata": {}, + "outputs": [], + "source": [ + "# batch_f" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "environmental-experience", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "\n", + "def reverse_batch(batch_f):\n", + " # reverse input\n", + " batch_f_input = torch.clone(batch_f['input_ids'])\n", + " batch_f_rev_input = torch.cat(\n", + " (\n", + " batch_f_input[0][0:1],\n", + " torch.flip(batch_f_input[0][1:-1], [0]),\n", + " batch_f_input[0][-1:]\n", + " )\n", + " )\n", + " batch_f_rev_input = batch_f_rev_input.reshape(1,-1)\n", + "\n", + " # reverse labels\n", + " batch_f_rev_labels = torch.clone(batch_f_rev_input)\n", + " \n", + " # batch_rev\n", + " batch_rev = batch_f.copy()\n", + " batch_rev['input_ids'] = batch_f_rev_input\n", + " batch_rev['labels'] = batch_f_rev_labels\n", + " \n", + " return batch_rev" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "informed-termination", + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "\n", + "def pandas_check(batch_f, batch_rev):\n", + " tokens_f = elmo_lm_builder.tokenizer.convert_ids_to_tokens(batch_f['input_ids'][0])\n", + " tokens_f_rev = elmo_lm_builder.tokenizer.convert_ids_to_tokens(batch_rev['input_ids'][0])\n", + " return pd.DataFrame({\"forward\": tokens_f, \"reverse\": tokens_f_rev})" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "complex-music", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "id": "tight-bennett", + "metadata": {}, + "source": [ + "#### Normal vs Reverse" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "respective-beauty", + "metadata": {}, + "outputs": [], + "source": [ + "batch_rev = reverse_batch(batch_f)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "direct-warren", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
forwardreverse
0[CLS][CLS]
1R.
2##øJews
3##damong
4##berglived
.........
122lived##berg
123among##d
124Jews##ø
125.R
126[SEP][SEP]
\n", + "

127 rows × 2 columns

\n", + "
" + ], + "text/plain": [ + " forward reverse\n", + "0 [CLS] [CLS]\n", + "1 R .\n", + "2 ##ø Jews\n", + "3 ##d among\n", + "4 ##berg lived\n", + ".. ... ...\n", + "122 lived ##berg\n", + "123 among ##d\n", + "124 Jews ##ø\n", + "125 . R\n", + "126 [SEP] [SEP]\n", + "\n", + "[127 rows x 2 columns]" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pandas_check(batch_f, batch_rev)" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "fourth-salmon", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(torch.Size([1, 127]), torch.Size([1, 127]))" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "batch_f['input_ids'].shape, batch_rev['input_ids'].shape" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "violent-shell", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "l2r_loss tensor(4.4801, grad_fn=)\n", + "r2l_loss tensor(4.3589, grad_fn=)\n" + ] + } + ], + "source": [ + "res = model(**batch_f) # forward" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "related-mambo", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "l2r_loss tensor(9.1742, grad_fn=)\n", + "r2l_loss tensor(9.0804, grad_fn=)\n" + ] + } + ], + "source": [ + "res = model(**batch_rev) # reverse" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "retired-bahamas", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "id": "broad-agriculture", + "metadata": {}, + "source": [ + "#### Random String" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "remarkable-sailing", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([1, 127])" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "batch_f['input_ids'].shape" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "id": "fifth-booth", + "metadata": {}, + "outputs": [], + "source": [ + "# shuffle data\n", + "batch_f_input = batch_f['input_ids']\n", + "batch_shuffle_input = torch.cat(\n", + " (\n", + " batch_f_input[0][0:1],\n", + " torch.randint(\n", + " low=5, # 0-4 > ['[PAD]', '[UNK]', '[CLS]', '[SEP]', '[MASK]']\n", + " high=29999,\n", + " size=(125,), # modified based on seqlen!\n", + " dtype=torch.long\n", + " ),\n", + " batch_f_input[0][-1:]\n", + " )\n", + ")\n", + "batch_shuffle_input = batch_shuffle_input.reshape(1,-1)\n", + "# labels\n", + "batch_shuffle_labels = torch.clone(batch_shuffle_input) \n", + "# batch_shuffle\n", + "batch_shuffle = batch_f.copy()\n", + "batch_shuffle['input_ids'] = batch_shuffle_input\n", + "batch_shuffle['labels'] = batch_shuffle_labels" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "id": "white-juice", + "metadata": {}, + "outputs": [], + "source": [ + "batch_shuffle_rev = reverse_batch(batch_shuffle)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "addressed-collective", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "id": "abroad-johnson", + "metadata": {}, + "source": [ + "##### Trial-1" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "inclusive-station", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
forwardreverse
0[CLS][CLS]
1Lex##leg
2neithersquinting
3feverы
4LeedsRussell
.........
122RussellLeeds
123ыfever
124squintingneither
125##legLex
126[SEP][SEP]
\n", + "

127 rows × 2 columns

\n", + "
" + ], + "text/plain": [ + " forward reverse\n", + "0 [CLS] [CLS]\n", + "1 Lex ##leg\n", + "2 neither squinting\n", + "3 fever ы\n", + "4 Leeds Russell\n", + ".. ... ...\n", + "122 Russell Leeds\n", + "123 ы fever\n", + "124 squinting neither\n", + "125 ##leg Lex\n", + "126 [SEP] [SEP]\n", + "\n", + "[127 rows x 2 columns]" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pandas_check(batch_shuffle, batch_shuffle_rev)" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "hungarian-soundtrack", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "l2r_loss tensor(13.5902, grad_fn=)\n", + "r2l_loss tensor(13.4444, grad_fn=)\n" + ] + } + ], + "source": [ + "res = model(**batch_shuffle) # forward" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "employed-yukon", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "l2r_loss tensor(13.7210, grad_fn=)\n", + "r2l_loss tensor(13.3258, grad_fn=)\n" + ] + } + ], + "source": [ + "res = model(**batch_shuffle_rev) # reverse" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "voluntary-necessity", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "id": "judicial-republic", + "metadata": {}, + "source": [ + "##### Trial-2" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "reported-sweden", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
forwardreverse
0[CLS][CLS]
1ByrneClean
2vanillaWear
3departedflickered
4##ische##ghter
.........
122##ghter##ische
123flickereddeparted
124Wearvanilla
125CleanByrne
126[SEP][SEP]
\n", + "

127 rows × 2 columns

\n", + "
" + ], + "text/plain": [ + " forward reverse\n", + "0 [CLS] [CLS]\n", + "1 Byrne Clean\n", + "2 vanilla Wear\n", + "3 departed flickered\n", + "4 ##ische ##ghter\n", + ".. ... ...\n", + "122 ##ghter ##ische\n", + "123 flickered departed\n", + "124 Wear vanilla\n", + "125 Clean Byrne\n", + "126 [SEP] [SEP]\n", + "\n", + "[127 rows x 2 columns]" + ] + }, + "execution_count": 26, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pandas_check(batch_shuffle, batch_shuffle_rev)" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "equipped-ordinary", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "l2r_loss tensor(13.4740, grad_fn=)\n", + "r2l_loss tensor(13.1427, grad_fn=)\n" + ] + } + ], + "source": [ + "res = model(**batch_shuffle) # forward" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "celtic-spider", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "l2r_loss tensor(13.3491, grad_fn=)\n", + "r2l_loss tensor(13.3746, grad_fn=)\n" + ] + } + ], + "source": [ + "res = model(**batch_shuffle_rev) # reverse" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "rough-kelly", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "id": "invisible-white", + "metadata": {}, + "source": [ + "##### Trial-3" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "id": "chubby-progressive", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
forwardreverse
0[CLS][CLS]
1SeveranceHild
2endorsenor
3reviewed##frey
4AloneFred
.........
122FredAlone
123##freyreviewed
124norendorse
125HildSeverance
126[SEP][SEP]
\n", + "

127 rows × 2 columns

\n", + "
" + ], + "text/plain": [ + " forward reverse\n", + "0 [CLS] [CLS]\n", + "1 Severance Hild\n", + "2 endorse nor\n", + "3 reviewed ##frey\n", + "4 Alone Fred\n", + ".. ... ...\n", + "122 Fred Alone\n", + "123 ##frey reviewed\n", + "124 nor endorse\n", + "125 Hild Severance\n", + "126 [SEP] [SEP]\n", + "\n", + "[127 rows x 2 columns]" + ] + }, + "execution_count": 31, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pandas_check(batch_shuffle, batch_shuffle_rev)" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "id": "partial-devil", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "l2r_loss tensor(13.9457, grad_fn=)\n", + "r2l_loss tensor(13.5547, grad_fn=)\n" + ] + } + ], + "source": [ + "res = model(**batch_shuffle) # forward" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "id": "structured-farmer", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "l2r_loss tensor(13.5305, grad_fn=)\n", + "r2l_loss tensor(13.4759, grad_fn=)\n" + ] + } + ], + "source": [ + "res = model(**batch_shuffle_rev) # reverse" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "funky-space", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.5" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/elmo-sanity-check.ipynb b/elmo-sanity-check.ipynb new file mode 100644 index 0000000..9dc106d --- /dev/null +++ b/elmo-sanity-check.ipynb @@ -0,0 +1,669 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "laden-tsunami", + "metadata": {}, + "source": [ + "## Prepare data" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "lovely-flood", + "metadata": {}, + "outputs": [], + "source": [ + "with open(\"./examples/data/text_forward-small.txt\", \"r+\") as fr:\n", + " lines = fr.readlines()" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "false-bernard", + "metadata": {}, + "outputs": [], + "source": [ + "lines = [line.strip() for line in lines]" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "funded-listing", + "metadata": {}, + "outputs": [], + "source": [ + "lines = [line.split() for line in lines]" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "medical-visiting", + "metadata": {}, + "outputs": [], + "source": [ + "lines = [line[::-1] for line in lines]" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "eastern-publicity", + "metadata": {}, + "outputs": [], + "source": [ + "lines = [\" \".join(line) for line in lines]" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "alien-nightmare", + "metadata": {}, + "outputs": [], + "source": [ + "lines = lines[::-1]" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "acute-black", + "metadata": {}, + "outputs": [], + "source": [ + "lines = \"\\n\".join(lines)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "modified-links", + "metadata": {}, + "outputs": [], + "source": [ + "with open(\"./examples/data/text_backward-small.txt\", \"w+\") as fw:\n", + " fw.write(lines)" + ] + }, + { + "cell_type": "markdown", + "id": "absolute-world", + "metadata": {}, + "source": [ + "## Sanity Check" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "hollow-newsletter", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from newlm.lm.elmo.modeling_elmo.elmo_head import ELMOBertLMHeadModel\n", + "from newlm.lm.elmo.lm_builder import ELMOLMBuilder\n", + "from transformers import BertConfig" + ] + }, + { + "cell_type": "markdown", + "id": "chronic-crack", + "metadata": {}, + "source": [ + "Model from scratch" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "identified-ivory", + "metadata": {}, + "outputs": [], + "source": [ + "from newlm.utils.file_util import read_from_yaml\n", + "config_file = read_from_yaml('examples/configs/run.1-percent-bert-causal.yaml')\n", + "\n", + "elmo_lm_builder = ELMOLMBuilder(\n", + " model_config = config_file['lm']['model']['config'], # no pretrained model\n", + " tokenizer=\"./outputs/en.1-percent.elmo-bert-causal\", # use pre-trained tokenizer\n", + " model_type=\"bert-causal-elmo\",\n", + " max_len=128\n", + ")\n", + "\n", + "# model\n", + "config = BertConfig(**elmo_lm_builder.model_config)\n", + "model = ELMOBertLMHeadModel(config=config)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "boring-workplace", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Model in eval mode for consistency\n" + ] + } + ], + "source": [ + "model.eval()\n", + "print(\"Model in eval mode for consistency\")" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "deluxe-expert", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2021-11-12 11:07:27.881 | INFO | newlm.lm.elmo.lm_builder:_get_dataset:142 - Constructing roBERTa style dataset\n", + "2021-11-12 11:07:29.324 | INFO | newlm.lm.elmo.lm_builder:_get_dataset:142 - Constructing roBERTa style dataset\n" + ] + } + ], + "source": [ + "%%capture\n", + "\n", + "# dataset-forward\n", + "train_path = \"./examples/data/text_forward-small.txt\"\n", + "ds_f = elmo_lm_builder._get_dataset(train_path)\n", + "\n", + "# dataset-backward\n", + "train_path = \"./examples/data/text_backward-small.txt\"\n", + "ds_b = elmo_lm_builder._get_dataset(train_path)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "immune-quick", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "max_steps is given, it will override any value given in num_train_epochs\n", + "max_steps is given, it will override any value given in num_train_epochs\n" + ] + } + ], + "source": [ + "# trainer (for helper)\n", + "from transformers import TrainingArguments, Trainer\n", + "args = TrainingArguments(output_dir=\"tmpout\",**config_file['lm']['hf_trainer']['args'])\n", + "\n", + "trainer = Trainer(model=model, args=args, data_collator=elmo_lm_builder.data_collator,\n", + " train_dataset=ds_f,\n", + ")\n", + "dl_f = trainer.get_train_dataloader() # Data Loader-forward\n", + "\n", + "trainer = Trainer(model=model, args=args,data_collator=elmo_lm_builder.data_collator,\n", + " train_dataset=ds_b,\n", + ")\n", + "dl_b = trainer.get_train_dataloader() # Data Loader-backward" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "enabling-business", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "id": "preceding-evening", + "metadata": {}, + "source": [ + "### Compare text_forward and text_backward" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "promising-panama", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(torch.Size([1, 127]), torch.Size([1, 127]))" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "batch_f = next(iter(dl_f))\n", + "batch_b = next(iter(dl_b))\n", + "\n", + "batch_f['input_ids'].shape, batch_b['input_ids'].shape" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "italic-treaty", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
forwardbackward
0[CLS][CLS]
1R.
2##øJews
3##damong
4##berglived
.........
122livedR
123among##ø
124Jews##d
125.##berg
126[SEP][SEP]
\n", + "

127 rows × 2 columns

\n", + "
" + ], + "text/plain": [ + " forward backward\n", + "0 [CLS] [CLS]\n", + "1 R .\n", + "2 ##ø Jews\n", + "3 ##d among\n", + "4 ##berg lived\n", + ".. ... ...\n", + "122 lived R\n", + "123 among ##ø\n", + "124 Jews ##d\n", + "125 . ##berg\n", + "126 [SEP] [SEP]\n", + "\n", + "[127 rows x 2 columns]" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tokens_f = elmo_lm_builder.tokenizer.convert_ids_to_tokens(batch_f['input_ids'][0])\n", + "tokens_b = elmo_lm_builder.tokenizer.convert_ids_to_tokens(batch_b['input_ids'][0])\n", + "\n", + "import pandas as pd\n", + "pd.DataFrame({\"forward\": tokens_f, \"backward\": tokens_b})" + ] + }, + { + "cell_type": "markdown", + "id": "attached-nursing", + "metadata": {}, + "source": [ + "Here we can see that the data is not completely flip when the tokenizer couldn't parse a single word into a single id" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "graphic-paintball", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "l2r_loss tensor(10.5425, grad_fn=)\n", + "r2l_loss tensor(10.4237, grad_fn=)\n" + ] + } + ], + "source": [ + "res = model(**batch_f)" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "frequent-scheduling", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "l2r_loss tensor(10.5305, grad_fn=)\n", + "r2l_loss tensor(10.4198, grad_fn=)\n" + ] + } + ], + "source": [ + "res = model(**batch_b)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "amber-auckland", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "id": "unique-jungle", + "metadata": {}, + "source": [ + "### From batch_forward compare Normal vs Rev" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "right-coalition", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "\n", + "batch_f_input = torch.clone(batch_f['input_ids'])\n", + "batch_f_rev_input = torch.cat(\n", + " (\n", + " batch_f_input[0][0:1],\n", + " torch.flip(batch_f_input[0][1:-1], [0]),\n", + " batch_f_input[0][-1:]\n", + " )\n", + ")\n", + "batch_f_rev_input = batch_f_rev_input.reshape(1,-1)\n", + "\n", + "batch_f_labels = torch.clone(batch_f['labels'])\n", + "batch_f_rev_labels = torch.cat(\n", + " (\n", + " batch_f_labels[0][0:1],\n", + " torch.flip(batch_f_labels[0][1:-1], [0]),\n", + " batch_f_labels[0][-1:]\n", + " )\n", + ")\n", + "batch_f_rev_labels = batch_f_rev_labels.reshape(1,-1)\n", + "\n", + "batch_rev = batch_f.copy()\n", + "batch_rev['input_ids'] = batch_f_rev_input\n", + "batch_rev['labels'] = batch_f_rev_labels" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "terminal-withdrawal", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
forwardreverse
0[CLS][CLS]
1R.
2##øJews
3##damong
4##berglived
.........
122lived##berg
123among##d
124Jews##ø
125.R
126[SEP][SEP]
\n", + "

127 rows × 2 columns

\n", + "
" + ], + "text/plain": [ + " forward reverse\n", + "0 [CLS] [CLS]\n", + "1 R .\n", + "2 ##ø Jews\n", + "3 ##d among\n", + "4 ##berg lived\n", + ".. ... ...\n", + "122 lived ##berg\n", + "123 among ##d\n", + "124 Jews ##ø\n", + "125 . R\n", + "126 [SEP] [SEP]\n", + "\n", + "[127 rows x 2 columns]" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tokens_f = elmo_lm_builder.tokenizer.convert_ids_to_tokens(batch_f['input_ids'][0])\n", + "tokens_f_rev = elmo_lm_builder.tokenizer.convert_ids_to_tokens(batch_rev['input_ids'][0])\n", + "\n", + "import pandas as pd\n", + "pd.DataFrame({\"forward\": tokens_f, \"reverse\": tokens_f_rev})" + ] + }, + { + "cell_type": "markdown", + "id": "continued-recommendation", + "metadata": {}, + "source": [ + "with the exception of [CLS] and [SEP], the data are completely flip" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "narrow-pipeline", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "l2r_loss tensor(10.5425, grad_fn=)\n", + "r2l_loss tensor(10.4237, grad_fn=)\n" + ] + } + ], + "source": [ + "res = model(**batch_f)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "desirable-peripheral", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "l2r_loss tensor(10.5323, grad_fn=)\n", + "r2l_loss tensor(10.4094, grad_fn=)\n" + ] + } + ], + "source": [ + "res = model(**batch_rev)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.5" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/newlm/lm/elmo/lm_builder.py b/newlm/lm/elmo/lm_builder.py index 590a05f..28bf661 100644 --- a/newlm/lm/elmo/lm_builder.py +++ b/newlm/lm/elmo/lm_builder.py @@ -135,7 +135,7 @@ def preprocess_function(examples): encoded_dataset = dataset.map(preprocess_function, batched=True) return encoded_dataset["train"] - def __get_dataset(self, train_path): + def _get_dataset(self, train_path): dataset = self.__get_dataset_via_ds(train_path)["input_ids"] print(len(dataset)) diff --git a/newlm/lm/elmo/modeling_elmo/elmo_head.py b/newlm/lm/elmo/modeling_elmo/elmo_head.py index 4683fc7..fc3e0cd 100644 --- a/newlm/lm/elmo/modeling_elmo/elmo_head.py +++ b/newlm/lm/elmo/modeling_elmo/elmo_head.py @@ -190,6 +190,8 @@ def forward( last_hidden_states, flip_labels, r2l=True ) + print("l2r_loss", l2r_loss) + print("r2l_loss", r2l_loss) total_loss = l2r_loss + r2l_loss if labels is not None else None return ElmoGPTCausalLMOutput(