Skip to content

Commit

Permalink
python-style
Browse files Browse the repository at this point in the history
Signed-off-by: youkaichao <[email protected]>
  • Loading branch information
youkaichao committed Dec 11, 2024
1 parent 82d60e6 commit 84aa0b3
Showing 1 changed file with 10 additions and 7 deletions.
17 changes: 10 additions & 7 deletions vllm/compilation/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,15 +58,18 @@ def serialize(self) -> str:
printer = pprint.PrettyPrinter(indent=4)
return printer.pformat(data)

def exists(self, runtime_shape: Optional[int], graph_index) -> bool:
def __contains__(self, key: Tuple[Optional[int], int]) -> bool:
runtime_shape, graph_index = key
return runtime_shape in self.cache and graph_index in self.cache[
runtime_shape]

def get(self, runtime_shape: Optional[int], graph_index) -> str:
def __getitem__(self, key: Tuple[Optional[int], int]) -> str:
runtime_shape, graph_index = key
return self.cache[runtime_shape][graph_index]

def store(self, runtime_shape: Optional[int], graph_index, hash_str: str):
self.cache[runtime_shape][graph_index] = hash_str
def __setitem__(self, key: Tuple[Optional[int], int], value: str):
runtime_shape, graph_index = key
self.cache[runtime_shape][graph_index] = value


class AlwaysHitShapeEnv:
Expand Down Expand Up @@ -119,10 +122,10 @@ def wrap_inductor(graph,
# see https://github.com/pytorch/pytorch/issues/138980
graph = copy.deepcopy(graph)
cache_data = compilation_config.inductor_hash_cache
if cache_data.exists(runtime_shape, graph_index):
if (runtime_shape, graph_index) in cache_data:
from torch._inductor.codecache import FxGraphCache
from torch._inductor.compile_fx import graph_returns_tuple
hash_str = cache_data.get(runtime_shape, graph_index)
hash_str = cache_data[(runtime_shape, graph_index)]
with patch("torch._inductor.codecache.FxGraphCache._get_shape_env",
lambda *args, **kwargs: AlwaysHitShapeEnv()):
inductor_compiled_graph = FxGraphCache._lookup_graph(
Expand All @@ -142,7 +145,7 @@ def compiled_graph(*args):
def mocked_compiled_fx_graph_hash(*args, **kwargs):
out = compiled_fx_graph_hash(*args, **kwargs)
nonlocal cache_data
cache_data.store(runtime_shape, graph_index, out[0])
cache_data[(runtime_shape, graph_index)] = out[0]
return out

with patch("torch._inductor.codecache.compiled_fx_graph_hash",
Expand Down

0 comments on commit 84aa0b3

Please sign in to comment.