diff --git a/integration-tests/models/__snapshots__/test_flash_llama/test_flash_llama.json b/integration-tests/models/__snapshots__/test_flash_llama/test_flash_llama.json index 49bc996c195..a7f7d2f0bdb 100644 --- a/integration-tests/models/__snapshots__/test_flash_llama/test_flash_llama.json +++ b/integration-tests/models/__snapshots__/test_flash_llama/test_flash_llama.json @@ -16,7 +16,7 @@ }, { "id": 2009, - "logprob": -11.5546875, + "logprob": -11.546875, "text": "request" } ], @@ -24,65 +24,66 @@ "tokens": [ { "id": 363, - "logprob": -1.5380859, + "logprob": -1.5351562, "special": false, "text": " for" }, { "id": 847, - "logprob": -2.5917969, + "logprob": -2.5722656, "special": false, "text": " /" }, { "id": 2754, - "logprob": -2.2773438, + "logprob": -2.2714844, "special": false, "text": "api" }, { "id": 29914, - "logprob": -0.034362793, + "logprob": -0.03414917, "special": false, "text": "/" }, { "id": 29894, - "logprob": -0.96533203, + "logprob": -0.95996094, "special": false, "text": "v" }, { "id": 29896, - "logprob": -0.36669922, + "logprob": -0.3635254, "special": false, "text": "1" }, { "id": 29914, - "logprob": -0.013122559, + "logprob": -0.013031006, "special": false, "text": "/" }, { "id": 16418, - "logprob": -3.1503906, + "logprob": -3.1523438, "special": false, "text": "projects" }, { "id": 29914, - "logprob": -0.43652344, + "logprob": -0.43701172, "special": false, "text": "/" }, { "id": 29896, - "logprob": -1.9404297, + "logprob": -1.9394531, "special": false, "text": "1" } - ] + ], + "top_tokens": null }, - "generated_text": "for /api/v1/projects/1" + "generated_text": " for /api/v1/projects/1" } diff --git a/integration-tests/models/__snapshots__/test_flash_llama/test_flash_llama_all_params.json b/integration-tests/models/__snapshots__/test_flash_llama/test_flash_llama_all_params.json index 5be2870da8f..9f145377725 100644 --- a/integration-tests/models/__snapshots__/test_flash_llama/test_flash_llama_all_params.json +++ b/integration-tests/models/__snapshots__/test_flash_llama/test_flash_llama_all_params.json @@ -16,7 +16,7 @@ }, { "id": 2009, - "logprob": -11.5546875, + "logprob": -11.546875, "text": "request" } ], @@ -24,19 +24,19 @@ "tokens": [ { "id": 5229, - "logprob": -2.5683594, + "logprob": -2.5839844, "special": false, "text": " failed" }, { "id": 29901, - "logprob": -0.45336914, + "logprob": -0.44970703, "special": false, "text": ":" }, { "id": 4829, - "logprob": -1.8408203, + "logprob": -1.8339844, "special": false, "text": " Error" }, @@ -52,7 +52,8 @@ "special": false, "text": " test" } - ] + ], + "top_tokens": null }, - "generated_text": "Test requestfailed: Error in test" + "generated_text": "Test request failed: Error in test" } diff --git a/integration-tests/models/__snapshots__/test_flash_llama/test_flash_llama_load.json b/integration-tests/models/__snapshots__/test_flash_llama/test_flash_llama_load.json index 9bbb5322576..3543dad2353 100644 --- a/integration-tests/models/__snapshots__/test_flash_llama/test_flash_llama_load.json +++ b/integration-tests/models/__snapshots__/test_flash_llama/test_flash_llama_load.json @@ -17,7 +17,7 @@ }, { "id": 2009, - "logprob": -11.5546875, + "logprob": -11.546875, "text": "request" } ], @@ -25,25 +25,25 @@ "tokens": [ { "id": 363, - "logprob": -1.5380859, + "logprob": -1.5351562, "special": false, "text": " for" }, { "id": 847, - "logprob": -2.5859375, + "logprob": -2.5566406, "special": false, "text": " /" }, { "id": 2754, - "logprob": -2.2695312, + "logprob": -2.2519531, "special": false, "text": "api" }, { "id": 29914, - "logprob": -0.03439331, + "logprob": -0.03414917, "special": false, "text": "/" }, @@ -55,13 +55,13 @@ }, { "id": 29896, - "logprob": -0.36694336, + "logprob": -0.3647461, "special": false, "text": "1" }, { "id": 29914, - "logprob": -0.013114929, + "logprob": -0.012901306, "special": false, "text": "/" }, @@ -73,19 +73,20 @@ }, { "id": 29914, - "logprob": -0.43847656, + "logprob": -0.4362793, "special": false, "text": "/" }, { "id": 29896, - "logprob": -1.9433594, + "logprob": -1.9394531, "special": false, "text": "1" } - ] + ], + "top_tokens": null }, - "generated_text": "for /api/v1/projects/1" + "generated_text": " for /api/v1/projects/1" }, { "details": { @@ -105,7 +106,7 @@ }, { "id": 2009, - "logprob": -11.5546875, + "logprob": -11.546875, "text": "request" } ], @@ -113,43 +114,43 @@ "tokens": [ { "id": 363, - "logprob": -1.5322266, + "logprob": -1.5332031, "special": false, "text": " for" }, { "id": 847, - "logprob": -2.5585938, + "logprob": -2.5625, "special": false, "text": " /" }, { "id": 2754, - "logprob": -2.265625, + "logprob": -2.2617188, "special": false, "text": "api" }, { "id": 29914, - "logprob": -0.034088135, + "logprob": -0.033996582, "special": false, "text": "/" }, { "id": 29894, - "logprob": -0.96240234, + "logprob": -0.9609375, "special": false, "text": "v" }, { "id": 29896, - "logprob": -0.36816406, + "logprob": -0.36572266, "special": false, "text": "1" }, { "id": 29914, - "logprob": -0.013191223, + "logprob": -0.0129776, "special": false, "text": "/" }, @@ -161,19 +162,20 @@ }, { "id": 29914, - "logprob": -0.43774414, + "logprob": -0.4362793, "special": false, "text": "/" }, { "id": 29896, - "logprob": -1.9443359, + "logprob": -1.9394531, "special": false, "text": "1" } - ] + ], + "top_tokens": null }, - "generated_text": "for /api/v1/projects/1" + "generated_text": " for /api/v1/projects/1" }, { "details": { @@ -193,7 +195,7 @@ }, { "id": 2009, - "logprob": -11.5546875, + "logprob": -11.546875, "text": "request" } ], @@ -201,43 +203,43 @@ "tokens": [ { "id": 363, - "logprob": -1.5322266, + "logprob": -1.5332031, "special": false, "text": " for" }, { "id": 847, - "logprob": -2.5585938, + "logprob": -2.5625, "special": false, "text": " /" }, { "id": 2754, - "logprob": -2.265625, + "logprob": -2.2617188, "special": false, "text": "api" }, { "id": 29914, - "logprob": -0.034088135, + "logprob": -0.033996582, "special": false, "text": "/" }, { "id": 29894, - "logprob": -0.96240234, + "logprob": -0.9609375, "special": false, "text": "v" }, { "id": 29896, - "logprob": -0.36816406, + "logprob": -0.36572266, "special": false, "text": "1" }, { "id": 29914, - "logprob": -0.013191223, + "logprob": -0.0129776, "special": false, "text": "/" }, @@ -249,19 +251,20 @@ }, { "id": 29914, - "logprob": -0.43774414, + "logprob": -0.4362793, "special": false, "text": "/" }, { "id": 29896, - "logprob": -1.9443359, + "logprob": -1.9394531, "special": false, "text": "1" } - ] + ], + "top_tokens": null }, - "generated_text": "for /api/v1/projects/1" + "generated_text": " for /api/v1/projects/1" }, { "details": { @@ -281,7 +284,7 @@ }, { "id": 2009, - "logprob": -11.5546875, + "logprob": -11.546875, "text": "request" } ], @@ -289,43 +292,43 @@ "tokens": [ { "id": 363, - "logprob": -1.5322266, + "logprob": -1.5332031, "special": false, "text": " for" }, { "id": 847, - "logprob": -2.5585938, + "logprob": -2.5625, "special": false, "text": " /" }, { "id": 2754, - "logprob": -2.265625, + "logprob": -2.2617188, "special": false, "text": "api" }, { "id": 29914, - "logprob": -0.034088135, + "logprob": -0.033996582, "special": false, "text": "/" }, { "id": 29894, - "logprob": -0.96240234, + "logprob": -0.9609375, "special": false, "text": "v" }, { "id": 29896, - "logprob": -0.36816406, + "logprob": -0.36572266, "special": false, "text": "1" }, { "id": 29914, - "logprob": -0.013191223, + "logprob": -0.0129776, "special": false, "text": "/" }, @@ -337,18 +340,19 @@ }, { "id": 29914, - "logprob": -0.43774414, + "logprob": -0.4362793, "special": false, "text": "/" }, { "id": 29896, - "logprob": -1.9443359, + "logprob": -1.9394531, "special": false, "text": "1" } - ] + ], + "top_tokens": null }, - "generated_text": "for /api/v1/projects/1" + "generated_text": " for /api/v1/projects/1" } ] diff --git a/integration-tests/models/__snapshots__/test_idefics/test_idefics.json b/integration-tests/models/__snapshots__/test_idefics/test_idefics.json index 0edd81b6bda..2c5d05f6036 100644 --- a/integration-tests/models/__snapshots__/test_idefics/test_idefics.json +++ b/integration-tests/models/__snapshots__/test_idefics/test_idefics.json @@ -11,22 +11,22 @@ }, { "id": 4911, - "logprob": -5.7773438, + "logprob": -5.7851562, "text": "User" }, { "id": 29901, - "logprob": -0.0069999695, + "logprob": -0.006996155, "text": ":" }, { "id": 32000, - "logprob": -0.8125, + "logprob": -0.81347656, "text": "" }, { "id": 32001, - "logprob": -6.651878e-05, + "logprob": -6.687641e-05, "text": "" }, { @@ -36,67 +36,67 @@ }, { "id": 1815, - "logprob": -4.2265625, + "logprob": -4.2148438, "text": "Can" }, { "id": 366, - "logprob": -0.013977051, + "logprob": -0.014137268, "text": "you" }, { "id": 2649, - "logprob": -4.4375, + "logprob": -4.4335938, "text": "tell" }, { "id": 592, - "logprob": -0.29077148, + "logprob": -0.2919922, "text": "me" }, { "id": 263, - "logprob": -4.2109375, + "logprob": -4.2070312, "text": "a" }, { "id": 1407, - "logprob": -9.4296875, + "logprob": -9.421875, "text": "very" }, { "id": 3273, - "logprob": -1.8671875, + "logprob": -1.8720703, "text": "short" }, { "id": 5828, - "logprob": -0.26586914, + "logprob": -0.26489258, "text": "story" }, { "id": 2729, - "logprob": -3.7460938, + "logprob": -3.7441406, "text": "based" }, { "id": 373, - "logprob": -0.0005350113, + "logprob": -0.0005393028, "text": "on" }, { "id": 278, - "logprob": -0.13867188, + "logprob": -0.140625, "text": "the" }, { "id": 1967, - "logprob": -0.06842041, + "logprob": -0.06756592, "text": "image" }, { "id": 29973, - "logprob": -0.15319824, + "logprob": -0.15454102, "text": "?" } ], @@ -104,7 +104,7 @@ "tokens": [ { "id": 32002, - "logprob": -0.0019445419, + "logprob": -0.0019140244, "special": true, "text": "" }, @@ -116,13 +116,13 @@ }, { "id": 13, - "logprob": -1.7881393e-05, + "logprob": -1.7642975e-05, "special": false, "text": "\n" }, { "id": 7900, - "logprob": -3.0994415e-06, + "logprob": -2.9802322e-06, "special": false, "text": "Ass" }, @@ -140,30 +140,30 @@ }, { "id": 319, - "logprob": -0.9057617, + "logprob": -0.91064453, "special": false, "text": " A" }, { "id": 696, - "logprob": -1.2314453, + "logprob": -1.2412109, "special": false, "text": " ro" }, { "id": 15664, - "logprob": -0.00024914742, + "logprob": -0.0002439022, "special": false, "text": "oster" }, { "id": 15028, - "logprob": -1.1621094, + "logprob": -1.1630859, "special": false, "text": " stands" } ], "top_tokens": null }, - "generated_text": "\nAssistant: A rooster stands" + "generated_text": " \nAssistant: A rooster stands" } diff --git a/integration-tests/models/__snapshots__/test_idefics/test_idefics_load.json b/integration-tests/models/__snapshots__/test_idefics/test_idefics_load.json index 81cc1b19841..f258e38da41 100644 --- a/integration-tests/models/__snapshots__/test_idefics/test_idefics_load.json +++ b/integration-tests/models/__snapshots__/test_idefics/test_idefics_load.json @@ -12,22 +12,22 @@ }, { "id": 4911, - "logprob": -5.7773438, + "logprob": -5.7851562, "text": "User" }, { "id": 29901, - "logprob": -0.0069999695, + "logprob": -0.006996155, "text": ":" }, { "id": 32000, - "logprob": -0.8125, + "logprob": -0.81347656, "text": "" }, { "id": 32001, - "logprob": -6.651878e-05, + "logprob": -6.687641e-05, "text": "" }, { @@ -37,67 +37,67 @@ }, { "id": 1815, - "logprob": -4.2265625, + "logprob": -4.2148438, "text": "Can" }, { "id": 366, - "logprob": -0.013977051, + "logprob": -0.014137268, "text": "you" }, { "id": 2649, - "logprob": -4.4375, + "logprob": -4.4335938, "text": "tell" }, { "id": 592, - "logprob": -0.29077148, + "logprob": -0.2919922, "text": "me" }, { "id": 263, - "logprob": -4.2109375, + "logprob": -4.2070312, "text": "a" }, { "id": 1407, - "logprob": -9.4296875, + "logprob": -9.421875, "text": "very" }, { "id": 3273, - "logprob": -1.8671875, + "logprob": -1.8720703, "text": "short" }, { "id": 5828, - "logprob": -0.26586914, + "logprob": -0.26489258, "text": "story" }, { "id": 2729, - "logprob": -3.7460938, + "logprob": -3.7441406, "text": "based" }, { "id": 373, - "logprob": -0.0005350113, + "logprob": -0.0005393028, "text": "on" }, { "id": 278, - "logprob": -0.13867188, + "logprob": -0.140625, "text": "the" }, { "id": 1967, - "logprob": -0.06842041, + "logprob": -0.06756592, "text": "image" }, { "id": 29973, - "logprob": -0.15319824, + "logprob": -0.15454102, "text": "?" } ], @@ -105,13 +105,13 @@ "tokens": [ { "id": 32002, - "logprob": -0.0019445419, + "logprob": -0.0019140244, "special": true, "text": "" }, { "id": 29871, - "logprob": -8.416176e-05, + "logprob": -8.392334e-05, "special": false, "text": " " }, @@ -123,7 +123,7 @@ }, { "id": 7900, - "logprob": -3.0994415e-06, + "logprob": -2.9802322e-06, "special": false, "text": "Ass" }, @@ -135,38 +135,38 @@ }, { "id": 29901, - "logprob": -3.2186508e-06, + "logprob": -3.0994415e-06, "special": false, "text": ":" }, { "id": 319, - "logprob": -0.89941406, + "logprob": -0.9057617, "special": false, "text": " A" }, { "id": 696, - "logprob": -1.234375, + "logprob": -1.2294922, "special": false, "text": " ro" }, { "id": 15664, - "logprob": -0.0002465248, + "logprob": -0.00024533272, "special": false, "text": "oster" }, { "id": 15028, - "logprob": -1.1660156, + "logprob": -1.1640625, "special": false, "text": " stands" } ], "top_tokens": null }, - "generated_text": "\nAssistant: A rooster stands" + "generated_text": " \nAssistant: A rooster stands" }, { "details": { @@ -181,22 +181,22 @@ }, { "id": 4911, - "logprob": -5.7890625, + "logprob": -5.7773438, "text": "User" }, { "id": 29901, - "logprob": -0.0070152283, + "logprob": -0.0070114136, "text": ":" }, { "id": 32000, - "logprob": -0.8125, + "logprob": -0.8208008, "text": "" }, { "id": 32001, - "logprob": -6.651878e-05, + "logprob": -6.699562e-05, "text": "" }, { @@ -211,17 +211,17 @@ }, { "id": 366, - "logprob": -0.014190674, + "logprob": -0.014175415, "text": "you" }, { "id": 2649, - "logprob": -4.4140625, + "logprob": -4.4296875, "text": "tell" }, { "id": 592, - "logprob": -0.2919922, + "logprob": -0.29516602, "text": "me" }, { @@ -231,7 +231,7 @@ }, { "id": 1407, - "logprob": -9.4375, + "logprob": -9.4296875, "text": "very" }, { @@ -241,7 +241,7 @@ }, { "id": 5828, - "logprob": -0.26904297, + "logprob": -0.26879883, "text": "story" }, { @@ -251,22 +251,22 @@ }, { "id": 373, - "logprob": -0.0005402565, + "logprob": -0.0005354881, "text": "on" }, { "id": 278, - "logprob": -0.13867188, + "logprob": -0.13671875, "text": "the" }, { "id": 1967, - "logprob": -0.068359375, + "logprob": -0.06719971, "text": "image" }, { "id": 29973, - "logprob": -0.15539551, + "logprob": -0.15551758, "text": "?" } ], @@ -274,7 +274,7 @@ "tokens": [ { "id": 32002, - "logprob": -0.0019168854, + "logprob": -0.0019130707, "special": true, "text": "" }, @@ -286,7 +286,7 @@ }, { "id": 13, - "logprob": -1.7642975e-05, + "logprob": -1.7881393e-05, "special": false, "text": "\n" }, @@ -310,32 +310,32 @@ }, { "id": 319, - "logprob": -0.90722656, + "logprob": -0.9013672, "special": false, "text": " A" }, { "id": 696, - "logprob": -1.2373047, + "logprob": -1.2324219, "special": false, "text": " ro" }, { "id": 15664, - "logprob": -0.00024938583, + "logprob": -0.0002477169, "special": false, "text": "oster" }, { "id": 15028, - "logprob": -1.1708984, + "logprob": -1.1660156, "special": false, "text": " stands" } ], "top_tokens": null }, - "generated_text": "\nAssistant: A rooster stands" + "generated_text": " \nAssistant: A rooster stands" }, { "details": { @@ -350,22 +350,22 @@ }, { "id": 4911, - "logprob": -5.7890625, + "logprob": -5.7773438, "text": "User" }, { "id": 29901, - "logprob": -0.0070152283, + "logprob": -0.0070114136, "text": ":" }, { "id": 32000, - "logprob": -0.8125, + "logprob": -0.8208008, "text": "" }, { "id": 32001, - "logprob": -6.663799e-05, + "logprob": -6.699562e-05, "text": "" }, { @@ -380,17 +380,17 @@ }, { "id": 366, - "logprob": -0.014190674, + "logprob": -0.014175415, "text": "you" }, { "id": 2649, - "logprob": -4.4140625, + "logprob": -4.4296875, "text": "tell" }, { "id": 592, - "logprob": -0.2919922, + "logprob": -0.29516602, "text": "me" }, { @@ -400,7 +400,7 @@ }, { "id": 1407, - "logprob": -9.4375, + "logprob": -9.4296875, "text": "very" }, { @@ -410,7 +410,7 @@ }, { "id": 5828, - "logprob": -0.26904297, + "logprob": -0.26879883, "text": "story" }, { @@ -420,22 +420,22 @@ }, { "id": 373, - "logprob": -0.0005402565, + "logprob": -0.0005354881, "text": "on" }, { "id": 278, - "logprob": -0.13867188, + "logprob": -0.13671875, "text": "the" }, { "id": 1967, - "logprob": -0.068359375, + "logprob": -0.06719971, "text": "image" }, { "id": 29973, - "logprob": -0.15539551, + "logprob": -0.15551758, "text": "?" } ], @@ -443,19 +443,19 @@ "tokens": [ { "id": 32002, - "logprob": -0.0019168854, + "logprob": -0.001912117, "special": true, "text": "" }, { "id": 29871, - "logprob": -8.404255e-05, + "logprob": -8.392334e-05, "special": false, "text": " " }, { "id": 13, - "logprob": -1.7642975e-05, + "logprob": -1.7762184e-05, "special": false, "text": "\n" }, @@ -479,32 +479,32 @@ }, { "id": 319, - "logprob": -0.90722656, + "logprob": -0.9013672, "special": false, "text": " A" }, { "id": 696, - "logprob": -1.2373047, + "logprob": -1.2324219, "special": false, "text": " ro" }, { "id": 15664, - "logprob": -0.00024938583, + "logprob": -0.0002477169, "special": false, "text": "oster" }, { "id": 15028, - "logprob": -1.1708984, + "logprob": -1.1660156, "special": false, "text": " stands" } ], "top_tokens": null }, - "generated_text": "\nAssistant: A rooster stands" + "generated_text": " \nAssistant: A rooster stands" }, { "details": { @@ -519,22 +519,22 @@ }, { "id": 4911, - "logprob": -5.7890625, + "logprob": -5.7773438, "text": "User" }, { "id": 29901, - "logprob": -0.0070152283, + "logprob": -0.0070114136, "text": ":" }, { "id": 32000, - "logprob": -0.8125, + "logprob": -0.8208008, "text": "" }, { "id": 32001, - "logprob": -6.663799e-05, + "logprob": -6.699562e-05, "text": "" }, { @@ -549,17 +549,17 @@ }, { "id": 366, - "logprob": -0.014190674, + "logprob": -0.014175415, "text": "you" }, { "id": 2649, - "logprob": -4.4140625, + "logprob": -4.4296875, "text": "tell" }, { "id": 592, - "logprob": -0.2919922, + "logprob": -0.29516602, "text": "me" }, { @@ -569,7 +569,7 @@ }, { "id": 1407, - "logprob": -9.4375, + "logprob": -9.4296875, "text": "very" }, { @@ -579,7 +579,7 @@ }, { "id": 5828, - "logprob": -0.26904297, + "logprob": -0.26879883, "text": "story" }, { @@ -589,22 +589,22 @@ }, { "id": 373, - "logprob": -0.0005402565, + "logprob": -0.0005354881, "text": "on" }, { "id": 278, - "logprob": -0.13867188, + "logprob": -0.13671875, "text": "the" }, { "id": 1967, - "logprob": -0.068359375, + "logprob": -0.06719971, "text": "image" }, { "id": 29973, - "logprob": -0.15539551, + "logprob": -0.15551758, "text": "?" } ], @@ -612,19 +612,19 @@ "tokens": [ { "id": 32002, - "logprob": -0.0019159317, + "logprob": -0.001912117, "special": true, "text": "" }, { "id": 29871, - "logprob": -8.404255e-05, + "logprob": -8.392334e-05, "special": false, "text": " " }, { "id": 13, - "logprob": -1.7642975e-05, + "logprob": -1.7762184e-05, "special": false, "text": "\n" }, @@ -648,31 +648,31 @@ }, { "id": 319, - "logprob": -0.90722656, + "logprob": -0.9013672, "special": false, "text": " A" }, { "id": 696, - "logprob": -1.2373047, + "logprob": -1.2324219, "special": false, "text": " ro" }, { "id": 15664, - "logprob": -0.00024938583, + "logprob": -0.0002477169, "special": false, "text": "oster" }, { "id": 15028, - "logprob": -1.1708984, + "logprob": -1.1660156, "special": false, "text": " stands" } ], "top_tokens": null }, - "generated_text": "\nAssistant: A rooster stands" + "generated_text": " \nAssistant: A rooster stands" } ] diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 696f0fb23f9..35d74b2e9ff 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -641,8 +641,11 @@ def generate_token( if i % self.world_size == self.rank: if stop: # Decode generated tokens - output_text = self.decode( - all_input_ids[-stopping_criteria.current_tokens :, 0] + output_text, _, _ = self.decode_token( + all_input_ids[:, 0], + prefix_offset=len(all_input_ids) - stopping_criteria.current_tokens - 1, + read_offset=len(all_input_ids) - stopping_criteria.current_tokens, + skip_special_tokens=True ) # Get seed if isinstance(next_token_chooser.choice, Sampling): diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index d6af07f4c01..12d8efebc91 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -793,11 +793,6 @@ def warmup(self, batch: FlashCausalLMBatch): return int(num_blocks * BLOCK_SIZE) - def decode(self, generated_ids: Union[torch.Tensor, List[int]]) -> str: - return self.tokenizer.decode( - generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False - ) - def forward( self, input_ids: torch.Tensor, @@ -1008,8 +1003,11 @@ def generate_token( if i % self.world_size == self.rank: if stop: # Decode generated tokens - output_text = self.decode( - all_input_ids[-stopping_criteria.current_tokens :] + output_text, _, _ = self.decode_token( + all_input_ids, + prefix_offset=len(all_input_ids) - stopping_criteria.current_tokens - 1, + read_offset=len(all_input_ids) - stopping_criteria.current_tokens, + skip_special_tokens=True ) generated_text = GeneratedText( output_text, diff --git a/server/text_generation_server/models/idefics_causal_lm.py b/server/text_generation_server/models/idefics_causal_lm.py index f4177145b18..30cc2299cd3 100644 --- a/server/text_generation_server/models/idefics_causal_lm.py +++ b/server/text_generation_server/models/idefics_causal_lm.py @@ -611,11 +611,6 @@ def __init__( def batch_type(self) -> Type[IdeficsCausalLMBatch]: return IdeficsCausalLMBatch - def decode(self, generated_ids: List[int]) -> str: - return self.tokenizer.decode( - generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False - ) - def forward( self, input_ids, @@ -728,8 +723,11 @@ def generate_token( if i % self.world_size == self.rank: if stop: # Decode generated tokens - output_text = self.decode( - all_input_ids[-stopping_criteria.current_tokens :, 0] + output_text, _, _ = self.decode_token( + all_input_ids[:, 0], + prefix_offset=len(all_input_ids) - stopping_criteria.current_tokens - 1, + read_offset=len(all_input_ids) - stopping_criteria.current_tokens, + skip_special_tokens=True ) # Get seed if isinstance(next_token_chooser.choice, Sampling): diff --git a/server/text_generation_server/models/model.py b/server/text_generation_server/models/model.py index 806e98332b0..73329b24e10 100644 --- a/server/text_generation_server/models/model.py +++ b/server/text_generation_server/models/model.py @@ -64,16 +64,17 @@ def decode_token( all_input_ids: List[int], prefix_offset: int = 0, read_offset: int = 0, + skip_special_tokens: bool = False, ) -> Tuple[str, int, int]: """Hack to hopefully support generate_stream for the maximum number of tokenizers""" # The prefix text is necessary only to defeat cleanup algorithms in the decode # which decide to add a space or not depending on the surrounding ids. prefix_text = self.tokenizer.decode( - all_input_ids[prefix_offset:read_offset], skip_special_tokens=False + all_input_ids[prefix_offset:read_offset], skip_special_tokens=skip_special_tokens ) new_text = self.tokenizer.decode( - all_input_ids[prefix_offset:], skip_special_tokens=False + all_input_ids[prefix_offset:], skip_special_tokens=skip_special_tokens ) if len(new_text) > len(prefix_text) and not new_text.endswith("�"): diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index 34932c0b504..f67874bed60 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -710,8 +710,11 @@ def generate_token( if stop: # Slice with decoder_input_length to remove padding # Decode all tokens - output_text = self.decode( - all_decoder_input_ids[-decoder_input_length:] + output_text, _, _ = self.decode_token( + all_decoder_input_ids, + prefix_offset=len(all_decoder_input_ids) - decoder_input_length - 1, + read_offset=len(all_decoder_input_ids) - decoder_input_length, + skip_special_tokens=True ) # Get seed