diff --git a/flake.lock b/flake.lock index ec87d569231..81b0d2a13c6 100644 --- a/flake.lock +++ b/flake.lock @@ -978,15 +978,16 @@ "nixpkgs": "nixpkgs_6" }, "locked": { - "lastModified": 1732218602, - "narHash": "sha256-BElslL34KjOJCFMPkNtilOz6S/7iY7Vd72FNbRRWKDY=", + "lastModified": 1734861790, + "narHash": "sha256-3afC0dDIkjOICziL4voDchZIkP14g8KM0xilGjt0cio=", "owner": "huggingface", "repo": "text-generation-inference-nix", - "rev": "f79638ac4e420e661321261744e745a3a747e182", + "rev": "29728b3bb43517114aa3025a270bcda4fe78de9f", "type": "github" }, "original": { "owner": "huggingface", + "ref": "flashinfer-v0.2", "repo": "text-generation-inference-nix", "type": "github" } diff --git a/flake.nix b/flake.nix index 83cedfa620f..a302db3eac8 100644 --- a/flake.nix +++ b/flake.nix @@ -5,7 +5,7 @@ inputs.nixpkgs.follows = "tgi-nix/nixpkgs"; }; nix-filter.url = "github:numtide/nix-filter"; - tgi-nix.url = "github:huggingface/text-generation-inference-nix"; + tgi-nix.url = "github:huggingface/text-generation-inference-nix/flashinfer-v0.2"; nixpkgs.follows = "tgi-nix/nixpkgs"; flake-utils.url = "github:numtide/flake-utils"; rust-overlay = { diff --git a/server/Makefile-flashinfer b/server/Makefile-flashinfer index f0a27622a17..1712827b9b7 100644 --- a/server/Makefile-flashinfer +++ b/server/Makefile-flashinfer @@ -1,2 +1,2 @@ install-flashinfer: - pip install flashinfer==0.1.6 -i https://flashinfer.ai/whl/cu124/torch2.4 + pip install flashinfer==0.2.0 -i https://flashinfer.ai/whl/cu124/torch2.4 diff --git a/server/text_generation_server/layers/attention/flashinfer.py b/server/text_generation_server/layers/attention/flashinfer.py index 26a72d9be71..ea1bc1d7f97 100644 --- a/server/text_generation_server/layers/attention/flashinfer.py +++ b/server/text_generation_server/layers/attention/flashinfer.py @@ -93,7 +93,7 @@ def use_prefill_with_paged_kv_state( head_dim=head_size, q_data_type=dtype, page_size=page_size, - window_left=window_left, + window_left=-1 if window_left is None else window_left, ) yield finally: @@ -139,7 +139,7 @@ def use_prefill_state( num_kv_heads=num_kv_heads, head_dim=head_size, q_data_type=dtype, - window_left=window_left, + window_left=-1 if window_left is None else window_left, ) yield finally: @@ -243,7 +243,7 @@ def use_decode_state( page_size=page_size, data_type=kv_cache_dtype, q_data_type=dtype, - window_left=window_left, + window_left=-1 if window_left is None else window_left, ) yield finally: