Skip to content

Commit

Permalink
feat: remove debug cuda avoid
Browse files Browse the repository at this point in the history
  • Loading branch information
drbh committed Feb 26, 2024
1 parent 66f8912 commit de421dc
Showing 1 changed file with 1 addition and 14 deletions.
15 changes: 1 addition & 14 deletions server/text_generation_server/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from text_generation_server.utils.speculate import get_speculate, set_speculate
from text_generation_server.models.model import Model
from text_generation_server.models.causal_lm import CausalLM
from text_generation_server.models.flash_causal_lm import FlashCausalLM
from text_generation_server.models.bloom import BLOOMSharded
from text_generation_server.models.mpt import MPTSharded
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
Expand Down Expand Up @@ -45,20 +46,6 @@

FLASH_ATTENTION = True

# FlashCausalLM reqiures CUDA Graphs to be enabled on the system. This will throw a RuntimeError
# if CUDA Graphs are not available when calling `torch.cuda.graph_pool_handle()` in the FlashCausalLM
HAS_CUDA_GRAPH = False
try:
from text_generation_server.models.flash_causal_lm import FlashCausalLM

HAS_CUDA_GRAPH = True
except RuntimeError as e:
logger.warning(f"Could not import FlashCausalLM: {e}")

if HAS_CUDA_GRAPH:
__all__.append(FlashCausalLM)


try:
from text_generation_server.models.flash_rw import FlashRWSharded
from text_generation_server.models.flash_neox import FlashNeoXSharded
Expand Down

0 comments on commit de421dc

Please sign in to comment.