From bf4482fac91a0286e88b9fcbfcb34cdbeb32e38c Mon Sep 17 00:00:00 2001 From: jon-tow Date: Thu, 8 Feb 2024 16:31:11 +0000 Subject: [PATCH] fix(docs): make docs consistent with latest CamelCase --- stablelm_repos/stable-code-3b | 1 + stablelm_repos/stablelm-2-1_6b | 1 + stablelm_repos/stablelm-2-zephyr-1_6b | 1 + stablelm_repos/stablelm-3b-4e1t | 1 + stablelm_repos/stablelm-zephyr-3b | 1 + stablelm_repos/todo | 1 + test/stablelm-2-1_6b | 1 + test/stablelm-2-1_6b-dev | 1 + test/stablelm-3b-4e1t | 1 + test/stablelm-3b-4e1t-dev | 1 + test/test copy.ipynb | 752 +++++++++++++++ test/todo | 90 ++ todo | 1258 +++++++++++++++++++++++++ 13 files changed, 2110 insertions(+) create mode 160000 stablelm_repos/stable-code-3b create mode 160000 stablelm_repos/stablelm-2-1_6b create mode 160000 stablelm_repos/stablelm-2-zephyr-1_6b create mode 160000 stablelm_repos/stablelm-3b-4e1t create mode 160000 stablelm_repos/stablelm-zephyr-3b create mode 100644 stablelm_repos/todo create mode 160000 test/stablelm-2-1_6b create mode 160000 test/stablelm-2-1_6b-dev create mode 160000 test/stablelm-3b-4e1t create mode 160000 test/stablelm-3b-4e1t-dev create mode 100644 test/test copy.ipynb create mode 100644 test/todo create mode 100644 todo diff --git a/stablelm_repos/stable-code-3b b/stablelm_repos/stable-code-3b new file mode 160000 index 00000000000000..7c4583da0a36c5 --- /dev/null +++ b/stablelm_repos/stable-code-3b @@ -0,0 +1 @@ +Subproject commit 7c4583da0a36c56c7bf040864b737a70d45d1f4b diff --git a/stablelm_repos/stablelm-2-1_6b b/stablelm_repos/stablelm-2-1_6b new file mode 160000 index 00000000000000..efd6eabee32eab --- /dev/null +++ b/stablelm_repos/stablelm-2-1_6b @@ -0,0 +1 @@ +Subproject commit efd6eabee32eabd8675eae49359287ad061fe197 diff --git a/stablelm_repos/stablelm-2-zephyr-1_6b b/stablelm_repos/stablelm-2-zephyr-1_6b new file mode 160000 index 00000000000000..28b4cfb0351646 --- /dev/null +++ b/stablelm_repos/stablelm-2-zephyr-1_6b @@ -0,0 +1 @@ +Subproject commit 28b4cfb03516461c44b131d362bd378d74a0522b diff --git a/stablelm_repos/stablelm-3b-4e1t b/stablelm_repos/stablelm-3b-4e1t new file mode 160000 index 00000000000000..ef92c21fd57a2c --- /dev/null +++ b/stablelm_repos/stablelm-3b-4e1t @@ -0,0 +1 @@ +Subproject commit ef92c21fd57a2c5e79f220acd1bcf44a7bea5f54 diff --git a/stablelm_repos/stablelm-zephyr-3b b/stablelm_repos/stablelm-zephyr-3b new file mode 160000 index 00000000000000..e20b0d644459c8 --- /dev/null +++ b/stablelm_repos/stablelm-zephyr-3b @@ -0,0 +1 @@ +Subproject commit e20b0d644459c889294d51be537c7d15971f9fb1 diff --git a/stablelm_repos/todo b/stablelm_repos/todo new file mode 100644 index 00000000000000..50c9330967d1a7 --- /dev/null +++ b/stablelm_repos/todo @@ -0,0 +1 @@ +git commit -m 'merge: upload `transformers` implementation \ No newline at end of file diff --git a/test/stablelm-2-1_6b b/test/stablelm-2-1_6b new file mode 160000 index 00000000000000..1909ae19b3adb0 --- /dev/null +++ b/test/stablelm-2-1_6b @@ -0,0 +1 @@ +Subproject commit 1909ae19b3adb0a79d1ed313154cb67afe35d587 diff --git a/test/stablelm-2-1_6b-dev b/test/stablelm-2-1_6b-dev new file mode 160000 index 00000000000000..2d597a8383d251 --- /dev/null +++ b/test/stablelm-2-1_6b-dev @@ -0,0 +1 @@ +Subproject commit 2d597a8383d251f064f9052b39c0c6e69cc13f3e diff --git a/test/stablelm-3b-4e1t b/test/stablelm-3b-4e1t new file mode 160000 index 00000000000000..e3be657f4d1b78 --- /dev/null +++ b/test/stablelm-3b-4e1t @@ -0,0 +1 @@ +Subproject commit e3be657f4d1b78eb7520637ba922448a1ee456bd diff --git a/test/stablelm-3b-4e1t-dev b/test/stablelm-3b-4e1t-dev new file mode 160000 index 00000000000000..7991e5d17df5d5 --- /dev/null +++ b/test/stablelm-3b-4e1t-dev @@ -0,0 +1 @@ +Subproject commit 7991e5d17df5d5e547363b75ff956c0776540231 diff --git a/test/test copy.ipynb b/test/test copy.ipynb new file mode 100644 index 00000000000000..5958f2febe8837 --- /dev/null +++ b/test/test copy.ipynb @@ -0,0 +1,752 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/fsx/home-guac/stable-lm/transformers/.env/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n", + "/fsx/home-guac/stable-lm/transformers/src/transformers/utils/hub.py:124: FutureWarning: Using `TRANSFORMERS_CACHE` is deprecated and will be removed in v5 of Transformers. Use `HF_HOME` instead.\n", + " warnings.warn(\n" + ] + } + ], + "source": [ + "import torch\n", + "from transformers import (\n", + " StableLmForCausalLM,\n", + " StableLmForSequenceClassification,\n", + " StableLmModel,\n", + " StableLmConfig,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "StableLmConfig {\n", + " \"attention_dropout\": 0.0,\n", + " \"bos_token_id\": 0,\n", + " \"eos_token_id\": 0,\n", + " \"hidden_act\": \"silu\",\n", + " \"hidden_dropout\": 0.0,\n", + " \"hidden_size\": 2560,\n", + " \"initializer_range\": 0.02,\n", + " \"intermediate_size\": 6912,\n", + " \"layer_norm_eps\": 1e-05,\n", + " \"max_position_embeddings\": 4096,\n", + " \"model_type\": \"stablelm\",\n", + " \"num_attention_heads\": 32,\n", + " \"num_hidden_layers\": 32,\n", + " \"num_key_value_heads\": 32,\n", + " \"partial_rotary_factor\": 0.25,\n", + " \"rope_scaling\": null,\n", + " \"rope_theta\": 10000,\n", + " \"tie_word_embeddings\": false,\n", + " \"transformers_version\": \"4.38.0.dev0\",\n", + " \"use_cache\": true,\n", + " \"use_qkv_bias\": true,\n", + " \"vocab_size\": 50304\n", + "}\n", + "\n" + ] + } + ], + "source": [ + "c = StableLmConfig()\n", + "print(c)" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "# model = StableLmForCausalLM.from_pretrained(\"jon-tow/stablelm-3b-4e1t-dev\", attn_implementation=\"flash_attention_2\")\n", + "model = StableLmForCausalLM.from_pretrained(\"jon-tow/stablelm-3b-4e1t-dev\", device_map=\"auto\", torch_dtype=torch.bfloat16,\n", + " attn_implementation=\"flash_attention_2\")" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "from transformers import AutoModelForCausalLM, AutoTokenizer\n", + "\n", + "path = \"jon-tow/stablelm-3b-4e1t-dev\"\n", + "# model = StableLmForCausalLM.from_pretrained(\n", + "# path,a\n", + "# torch_dtype=torch.bfloat16,\n", + "# low_cpu_mem_usage=True,\n", + "# device_map={\"\": 0},\n", + "# )\n", + "model = AutoModelForCausalLM.from_pretrained(\n", + " path,\n", + " device_map=\"auto\",\n", + " torch_dtype=\"auto\",\n", + " attn_implementation=\"flash_attention_2\",\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[31mERROR: You must give at least one requirement to install (see \"pip help install\")\u001b[0m\u001b[31m\n", + "\u001b[0m\n", + "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.0.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m24.0\u001b[0m\n", + "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n" + ] + } + ], + "source": [ + "!pip install -i https://pypi.org/simple bitsandbytes" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['My favorite food has always been iced tea. I love the way it tastes, the way it smells, and the way it makes',\n", + " 'Hello.\\nI am a newbie to the forum and I am looking for some advice.\\nI']" + ] + }, + "execution_count": 30, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "output_native" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['My favorite food has always been iced tea. I love the way it tastes, the way it smells, and the way it makes',\n", + " 'Hello.\\nI am a newbie to the forum and I am looking for some advice.\\nI']" + ] + }, + "execution_count": 31, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "output_fa_2" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "StableLMConfig {\n", + " \"_name_or_path\": \"jon-tow/stablelm-3b-4e1t-dev\",\n", + " \"attention_dropout\": 0.0,\n", + " \"bos_token_id\": 0,\n", + " \"eos_token_id\": 0,\n", + " \"hidden_act\": \"silu\",\n", + " \"hidden_dropout\": 0.0,\n", + " \"hidden_size\": 2560,\n", + " \"initializer_range\": 0.02,\n", + " \"intermediate_size\": 6912,\n", + " \"layer_norm_eps\": 1e-05,\n", + " \"max_position_embeddings\": 4096,\n", + " \"model_type\": \"stablelm\",\n", + " \"num_attention_heads\": 32,\n", + " \"num_hidden_layers\": 32,\n", + " \"num_key_value_heads\": 32,\n", + " \"partial_rotary_factor\": 0.25,\n", + " \"rope_scaling\": null,\n", + " \"rope_theta\": 10000,\n", + " \"tie_word_embeddings\": false,\n", + " \"torch_dtype\": \"bfloat16\",\n", + " \"transformers_version\": \"4.38.0.dev0\",\n", + " \"use_cache\": true,\n", + " \"use_qkv_bias\": false,\n", + " \"vocab_size\": 50304\n", + "}" + ] + }, + "execution_count": 32, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model.config" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/fsx/home-guac/stable-lm/transformers/.env/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n", + "/fsx/home-guac/stable-lm/transformers/src/transformers/utils/hub.py:124: FutureWarning: Using `TRANSFORMERS_CACHE` is deprecated and will be removed in v5 of Transformers. Use `HF_HOME` instead.\n", + " warnings.warn(\n", + "A new version of the following files was downloaded from https://huggingface.co/stabilityai/stablelm-2-1_6b:\n", + "- tokenization_arcade100k.py\n", + ". Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.\n", + "A new version of the following files was downloaded from https://huggingface.co/stabilityai/stablelm-2-1_6b:\n", + "- configuration_stablelm_epoch.py\n", + ". Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.\n", + "A new version of the following files was downloaded from https://huggingface.co/stabilityai/stablelm-2-1_6b:\n", + "- modeling_stablelm_epoch.py\n", + ". Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.\n", + "Setting `pad_token_id` to `eos_token_id`:100257 for open-end generation.\n", + "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n", + "Setting `pad_token_id` to `eos_token_id`:100257 for open-end generation.\n", + "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n" + ] + } + ], + "source": [ + "from transformers import AutoModelForCausalLM, AutoTokenizer\n", + "\n", + "tokenizer = AutoTokenizer.from_pretrained(\"stabilityai/stablelm-2-1_6b\", trust_remote_code=True)\n", + "model_native = AutoModelForCausalLM.from_pretrained(\n", + " \"stabilityai/stablelm-2-1_6b\",\n", + " trust_remote_code=True,\n", + " torch_dtype=\"auto\",\n", + ")\n", + "\n", + "model_transformers = AutoModelForCausalLM.from_pretrained(\n", + " \"jon-tow/stablelm-2-1_6b-dev\",\n", + " torch_dtype=\"auto\",\n", + ")\n", + "\n", + "texts = [\"My favorite food has always been \", \"Hello\"]\n", + "\n", + "tokenizer.padding_side = \"right\"\n", + "tokenizer.pad_token = tokenizer.eos_token\n", + "\n", + "inputs = tokenizer(texts, return_tensors=\"pt\", padding=True).to(model_native.device)\n", + "\n", + "output_native = model_native.generate(**inputs, max_new_tokens=20, do_sample=False)\n", + "output_native = tokenizer.batch_decode(output_native, skip_special_tokens=True)\n", + "\n", + "\n", + "output_fa_2 = model_transformers.generate(**inputs, max_new_tokens=20, do_sample=False)\n", + "output_fa_2 = tokenizer.batch_decode(output_fa_2, skip_special_tokens=True)\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['My favorite food has always been 100% fruit. I love fruit. I love fruit in all its forms. I love',\n", + " 'Hello\\nI have a problem with my 2006 320d. I have']" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "output_native" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['My favorite food has always been 100% fruit. I love fruit. I love fruit in all its forms. I love',\n", + " 'Hello\\nI have a problem with my 2006 320d. I have']" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "output_fa_2" + ] + }, + { + "cell_type": "code", + "execution_count": 84, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Setting `pad_token_id` to `eos_token_id`:100257 for open-end generation.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Input ids: tensor([[ 5159, 7075, 3691, 706, 2744, 1027, 23317, 11, 719,\n", + " 31445],\n", + " [100257, 100257, 100257, 100257, 100257, 100257, 100257, 100257, 100257,\n", + " 9906]], device='cuda:0')\n", + "Hello, I am a 20 year old male and I have been having a problem with my penis\n" + ] + } + ], + "source": [ + "tokenizer.padding_side = \"left\"\n", + "tokenizer.pad_token = tokenizer.eos_token\n", + "\n", + "inputs = tokenizer([\n", + " \"My favorite food has always been pizza, but lately\",\n", + " \"Hello\"\n", + " ], return_tensors=\"pt\", padding=True).to(model.device)\n", + "print(f\"Input ids: {inputs.input_ids}\")\n", + "\n", + "tokens = model_native.generate(\n", + " **inputs,\n", + " max_new_tokens=20,\n", + " do_sample=False,\n", + " use_cache=True\n", + ")\n", + "print(tokenizer.decode(tokens[1], skip_special_tokens=True))" + ] + }, + { + "cell_type": "code", + "execution_count": 86, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Setting `pad_token_id` to `eos_token_id`:100257 for open-end generation.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Input ids: tensor([[ 5159, 7075, 3691, 706, 2744, 1027, 23317, 11, 719,\n", + " 31445],\n", + " [100257, 100257, 100257, 100257, 100257, 100257, 100257, 100257, 100257,\n", + " 9906]], device='cuda:0')\n", + "Hello, I am a 20 year old male and I have been having a problem with my penis\n" + ] + } + ], + "source": [ + "tokenizer.padding_side = \"left\"\n", + "tokenizer.pad_token = tokenizer.eos_token\n", + "\n", + "inputs = tokenizer([\n", + " \"My favorite food has always been pizza, but lately\",\n", + " \"Hello\"\n", + " ], return_tensors=\"pt\", padding=True).to(model.device)\n", + "print(f\"Input ids: {inputs.input_ids}\")\n", + "\n", + "tokens = model_transformers.generate(\n", + " **inputs,\n", + " max_new_tokens=20,\n", + " do_sample=False,\n", + " use_cache=True\n", + ")\n", + "print(tokenizer.decode(tokens[1], skip_special_tokens=True))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 105, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[autoreload of transformers.models.auto failed: Traceback (most recent call last):\n", + " File \"/fsx/home-guac/stable-lm/transformers/.env/lib/python3.10/site-packages/IPython/extensions/autoreload.py\", line 276, in check\n", + " superreload(m, reload, self.old_objects)\n", + " File \"/fsx/home-guac/stable-lm/transformers/.env/lib/python3.10/site-packages/IPython/extensions/autoreload.py\", line 475, in superreload\n", + " module = reload(module)\n", + " File \"/usr/lib/python3.10/importlib/__init__.py\", line 139, in reload\n", + " name = module.__spec__.name\n", + " File \"/fsx/home-guac/stable-lm/transformers/src/transformers/utils/import_utils.py\", line 1349, in __getattr__\n", + " if name in self._objects:\n", + " File \"/fsx/home-guac/stable-lm/transformers/src/transformers/utils/import_utils.py\", line 1349, in __getattr__\n", + " if name in self._objects:\n", + " File \"/fsx/home-guac/stable-lm/transformers/src/transformers/utils/import_utils.py\", line 1349, in __getattr__\n", + " if name in self._objects:\n", + " [Previous line repeated 4 more times]\n", + "RecursionError: maximum recursion depth exceeded while calling a Python object\n", + "]\n", + "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n" + ] + }, + { + "data": { + "text/plain": [ + "StableLMForCausalLM(\n", + " (model): StableLMModel(\n", + " (embed_tokens): Embedding(50304, 2560)\n", + " (layers): ModuleList(\n", + " (0-31): 32 x StableLMDecoderLayer(\n", + " (self_attn): StableLMAttention(\n", + " (q_proj): Linear(in_features=2560, out_features=2560, bias=False)\n", + " (k_proj): Linear(in_features=2560, out_features=2560, bias=False)\n", + " (v_proj): Linear(in_features=2560, out_features=2560, bias=False)\n", + " (o_proj): Linear(in_features=2560, out_features=2560, bias=False)\n", + " (rotary_emb): StableLMRotaryEmbedding()\n", + " )\n", + " (mlp): StableLMMLP(\n", + " (gate_proj): Linear(in_features=2560, out_features=6912, bias=False)\n", + " (up_proj): Linear(in_features=2560, out_features=6912, bias=False)\n", + " (down_proj): Linear(in_features=6912, out_features=2560, bias=False)\n", + " (act_fn): SiLU()\n", + " )\n", + " (input_layernorm): LayerNorm((2560,), eps=1e-05, elementwise_affine=True)\n", + " (post_attention_layernorm): LayerNorm((2560,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " )\n", + " (norm): LayerNorm((2560,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (lm_head): Linear(in_features=2560, out_features=50304, bias=False)\n", + ")" + ] + }, + "execution_count": 105, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from transformers import AutoModelForCausalLM, AutoTokenizer\n", + "# t = \"\"\n", + "# p = \"jon-tow/stablelm-2-1_6b-dev\"\n", + "p = \"jon-tow/stablelm-3b-4e1t-dev\"\n", + "device = \"cuda\"\n", + "tokenizer = AutoTokenizer.from_pretrained(p)\n", + "model = AutoModelForCausalLM.from_pretrained(\n", + " p,\n", + " torch_dtype=\"auto\",\n", + ")\n", + "model.to(device)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 101, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "A new version of the following files was downloaded from https://huggingface.co/stabilityai/stablelm-3b-4e1t:\n", + "- configuration_stablelm_epoch.py\n", + ". Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.\n" + ] + } + ], + "source": [ + "model = AutoModelForCausalLM.from_pretrained(\"stabilityai/stablelm-3b-4e1t\", torch_dtype=torch.bfloat16, attn_implementation=\"flash_at0tention_2\")" + ] + }, + { + "cell_type": "code", + "execution_count": 107, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.\n" + ] + }, + { + "data": { + "text/plain": [ + "[\"The weather is always wonderful in the San Juan Islands, and whether you're a vacationer or a resident, here are some ideas for fun! If you have\"]" + ] + }, + "execution_count": 107, + "metadata": {}, + "output_type": "execute_result" + }, + { + "ename": "", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[1;31mThe Kernel crashed while executing code in the the current cell or a previous cell. Please review the code in the cell(s) to identify a possible cause of the failure. Click here for more info. View Jupyter log for further details." + ] + } + ], + "source": [ + "device = \"cuda\"\n", + "model.to(device)\n", + "prompt = \"The weather is always wonderful in\"\n", + "model_inputs = tokenizer(prompt, return_tensors=\"pt\").to(model.device)\n", + "# model_inputs = tokenizer(\"The weather is always wonderful in\", return_tensors=\"pt\").to(model.device)\n", + "generated_ids = model.generate(**model_inputs, max_length=32, do_sample=True)\n", + "responses = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)\n", + "responses " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "responses" + ] + }, + { + "cell_type": "code", + "execution_count": 97, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "['The weather is always wonderful in Santa Barbara and, for visitors hoping to make the move to our beautiful seaside city, this town offers plenty of great places to']\n" + ] + } + ], + "source": [ + "model_inputs = tokenizer(\"The weather is always wonderful in\", return_tensors=\"pt\").to(model.device)\n", + "generated_ids = model.generate(**model_inputs, max_length=32, do_sample=True)\n", + "print(tokenizer.batch_decode(generated_ids, skip_special_tokens=True))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torch" + ] + }, + { + "cell_type": "code", + "execution_count": 56, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Mean on dim=-1: tensor([[2.7146, 2.4245, 1.5616, 1.4424, 2.6790]], grad_fn=)\n" + ] + }, + { + "data": { + "text/plain": [ + "tensor([ 7.1030, -1.4195, 9.9206, 7.7008, 4.9891, 4.2169, 5.5426, 3.7878,\n", + " 6.7593, 5.7360, 8.4691, 5.5448, 5.0544, 10.4129, 8.5573, 13.0405,\n", + " 7.3265, 3.5868, 6.1106, 5.9406, 5.6376, 5.7490, 5.4850, 4.8124,\n", + " 5.1991, 4.6419, 4.5719, 9.9588, 6.7222, 4.5070],\n", + " grad_fn=)" + ] + }, + "execution_count": 56, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "torch_device = \"cuda:0\"\n", + "\n", + "input_ids = {\n", + " \"input_ids\": torch.tensor([[ 510, 8588, 310, 1900, 9386]], dtype=torch.long)\n", + "}\n", + "model.eval()\n", + "output = model(**input_ids).logits\n", + "print(f\"Mean on dim=-1: {output.mean(dim=-1)}\")\n", + "output[0, 0, 0:30]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Input ids: tensor([[ 3220, 7583, 2739, 556, 1900, 644, 22534, 13, 533, 21403]])\n" + ] + }, + { + "ename": "TypeError", + "evalue": "The current model class (StableLMForCausalLM) is not compatible with `.generate()`, as it doesn't have a language model head. Please use one of the following classes instead: {'StableLMForCausalLM'}", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[23], line 3\u001b[0m\n\u001b[1;32m 1\u001b[0m inputs \u001b[38;5;241m=\u001b[39m tokenizer(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mMy favorite food has always been pizza, but lately\u001b[39m\u001b[38;5;124m\"\u001b[39m, return_tensors\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mpt\u001b[39m\u001b[38;5;124m\"\u001b[39m)\u001b[38;5;241m.\u001b[39mto(model\u001b[38;5;241m.\u001b[39mdevice)\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mInput ids: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00minputs\u001b[38;5;241m.\u001b[39minput_ids\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m----> 3\u001b[0m tokens \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgenerate\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 4\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 5\u001b[0m \u001b[43m \u001b[49m\u001b[43mmax_new_tokens\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m20\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 6\u001b[0m \u001b[43m \u001b[49m\u001b[43mtemperature\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m0.0\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 7\u001b[0m \u001b[43m)\u001b[49m\n\u001b[1;32m 8\u001b[0m \u001b[38;5;28mprint\u001b[39m(tokenizer\u001b[38;5;241m.\u001b[39mdecode(tokens[\u001b[38;5;241m0\u001b[39m], skip_special_tokens\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m))\n", + "File \u001b[0;32m~/stable-lm/transformers/.env/lib/python3.10/site-packages/torch/utils/_contextlib.py:115\u001b[0m, in \u001b[0;36mcontext_decorator..decorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 112\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(func)\n\u001b[1;32m 113\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdecorate_context\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 114\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m ctx_factory():\n\u001b[0;32m--> 115\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/stable-lm/transformers/src/transformers/generation/utils.py:1279\u001b[0m, in \u001b[0;36mGenerationMixin.generate\u001b[0;34m(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, negative_prompt_ids, negative_prompt_attention_mask, **kwargs)\u001b[0m\n\u001b[1;32m 1276\u001b[0m synced_gpus \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n\u001b[1;32m 1278\u001b[0m \u001b[38;5;66;03m# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call\u001b[39;00m\n\u001b[0;32m-> 1279\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_validate_model_class\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1281\u001b[0m \u001b[38;5;66;03m# priority: `generation_config` argument > `model.generation_config` (the default generation config)\u001b[39;00m\n\u001b[1;32m 1282\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m generation_config \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 1283\u001b[0m \u001b[38;5;66;03m# legacy: users may modify the model configuration to control generation. To trigger this legacy behavior,\u001b[39;00m\n\u001b[1;32m 1284\u001b[0m \u001b[38;5;66;03m# three conditions must be met\u001b[39;00m\n\u001b[1;32m 1285\u001b[0m \u001b[38;5;66;03m# 1) the generation config must have been created from the model config (`_from_model_config` field);\u001b[39;00m\n\u001b[1;32m 1286\u001b[0m \u001b[38;5;66;03m# 2) the generation config must have seen no modification since its creation (the hash is the same);\u001b[39;00m\n\u001b[1;32m 1287\u001b[0m \u001b[38;5;66;03m# 3) the user must have set generation parameters in the model config.\u001b[39;00m\n", + "File \u001b[0;32m~/stable-lm/transformers/src/transformers/generation/utils.py:1065\u001b[0m, in \u001b[0;36mGenerationMixin._validate_model_class\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1063\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m generate_compatible_classes:\n\u001b[1;32m 1064\u001b[0m exception_message \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m Please use one of the following classes instead: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mgenerate_compatible_classes\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m-> 1065\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m(exception_message)\n", + "\u001b[0;31mTypeError\u001b[0m: The current model class (StableLMForCausalLM) is not compatible with `.generate()`, as it doesn't have a language model head. Please use one of the following classes instead: {'StableLMForCausalLM'}" + ] + } + ], + "source": [ + "inputs = tokenizer(\"My favorite food has always been pizza, but lately \", return_tensors=\"pt\").to(model.device)\n", + "print(f\"Input ids: {inputs.input_ids}\")\n", + "tokens = model.generate(\n", + " **inputs,\n", + " max_new_tokens=20,\n", + " temperature=0.0,\n", + ")\n", + "print(tokenizer.decode(tokens[0], skip_special_tokens=True))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n" + ] + } + ], + "source": [ + "import transformers\n", + "p = \"/fsx/home-guac/stable-lm/transformers/test/stablelm-3b-4e1t\"\n", + "tokenizer = transformers.AutoTokenizer.from_pretrained(p, trust_remote_code=True)\n", + "tokenizer = transformers.AutoTokenizer.from_pretrained(p, trust_remote_code=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Could not locate the configuration_stablelm_epoch.py inside /fsx/home-guac/stable-lm/transformers/test/stablelm-3b-4e1t.\n" + ] + }, + { + "ename": "OSError", + "evalue": "/fsx/home-guac/stable-lm/transformers/test/stablelm-3b-4e1t does not appear to have a file named configuration_stablelm_epoch.py. Checkout 'https://huggingface.co//fsx/home-guac/stable-lm/transformers/test/stablelm-3b-4e1t/None' for available files.", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mOSError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[10], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m model \u001b[38;5;241m=\u001b[39m \u001b[43mtransformers\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mAutoModelForCausalLM\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfrom_pretrained\u001b[49m\u001b[43m(\u001b[49m\u001b[43mp\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrust_remote_code\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdevice_map\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mauto\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/stable-lm/transformers/src/transformers/models/auto/auto_factory.py:527\u001b[0m, in \u001b[0;36m_BaseAutoModelClass.from_pretrained\u001b[0;34m(cls, pretrained_model_name_or_path, *model_args, **kwargs)\u001b[0m\n\u001b[1;32m 524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m kwargs\u001b[38;5;241m.\u001b[39mget(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mquantization_config\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;28;01mNone\u001b[39;00m) \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 525\u001b[0m _ \u001b[38;5;241m=\u001b[39m kwargs\u001b[38;5;241m.\u001b[39mpop(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mquantization_config\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m--> 527\u001b[0m config, kwargs \u001b[38;5;241m=\u001b[39m \u001b[43mAutoConfig\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfrom_pretrained\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 528\u001b[0m \u001b[43m \u001b[49m\u001b[43mpretrained_model_name_or_path\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 529\u001b[0m \u001b[43m \u001b[49m\u001b[43mreturn_unused_kwargs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 530\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrust_remote_code\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtrust_remote_code\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 531\u001b[0m \u001b[43m \u001b[49m\u001b[43mcode_revision\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcode_revision\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 532\u001b[0m \u001b[43m \u001b[49m\u001b[43m_commit_hash\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcommit_hash\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 533\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mhub_kwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 534\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 535\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 537\u001b[0m \u001b[38;5;66;03m# if torch_dtype=auto was passed here, ensure to pass it on\u001b[39;00m\n\u001b[1;32m 538\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m kwargs_orig\u001b[38;5;241m.\u001b[39mget(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtorch_dtype\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;28;01mNone\u001b[39;00m) \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mauto\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n", + "File \u001b[0;32m~/stable-lm/transformers/src/transformers/models/auto/configuration_auto.py:1115\u001b[0m, in \u001b[0;36mAutoConfig.from_pretrained\u001b[0;34m(cls, pretrained_model_name_or_path, **kwargs)\u001b[0m\n\u001b[1;32m 1113\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m has_remote_code \u001b[38;5;129;01mand\u001b[39;00m trust_remote_code:\n\u001b[1;32m 1114\u001b[0m class_ref \u001b[38;5;241m=\u001b[39m config_dict[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mauto_map\u001b[39m\u001b[38;5;124m\"\u001b[39m][\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mAutoConfig\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n\u001b[0;32m-> 1115\u001b[0m config_class \u001b[38;5;241m=\u001b[39m \u001b[43mget_class_from_dynamic_module\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1116\u001b[0m \u001b[43m \u001b[49m\u001b[43mclass_ref\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpretrained_model_name_or_path\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcode_revision\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcode_revision\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\n\u001b[1;32m 1117\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1118\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m os\u001b[38;5;241m.\u001b[39mpath\u001b[38;5;241m.\u001b[39misdir(pretrained_model_name_or_path):\n\u001b[1;32m 1119\u001b[0m config_class\u001b[38;5;241m.\u001b[39mregister_for_auto_class()\n", + "File \u001b[0;32m~/stable-lm/transformers/src/transformers/dynamic_module_utils.py:488\u001b[0m, in \u001b[0;36mget_class_from_dynamic_module\u001b[0;34m(class_reference, pretrained_model_name_or_path, cache_dir, force_download, resume_download, proxies, token, revision, local_files_only, repo_type, code_revision, **kwargs)\u001b[0m\n\u001b[1;32m 486\u001b[0m code_revision \u001b[38;5;241m=\u001b[39m revision\n\u001b[1;32m 487\u001b[0m \u001b[38;5;66;03m# And lastly we get the class inside our newly created module\u001b[39;00m\n\u001b[0;32m--> 488\u001b[0m final_module \u001b[38;5;241m=\u001b[39m \u001b[43mget_cached_module_file\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 489\u001b[0m \u001b[43m \u001b[49m\u001b[43mrepo_id\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 490\u001b[0m \u001b[43m \u001b[49m\u001b[43mmodule_file\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m+\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43m.py\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 491\u001b[0m \u001b[43m \u001b[49m\u001b[43mcache_dir\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcache_dir\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 492\u001b[0m \u001b[43m \u001b[49m\u001b[43mforce_download\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mforce_download\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 493\u001b[0m \u001b[43m \u001b[49m\u001b[43mresume_download\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mresume_download\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 494\u001b[0m \u001b[43m \u001b[49m\u001b[43mproxies\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mproxies\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 495\u001b[0m \u001b[43m \u001b[49m\u001b[43mtoken\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtoken\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 496\u001b[0m \u001b[43m \u001b[49m\u001b[43mrevision\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcode_revision\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 497\u001b[0m \u001b[43m \u001b[49m\u001b[43mlocal_files_only\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlocal_files_only\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 498\u001b[0m \u001b[43m \u001b[49m\u001b[43mrepo_type\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrepo_type\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 499\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 500\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m get_class_in_module(class_name, final_module\u001b[38;5;241m.\u001b[39mreplace(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m.py\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m\"\u001b[39m))\n", + "File \u001b[0;32m~/stable-lm/transformers/src/transformers/dynamic_module_utils.py:294\u001b[0m, in \u001b[0;36mget_cached_module_file\u001b[0;34m(pretrained_model_name_or_path, module_file, cache_dir, force_download, resume_download, proxies, token, revision, local_files_only, repo_type, _commit_hash, **deprecated_kwargs)\u001b[0m\n\u001b[1;32m 291\u001b[0m new_files \u001b[38;5;241m=\u001b[39m []\n\u001b[1;32m 292\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 293\u001b[0m \u001b[38;5;66;03m# Load from URL or cache if already cached\u001b[39;00m\n\u001b[0;32m--> 294\u001b[0m resolved_module_file \u001b[38;5;241m=\u001b[39m \u001b[43mcached_file\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 295\u001b[0m \u001b[43m \u001b[49m\u001b[43mpretrained_model_name_or_path\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 296\u001b[0m \u001b[43m \u001b[49m\u001b[43mmodule_file\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 297\u001b[0m \u001b[43m \u001b[49m\u001b[43mcache_dir\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcache_dir\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 298\u001b[0m \u001b[43m \u001b[49m\u001b[43mforce_download\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mforce_download\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 299\u001b[0m \u001b[43m \u001b[49m\u001b[43mproxies\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mproxies\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 300\u001b[0m \u001b[43m \u001b[49m\u001b[43mresume_download\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mresume_download\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 301\u001b[0m \u001b[43m \u001b[49m\u001b[43mlocal_files_only\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlocal_files_only\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 302\u001b[0m \u001b[43m \u001b[49m\u001b[43mtoken\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtoken\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 303\u001b[0m \u001b[43m \u001b[49m\u001b[43mrevision\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrevision\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 304\u001b[0m \u001b[43m \u001b[49m\u001b[43mrepo_type\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrepo_type\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 305\u001b[0m \u001b[43m \u001b[49m\u001b[43m_commit_hash\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m_commit_hash\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 306\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 307\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m is_local \u001b[38;5;129;01mand\u001b[39;00m cached_module \u001b[38;5;241m!=\u001b[39m resolved_module_file:\n\u001b[1;32m 308\u001b[0m new_files\u001b[38;5;241m.\u001b[39mappend(module_file)\n", + "File \u001b[0;32m~/stable-lm/transformers/src/transformers/utils/hub.py:369\u001b[0m, in \u001b[0;36mcached_file\u001b[0;34m(path_or_repo_id, filename, cache_dir, force_download, resume_download, proxies, token, revision, local_files_only, subfolder, repo_type, user_agent, _raise_exceptions_for_gated_repo, _raise_exceptions_for_missing_entries, _raise_exceptions_for_connection_errors, _commit_hash, **deprecated_kwargs)\u001b[0m\n\u001b[1;32m 367\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m os\u001b[38;5;241m.\u001b[39mpath\u001b[38;5;241m.\u001b[39misfile(resolved_file):\n\u001b[1;32m 368\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m _raise_exceptions_for_missing_entries:\n\u001b[0;32m--> 369\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mEnvironmentError\u001b[39;00m(\n\u001b[1;32m 370\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mpath_or_repo_id\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m does not appear to have a file named \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mfull_filename\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m. Checkout \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 371\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mhttps://huggingface.co/\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mpath_or_repo_id\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m/\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mrevision\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m for available files.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 372\u001b[0m )\n\u001b[1;32m 373\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 374\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n", + "\u001b[0;31mOSError\u001b[0m: /fsx/home-guac/stable-lm/transformers/test/stablelm-3b-4e1t does not appear to have a file named configuration_stablelm_epoch.py. Checkout 'https://huggingface.co//fsx/home-guac/stable-lm/transformers/test/stablelm-3b-4e1t/None' for available files." + ] + } + ], + "source": [ + "model = transformers.AutoModelForCausalLM.from_pretrained(p, trust_remote_code=True, device_map = \"auto\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".env", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/test/todo b/test/todo new file mode 100644 index 00000000000000..2c82dfd516f026 --- /dev/null +++ b/test/todo @@ -0,0 +1,90 @@ +# What does this PR do? + +This PR adds modeling support for [`StableLM-3B-4E1T`](https://huggingface.co/stabilityai/stablelm-3b-4e1t)based models. + + + + + + +## Before submitting +- [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). +- [x] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), + Pull Request section? +- [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link + to it if that's the case. +- [x] Did you make sure to update the documentation with your changes? Here are the + [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and + [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). +- [x] Did you write any new necessary tests? + +## Who can review? + +Anyone in the community is free to review the PR once the tests have passed. Feel free to tag +members/contributors who may be interested in your PR. + +cc @ArthurZucker + + + + +## Notes + +**TODO**: The current online implementation uses an early naming scheme for the [`model_type`](https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/e3be657f4d1b78eb7520637ba922448a1ee456bd/config.json#L16) +```json + "model_type": "stablelm_epoch", +``` +I've temporarily created a development model repository https://huggingface.co/jon-tow/stablelm-3b-4e1t-dev for unit testing [config archive mapping](https://github.com/jon-tow/transformers/blob/caf38d1659333ae826c0f64d2e3bec45e008081b/src/transformers/models/stablelm/configuration_stablelm.py#L24) which need to be updated before merging. +Is there a better way to handle this? I've noticed a similar [issue](https://github.com/huggingface/transformers/pull/26170#discussion_r1368903469) in the `Phi` model PR. + diff --git a/todo b/todo new file mode 100644 index 00000000000000..08394a8c1a7059 --- /dev/null +++ b/todo @@ -0,0 +1,1258 @@ +Hi, Arthur! I've pushed a commit to re-add the model based on `huggingface-cli add-new-model-like persimmon` +These are the core diffs: +* Modifies `PersimmonAttention` to use unpacked q,k,v projections instead of `query_key_value` -> `q_proj`, `k_proj`, `v_proj` +* Renames `PersimmonAttention.dense` to `o_proj` to avoid breaking compatibility with current checkpoints. +* Remove `qk_norm` attributes. Our soon to be released LM will use a different form of `QK-Normalization` and this may confuse users. +* Adds `FlashAttention2` support as featured in the custom modeling implementation of the base model [`here`](https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/e3be657f4d1b78eb7520637ba922448a1ee456bd/modeling_stablelm_epoch.py#L300) +* Adds MQA/GQA support through `num_key_value_heads` as we internally experiment with GQA. If you'd like me to remove it, I'll be glad to so. +Thank you for your time. + + + + +# coding=utf-8 +# Copyright 2024 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch StableLM model.""" +import math +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache +from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast +from ...modeling_utils import PreTrainedModel +from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings +from .configuration_stablelm import StableLmConfig + + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "StableLmConfig" + + +# Copied from transformers.models.llama.modeling_llama._get_unpad_data +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->StableLm +class StableLmRotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache( + seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() + ) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) + + freqs = torch.outer(t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + + return ( + self.cos_cached[:seq_len].to(dtype=x.dtype), + self.sin_cached[:seq_len].to(dtype=x.dtype), + ) + + +# Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->StableLm +class StableLmLinearScalingRotaryEmbedding(StableLmRotaryEmbedding): + """StableLmRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" + + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) + t = t / self.scaling_factor + + freqs = torch.outer(t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + +# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->StableLm +class StableLmDynamicNTKScalingRotaryEmbedding(StableLmRotaryEmbedding): + """StableLmRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" + + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + + if seq_len > self.max_position_embeddings: + base = self.base * ( + (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) + ) ** (self.dim / (self.dim - 2)) + inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) + + freqs = torch.outer(t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb +def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos[position_ids].unsqueeze(unsqueeze_dim) + sin = sin[position_ids].unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +# Copied from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->StableLm +class StableLmMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + +# Copied from transformers.models.llama.modeling_llama.repeat_kv +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class StableLmAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: StableLmConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.partial_rotary_factor = config.partial_rotary_factor + self.is_causal = True + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.use_qkv_bias) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.use_qkv_bias) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.use_qkv_bias) + self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False) + + self.attention_dropout = nn.Dropout(config.attention_dropout) + self._init_rope() + + # Copied from transformers.models.persimmon.modeling_persimmon.PersimmonAttention with Persimmon->StableLm + def _init_rope(self): + if self.config.rope_scaling is None: + self.rotary_emb = StableLmRotaryEmbedding( + int(self.partial_rotary_factor * self.head_dim), + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + else: + scaling_type = self.config.rope_scaling["type"] + scaling_factor = self.config.rope_scaling["factor"] + if scaling_type == "linear": + self.rotary_emb = StableLmLinearScalingRotaryEmbedding( + int(self.partial_rotary_factor * self.head_dim), + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + elif scaling_type == "dynamic": + self.rotary_emb = StableLmDynamicNTKScalingRotaryEmbedding( + int(self.partial_rotary_factor * self.head_dim), + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + + # Partial rotary embedding + query_rot, query_pass = ( + query_states[..., : self.rotary_emb.dim], + query_states[..., self.rotary_emb.dim :], + ) + key_rot, key_pass = ( + key_states[..., : self.rotary_emb.dim], + key_states[..., self.rotary_emb.dim :], + ) + # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor] + query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids) + + # [batch_size, seq_length, num_heads, head_dim] + query_states = torch.cat((query_rot, query_pass), dim=-1) + key_states = torch.cat((key_rot, key_pass), dim=-1) + + if past_key_value is not None: + # Specific to RoPE models with partial rotation + cache_kwargs = {"sin": sin, "cos": cos, "partial_rotation_size": self.rotary_emb.dim} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # Repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query_states.dtype) + attn_weights = self.attention_dropout(attn_weights) + + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class StableLmFlashAttention2(StableLmAttention): + """ + StableLM flash attention module. This module inherits from `StableLmAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # StableLmFlashAttention2 attention does not support output_attentions + + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + + # Partial rotary embedding + query_rot, query_pass = ( + query_states[..., : self.rotary_emb.dim], + query_states[..., self.rotary_emb.dim :], + ) + key_rot, key_pass = ( + key_states[..., : self.rotary_emb.dim], + key_states[..., self.rotary_emb.dim :], + ) + query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids) + + # [batch_size, seq_length, num_heads, head_dim] + query_states = torch.cat((query_rot, query_pass), dim=-1) + key_states = torch.cat((key_rot, key_pass), dim=-1) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "partial_rotation_size": self.rotary_emb.dim} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.attention_dropout if self.training else 0.0 + + attn_output = self._flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=dropout_rate, + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward + def _flash_attention_forward( + self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`int`, *optional*): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + attn_output = flash_attn_func( + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal + ) + + return attn_output + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +ATTENTION_CLASSES = { + "eager": StableLmAttention, + "flash_attention_2": StableLmFlashAttention2, +} + + +# Copied from transformers.models.persimmon.modeling_persimmon.PersimmonDecoderLayer with Persimmon->StableLm +class StableLmDecoderLayer(nn.Module): + def __init__(self, config: StableLmConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=layer_idx) + self.mlp = StableLmMLP(config) + self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + position_ids (`torch.LongTensor` of shape `({0})`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range + `[0, config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): + cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + + hidden_states = self.dropout(hidden_states) + hidden_states = hidden_states + residual + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +STABLELM_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`StableLmConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare StableLM Model outputting raw hidden-states without any specific head on top.", + STABLELM_START_DOCSTRING, +) +class StableLmPreTrainedModel(PreTrainedModel): + config_class = StableLmConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["StableLmDecoderLayer"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_cache_class = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +STABLELM_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance; + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare StableLM Model outputting raw hidden-states without any specific head on top.", + STABLELM_START_DOCSTRING, +) +# Copied from transformers.models.persimmon.modeling_persimmon.PersimmonModel with PERSIMMON->STABLELM,Persimmon->StableLm +class StableLmModel(StableLmPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`StableLmDecoderLayer`] + + Args: + config: StableLmConfig + """ + + def __init__(self, config: StableLmConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [StableLmDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + # Copied from transformers.models.persimmon.modeling_persimmon.PersimmonModel.final_layernorm with final_layernorm->norm + self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(STABLELM_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + seq_length_with_past = seq_length + past_key_values_length = 0 + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + if use_cache: + use_legacy_cache = not isinstance(past_key_values, Cache) + if use_legacy_cache: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_key_values_length = past_key_values.get_usable_length(seq_length) + seq_length_with_past = seq_length_with_past + past_key_values_length + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + # embed positions + if attention_mask is None: + attention_mask = torch.ones( + (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device + ) + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length + ) + + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + attention_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = None + if use_cache: + next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +# Copied from transformers.models.persimmon.modeling_persimmon.PersimmonForCausalLM with PERSIMMON->STABLELM,Persimmon->StableLm,persimmon->stablelm +class StableLmForCausalLM(StableLmPreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = StableLmModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @add_start_docstrings_to_model_forward(STABLELM_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, StableLmForCausalLM + + >>> model = StableLmForCausalLM.from_pretrained("stabilityai/stablelm-3b-4e1t") + >>> tokenizer = AutoTokenizer.from_pretrained("stabilityai/stablelm-3b-4e1t") + + >>> prompt = "The weather is always wonderful in" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + 'The weather is always wonderful in the San Juan Islands, and whether you're a vacationer or a resident, here are some ideas for fun!' + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs + ): + if past_key_values is not None: + if isinstance(past_key_values, Cache): + cache_length = past_key_values.get_seq_length() + past_length = past_key_values.seen_tokens + max_cache_length = past_key_values.get_max_length() + else: + cache_length = past_length = past_key_values[0][0].shape[2] + max_cache_length = None + + # Keep only the unprocessed tokens: + # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where + # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as + # input) + if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard + # input_ids based on the past_length. + elif past_length < input_ids.shape[1]: + input_ids = input_ids[:, past_length:] + # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. + + # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. + if ( + max_cache_length is not None + and attention_mask is not None + and cache_length + input_ids.shape[1] > max_cache_length + ): + attention_mask = attention_mask[:, -max_cache_length:] + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + } + ) + return model_inputs + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) + return reordered_past + + +@add_start_docstrings( + """ + The StableLM transformer with a sequence classification head on top (linear layer). + + [`StableLmForSequenceClassification`] uses the last token in order to do the classification, as other causal + models (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + STABLELM_START_DOCSTRING, +) +# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with LLAMA->STABLELM,Llama->StableLm +class StableLmForSequenceClassification(StableLmPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = StableLmModel(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(STABLELM_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility + sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths.to(logits.device) + else: + sequence_lengths = -1 + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +from transformers import AutoModelForCausalLM, AutoTokenizer +tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b") +model = AutoModelForCausalLM.from_pretrained( + "stabilityai/stablelm-3b-4e1t", + torch_dtype="auto", + attn_implementation="flash_attention_2", +) +model.cuda() +inputs = tokenizer("The weather is always wonderful", return_tensors="pt").to(model.device) +tokens = model.generate( + **inputs, + max_new_tokens=64, + temperature=0.75, + top_p=0.95, + do_sample=True, +) +print(tokenizer.decode(tokens[0], skip_special_tokens=True)) \ No newline at end of file