From cefa4cbf3f028522cac5b3aa290e1f56325df86c Mon Sep 17 00:00:00 2001 From: zgy Date: Tue, 24 Jan 2023 22:03:32 +0800 Subject: [PATCH] FIX #62; ADD GPT-J example #63 --- bminf/scheduler/__init__.py | 52 ++++++-- example/huggingface/gpt-j.ipynb | 211 ++++++++++++++++++++++++++++++++ requirements.txt | 3 +- 3 files changed, 257 insertions(+), 9 deletions(-) create mode 100644 example/huggingface/gpt-j.ipynb diff --git a/bminf/scheduler/__init__.py b/bminf/scheduler/__init__.py index 65a9b4a..8d638b9 100644 --- a/bminf/scheduler/__init__.py +++ b/bminf/scheduler/__init__.py @@ -1,12 +1,24 @@ import torch -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union, Dict, Set from cpm_kernels.library import cudart +from typing_extensions import TypedDict -def calc_fixed_layers(total_layers : int, max_fixed : int): +class ParameterInfo(TypedDict): + shape : torch.Size + dtype : torch.dtype + + +class SchedLayerInfo(TypedDict): + parameters : Dict[str, torch.Tensor] + evt : torch.cuda.Event + unused : bool + id : int + +def calc_fixed_layers(total_layers : int, max_fixed : int) -> List[int]: max_fixed = min(max_fixed, total_layers) scheduled_layers = total_layers - max_fixed vals = [(i + 1) * scheduled_layers // total_layers for i in range(total_layers)] - ret = [] + ret : List[int] = [] last_v = 0 for i, v in enumerate(vals): if v == last_v: @@ -19,16 +31,23 @@ def pin_layer(m : torch.nn.Module): for param in m.parameters(): with torch.no_grad(): param.data = param.data.pin_memory() + for buf in m.buffers(): + with torch.no_grad(): + buf.data = buf.data.pin_memory() return m -def transfer_layers(m_src : torch.nn.Module, m_dst : dict): +def transfer_layers(m_src : torch.nn.Module, m_dst : Dict[str, torch.Tensor]): with torch.no_grad(): for name, param in m_src.named_parameters(): assert name in m_dst # copy to device buffer m_dst[name].copy_(param, non_blocking=True) + for name, buf in m_src.named_buffers(): + assert name in m_dst + m_dst[name].copy_(buf, non_blocking=True) + -def swap_params(m_src : torch.nn.Module, m_dst : dict): +def swap_params(m_src : torch.nn.Module, m_dst : Dict[str, torch.Tensor]): with torch.no_grad(): for name, param in m_src.named_parameters(): assert name in m_dst @@ -37,6 +56,13 @@ def swap_params(m_src : torch.nn.Module, m_dst : dict): tmp = m_dst[name].data m_dst[name].data = param.data param.data = tmp + for name, buf in m_src.named_buffers(): + assert name in m_dst + + # swap memory info + tmp = m_dst[name].data + m_dst[name].data = buf.data + buf.data = tmp class OpDeviceLayer(torch.autograd.Function): @staticmethod @@ -176,8 +202,8 @@ def __init__(self, layers : List[torch.nn.Module], device_id : int, memory_limit self._device = device_id self._num_layers = len(layers) - self._fixed_layers = set() - self._sched_layers = [] + self._fixed_layers : Set[int] = set() + self._sched_layers : List[SchedLayerInfo] = [] self._layers = [] self._active_layers = {} @@ -193,6 +219,8 @@ def __init__(self, layers : List[torch.nn.Module], device_id : int, memory_limit total_size = 0 for param in layers[0].parameters(): total_size += param.numel() * param.storage().element_size() + for buf in layers[0].buffers(): + total_size += buf.numel() * buf.storage().element_size() total_layers = free_mem // total_size if total_layers < 2: @@ -217,7 +245,12 @@ def __init__(self, layers : List[torch.nn.Module], device_id : int, memory_limit if i not in self._fixed_layers: self._active_layers[i] = len(self._sched_layers) self._sched_layers.append({ - "parameters": { name: param.cuda() for name, param in layers[i].named_parameters()}, + "parameters": { + name: param.cuda() for name, param in ( + list(layers[i].named_parameters()) + + list(layers[i].named_buffers()) + ) + }, "evt": torch.cuda.Event(), "id": i, "unused": True @@ -375,6 +408,9 @@ def __iter__(self): for sched in self._scheds: for layer in sched: yield layer + + def __len__(self): + return len(self.layers) def forward(self, x, *args, **kwargs): for sched in self._scheds: diff --git a/example/huggingface/gpt-j.ipynb b/example/huggingface/gpt-j.ipynb new file mode 100644 index 0000000..ad315b5 --- /dev/null +++ b/example/huggingface/gpt-j.ipynb @@ -0,0 +1,211 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# GPT-J 6B" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. Load model and tokenizer from HuggingFace Hub\n", + "\n", + "GPT-J is loaded in fp32 mode by default which takes about 24GB CPU memory." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from transformers import AutoTokenizer, AutoModelForCausalLM\n", + "\n", + "tokenizer = AutoTokenizer.from_pretrained(\"EleutherAI/gpt-j-6B\")\n", + "\n", + "model = AutoModelForCausalLM.from_pretrained(\"EleutherAI/gpt-j-6B\")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. Use BMInf wrapper for low-resource inference" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import bminf\n", + "with torch.cuda.device(0):\n", + " model = bminf.wrapper(model, quantization=False, memory_limit=8 << 30) # 8GB" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3. See the GPU usage" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "|===========================================================================|\n", + "| PyTorch CUDA memory summary, device ID 0 |\n", + "|---------------------------------------------------------------------------|\n", + "| CUDA OOMs: 0 | cudaMalloc retries: 0 |\n", + "|===========================================================================|\n", + "| Metric | Cur Usage | Peak Usage | Tot Alloc | Tot Freed |\n", + "|---------------------------------------------------------------------------|\n", + "| Allocated memory | 9297 MB | 9297 MB | 9297 MB | 0 B |\n", + "| from large pool | 9296 MB | 9296 MB | 9296 MB | 0 B |\n", + "| from small pool | 1 MB | 1 MB | 1 MB | 0 B |\n", + "|---------------------------------------------------------------------------|\n", + "| Active memory | 9297 MB | 9297 MB | 9297 MB | 0 B |\n", + "| from large pool | 9296 MB | 9296 MB | 9296 MB | 0 B |\n", + "| from small pool | 1 MB | 1 MB | 1 MB | 0 B |\n", + "|---------------------------------------------------------------------------|\n", + "| GPU reserved memory | 9298 MB | 9298 MB | 9298 MB | 0 B |\n", + "| from large pool | 9296 MB | 9296 MB | 9296 MB | 0 B |\n", + "| from small pool | 2 MB | 2 MB | 2 MB | 0 B |\n", + "|---------------------------------------------------------------------------|\n", + "| Non-releasable memory | 710656 B | 18400 KB | 34800 KB | 34106 KB |\n", + "| from large pool | 0 B | 16384 KB | 32768 KB | 32768 KB |\n", + "| from small pool | 710656 B | 2032 KB | 2032 KB | 1338 KB |\n", + "|---------------------------------------------------------------------------|\n", + "| Allocations | 125 | 125 | 125 | 0 |\n", + "| from large pool | 72 | 72 | 72 | 0 |\n", + "| from small pool | 53 | 53 | 53 | 0 |\n", + "|---------------------------------------------------------------------------|\n", + "| Active allocs | 125 | 125 | 125 | 0 |\n", + "| from large pool | 72 | 72 | 72 | 0 |\n", + "| from small pool | 53 | 53 | 53 | 0 |\n", + "|---------------------------------------------------------------------------|\n", + "| GPU reserved segments | 65 | 65 | 65 | 0 |\n", + "| from large pool | 64 | 64 | 64 | 0 |\n", + "| from small pool | 1 | 1 | 1 | 0 |\n", + "|---------------------------------------------------------------------------|\n", + "| Non-releasable allocs | 1 | 2 | 3 | 2 |\n", + "| from large pool | 0 | 1 | 2 | 2 |\n", + "| from small pool | 1 | 1 | 1 | 0 |\n", + "|---------------------------------------------------------------------------|\n", + "| Oversize allocations | 0 | 0 | 0 | 0 |\n", + "|---------------------------------------------------------------------------|\n", + "| Oversize GPU segments | 0 | 0 | 0 | 0 |\n", + "|===========================================================================|\n", + "\n" + ] + } + ], + "source": [ + "print(torch.cuda.memory_summary())" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 4. Run generation" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n", + "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n" + ] + } + ], + "source": [ + "prompt = \"To be or not to be, that\"\n", + "input_ids = tokenizer(prompt, return_tensors=\"pt\").input_ids\n", + "gen_tokens = model.generate(\n", + " input_ids.cuda(),\n", + " do_sample=True,\n", + " temperature=0.9,\n", + " max_length=20\n", + ")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 5. Get the generated text" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['To be or not to be, that is the question — that has been the question, and still']" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tokenizer.batch_decode(gen_tokens)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.10" + }, + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "29d71688ffbe7d005e79abd80e578fa5cab2d2c2e11d1955de002b95fcc7229b" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/requirements.txt b/requirements.txt index 7c9ca88..37f5cb0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,3 @@ torch -cpm_kernels>=1.0.9 \ No newline at end of file +cpm_kernels>=1.0.9 +typing_extensions