Skip to content

Commit

Permalink
FIX #62; ADD GPT-J example #63
Browse files Browse the repository at this point in the history
  • Loading branch information
a710128 committed Jan 24, 2023
1 parent 949ffef commit cefa4cb
Show file tree
Hide file tree
Showing 3 changed files with 257 additions and 9 deletions.
52 changes: 44 additions & 8 deletions bminf/scheduler/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,24 @@
import torch
from typing import List, Optional, Tuple, Union
from typing import List, Optional, Tuple, Union, Dict, Set
from cpm_kernels.library import cudart
from typing_extensions import TypedDict

def calc_fixed_layers(total_layers : int, max_fixed : int):
class ParameterInfo(TypedDict):
shape : torch.Size
dtype : torch.dtype


class SchedLayerInfo(TypedDict):
parameters : Dict[str, torch.Tensor]
evt : torch.cuda.Event
unused : bool
id : int

def calc_fixed_layers(total_layers : int, max_fixed : int) -> List[int]:
max_fixed = min(max_fixed, total_layers)
scheduled_layers = total_layers - max_fixed
vals = [(i + 1) * scheduled_layers // total_layers for i in range(total_layers)]
ret = []
ret : List[int] = []
last_v = 0
for i, v in enumerate(vals):
if v == last_v:
Expand All @@ -19,16 +31,23 @@ def pin_layer(m : torch.nn.Module):
for param in m.parameters():
with torch.no_grad():
param.data = param.data.pin_memory()
for buf in m.buffers():
with torch.no_grad():
buf.data = buf.data.pin_memory()
return m

def transfer_layers(m_src : torch.nn.Module, m_dst : dict):
def transfer_layers(m_src : torch.nn.Module, m_dst : Dict[str, torch.Tensor]):
with torch.no_grad():
for name, param in m_src.named_parameters():
assert name in m_dst
# copy to device buffer
m_dst[name].copy_(param, non_blocking=True)
for name, buf in m_src.named_buffers():
assert name in m_dst
m_dst[name].copy_(buf, non_blocking=True)


def swap_params(m_src : torch.nn.Module, m_dst : dict):
def swap_params(m_src : torch.nn.Module, m_dst : Dict[str, torch.Tensor]):
with torch.no_grad():
for name, param in m_src.named_parameters():
assert name in m_dst
Expand All @@ -37,6 +56,13 @@ def swap_params(m_src : torch.nn.Module, m_dst : dict):
tmp = m_dst[name].data
m_dst[name].data = param.data
param.data = tmp
for name, buf in m_src.named_buffers():
assert name in m_dst

# swap memory info
tmp = m_dst[name].data
m_dst[name].data = buf.data
buf.data = tmp

class OpDeviceLayer(torch.autograd.Function):
@staticmethod
Expand Down Expand Up @@ -176,8 +202,8 @@ def __init__(self, layers : List[torch.nn.Module], device_id : int, memory_limit
self._device = device_id
self._num_layers = len(layers)

self._fixed_layers = set()
self._sched_layers = []
self._fixed_layers : Set[int] = set()
self._sched_layers : List[SchedLayerInfo] = []
self._layers = []
self._active_layers = {}

Expand All @@ -193,6 +219,8 @@ def __init__(self, layers : List[torch.nn.Module], device_id : int, memory_limit
total_size = 0
for param in layers[0].parameters():
total_size += param.numel() * param.storage().element_size()
for buf in layers[0].buffers():
total_size += buf.numel() * buf.storage().element_size()

total_layers = free_mem // total_size
if total_layers < 2:
Expand All @@ -217,7 +245,12 @@ def __init__(self, layers : List[torch.nn.Module], device_id : int, memory_limit
if i not in self._fixed_layers:
self._active_layers[i] = len(self._sched_layers)
self._sched_layers.append({
"parameters": { name: param.cuda() for name, param in layers[i].named_parameters()},
"parameters": {
name: param.cuda() for name, param in (
list(layers[i].named_parameters())
+ list(layers[i].named_buffers())
)
},
"evt": torch.cuda.Event(),
"id": i,
"unused": True
Expand Down Expand Up @@ -375,6 +408,9 @@ def __iter__(self):
for sched in self._scheds:
for layer in sched:
yield layer

def __len__(self):
return len(self.layers)

def forward(self, x, *args, **kwargs):
for sched in self._scheds:
Expand Down
211 changes: 211 additions & 0 deletions example/huggingface/gpt-j.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# GPT-J 6B"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## 1. Load model and tokenizer from HuggingFace Hub\n",
"\n",
"GPT-J is loaded in fp32 mode by default which takes about 24GB CPU memory."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from transformers import AutoTokenizer, AutoModelForCausalLM\n",
"\n",
"tokenizer = AutoTokenizer.from_pretrained(\"EleutherAI/gpt-j-6B\")\n",
"\n",
"model = AutoModelForCausalLM.from_pretrained(\"EleutherAI/gpt-j-6B\")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2. Use BMInf wrapper for low-resource inference"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import bminf\n",
"with torch.cuda.device(0):\n",
" model = bminf.wrapper(model, quantization=False, memory_limit=8 << 30) # 8GB"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## 3. See the GPU usage"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"|===========================================================================|\n",
"| PyTorch CUDA memory summary, device ID 0 |\n",
"|---------------------------------------------------------------------------|\n",
"| CUDA OOMs: 0 | cudaMalloc retries: 0 |\n",
"|===========================================================================|\n",
"| Metric | Cur Usage | Peak Usage | Tot Alloc | Tot Freed |\n",
"|---------------------------------------------------------------------------|\n",
"| Allocated memory | 9297 MB | 9297 MB | 9297 MB | 0 B |\n",
"| from large pool | 9296 MB | 9296 MB | 9296 MB | 0 B |\n",
"| from small pool | 1 MB | 1 MB | 1 MB | 0 B |\n",
"|---------------------------------------------------------------------------|\n",
"| Active memory | 9297 MB | 9297 MB | 9297 MB | 0 B |\n",
"| from large pool | 9296 MB | 9296 MB | 9296 MB | 0 B |\n",
"| from small pool | 1 MB | 1 MB | 1 MB | 0 B |\n",
"|---------------------------------------------------------------------------|\n",
"| GPU reserved memory | 9298 MB | 9298 MB | 9298 MB | 0 B |\n",
"| from large pool | 9296 MB | 9296 MB | 9296 MB | 0 B |\n",
"| from small pool | 2 MB | 2 MB | 2 MB | 0 B |\n",
"|---------------------------------------------------------------------------|\n",
"| Non-releasable memory | 710656 B | 18400 KB | 34800 KB | 34106 KB |\n",
"| from large pool | 0 B | 16384 KB | 32768 KB | 32768 KB |\n",
"| from small pool | 710656 B | 2032 KB | 2032 KB | 1338 KB |\n",
"|---------------------------------------------------------------------------|\n",
"| Allocations | 125 | 125 | 125 | 0 |\n",
"| from large pool | 72 | 72 | 72 | 0 |\n",
"| from small pool | 53 | 53 | 53 | 0 |\n",
"|---------------------------------------------------------------------------|\n",
"| Active allocs | 125 | 125 | 125 | 0 |\n",
"| from large pool | 72 | 72 | 72 | 0 |\n",
"| from small pool | 53 | 53 | 53 | 0 |\n",
"|---------------------------------------------------------------------------|\n",
"| GPU reserved segments | 65 | 65 | 65 | 0 |\n",
"| from large pool | 64 | 64 | 64 | 0 |\n",
"| from small pool | 1 | 1 | 1 | 0 |\n",
"|---------------------------------------------------------------------------|\n",
"| Non-releasable allocs | 1 | 2 | 3 | 2 |\n",
"| from large pool | 0 | 1 | 2 | 2 |\n",
"| from small pool | 1 | 1 | 1 | 0 |\n",
"|---------------------------------------------------------------------------|\n",
"| Oversize allocations | 0 | 0 | 0 | 0 |\n",
"|---------------------------------------------------------------------------|\n",
"| Oversize GPU segments | 0 | 0 | 0 | 0 |\n",
"|===========================================================================|\n",
"\n"
]
}
],
"source": [
"print(torch.cuda.memory_summary())"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## 4. Run generation"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
"Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n"
]
}
],
"source": [
"prompt = \"To be or not to be, that\"\n",
"input_ids = tokenizer(prompt, return_tensors=\"pt\").input_ids\n",
"gen_tokens = model.generate(\n",
" input_ids.cuda(),\n",
" do_sample=True,\n",
" temperature=0.9,\n",
" max_length=20\n",
")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## 5. Get the generated text"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['To be or not to be, that is the question — that has been the question, and still']"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tokenizer.batch_decode(gen_tokens)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "29d71688ffbe7d005e79abd80e578fa5cab2d2c2e11d1955de002b95fcc7229b"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
torch
cpm_kernels>=1.0.9
cpm_kernels>=1.0.9
typing_extensions

0 comments on commit cefa4cb

Please sign in to comment.