Skip to content

Commit

Permalink
Merge branch 'master' into gma/add_autotp_workflow
Browse files Browse the repository at this point in the history
  • Loading branch information
loadams authored Feb 22, 2024
2 parents e5b0b18 + d5fa87f commit 934bbc9
Show file tree
Hide file tree
Showing 30 changed files with 565 additions and 131 deletions.
8 changes: 4 additions & 4 deletions .github/workflows/nv-inference.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ jobs:

- name: Install pytorch
run: |
pip install -U --cache-dir $TORCH_CACHE torch==1.13.1 torchvision --extra-index-url https://download.pytorch.org/whl/cu116
pip install -U --cache-dir $TORCH_CACHE torch==2.1.2 torchvision==0.16.2 --extra-index-url https://download.pytorch.org/whl/cu118
python -c "import torch; print('torch:', torch.__version__, torch)"
python -c "import torch; print('CUDA available:', torch.cuda.is_available())"
Expand All @@ -57,6 +57,6 @@ jobs:
run: |
unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch
cd tests
#pytest $PYTEST_OPTS -m 'seq_inference' unit/ --torch_ver="1.13" --cuda_ver="11.6"
pytest $PYTEST_OPTS -m 'inference_ops' unit/ --torch_ver="1.13" --cuda_ver="11.6"
pytest $PYTEST_OPTS --forked -n 4 -m 'inference' unit/ --torch_ver="1.13" --cuda_ver="11.6"
#pytest $PYTEST_OPTS -m 'seq_inference' unit/ --torch_ver="2.1" --cuda_ver="11.8"
pytest $PYTEST_OPTS -m 'inference_ops' unit/ --torch_ver="2.1" --cuda_ver="11.8"
pytest $PYTEST_OPTS --forked -n 4 -m 'inference' unit/ --torch_ver="2.1" --cuda_ver="11.8"
15 changes: 8 additions & 7 deletions accelerator/cuda_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

# DeepSpeed Team

import functools
import os
import pkgutil
import importlib
Expand Down Expand Up @@ -260,31 +261,31 @@ def replay_graph(self, graph):

@property
def BFloat16Tensor(self):
return torch.cuda.BFloat16Tensor
return functools.partial(torch.tensor, dtype=torch.bfloat16, device='cuda')

@property
def ByteTensor(self):
return torch.cuda.ByteTensor
return functools.partial(torch.tensor, dtype=torch.uint8, device='cuda')

@property
def DoubleTensor(self):
return torch.cuda.DoubleTensor
return functools.partial(torch.tensor, dtype=torch.double, device='cuda')

@property
def FloatTensor(self):
return torch.cuda.FloatTensor
return functools.partial(torch.tensor, dtype=torch.float, device='cuda')

@property
def HalfTensor(self):
return torch.cuda.HalfTensor
return functools.partial(torch.tensor, dtype=torch.half, device='cuda')

@property
def IntTensor(self):
return torch.cuda.IntTensor
return functools.partial(torch.tensor, dtype=torch.int, device='cuda')

@property
def LongTensor(self):
return torch.cuda.LongTensor
return functools.partial(torch.tensor, dtype=torch.long, device='cuda')

def pin_memory(self, tensor, align_bytes=1):
return tensor.pin_memory()
Expand Down
7 changes: 2 additions & 5 deletions deepspeed/inference/v2/checkpoint/huggingface_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,13 @@ def _fetch_checkpoint_files(self):
# currently coming from the ckpt engine init but maybe a catch all kwargs for other
# snapshot download parameters would be more flexible.

# NOTE(jeff): allow_patterns here are explicitly not using safetensors or other
# checkpoint files that may be present. Example of all files in the llama-2-7b
# repo here: https://huggingface.co/meta-llama/Llama-2-7b-hf/tree/main
from huggingface_hub import snapshot_download, list_files_info
from huggingface_hub import snapshot_download, list_repo_tree

def model_has_safetensors(model_name_or_path: str) -> bool:
if os.path.isdir(model_name_or_path):
file_list = os.listdir(model_name_or_path)
else:
file_list = [rf.rfilename for rf in list_files_info(model_name_or_path)]
file_list = [rf.path for rf in list_repo_tree(model_name_or_path)]
for f in file_list:
if f.endswith(".safetensors"):
return True
Expand Down
13 changes: 13 additions & 0 deletions deepspeed/profiling/flops_profiler/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -827,6 +827,15 @@ def _elementwise_flops_compute(input, other):
return flops, 0


def _attn_flops_compute(q, k, v, *args, **kwargs):
"""
Count flops for the scaled_dot_product_attention operation.
"""
macs = _prod(q.shape) * k.shape[-2]
macs += _prod(q.shape[:-1]) * k.shape[-2] * v.shape[-1]
return 2 * macs, macs


def wrapFunc(func, funcFlopCompute):
oldFunc = func
name = func.__str__
Expand Down Expand Up @@ -899,10 +908,14 @@ def _patch_functionals():
# embedding
F.embedding = wrapFunc(F.embedding, _embedding_flops_compute)

# attn
F.scaled_dot_product_attention = wrapFunc(F.scaled_dot_product_attention, _attn_flops_compute)


def _patch_tensor_methods():
torch.matmul = wrapFunc(torch.matmul, _matmul_flops_compute)
torch.Tensor.matmul = wrapFunc(torch.Tensor.matmul, _matmul_flops_compute)
torch.Tensor.__matmul__ = wrapFunc(torch.Tensor.__matmul__, _matmul_flops_compute)
torch.mm = wrapFunc(torch.mm, _matmul_flops_compute)
torch.Tensor.mm = wrapFunc(torch.Tensor.mm, _matmul_flops_compute)
torch.bmm = wrapFunc(torch.bmm, _matmul_flops_compute)
Expand Down
Loading

0 comments on commit 934bbc9

Please sign in to comment.