Skip to content
This repository has been archived by the owner on Aug 30, 2024. It is now read-only.

Commit

Permalink
miagrate pr [LLM Runtime] Add Whisper Example and Python API
Browse files Browse the repository at this point in the history
Signed-off-by: intellinjun <[email protected]>
  • Loading branch information
intellinjun committed Jan 10, 2024
1 parent a0a806e commit 7595348
Show file tree
Hide file tree
Showing 32 changed files with 1,509 additions and 462 deletions.
20 changes: 20 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,18 @@ Neural Speed supports the following models:
<td> </td>
<td>Latest</td>
</tr>
<tr>
<td><a href="https://huggingface.co/openai/whisper-tiny" target="_blank" rel="noopener noreferrer">Whisper-tiny</a>,
<a href="https://huggingface.co/openai/whisper-base" target="_blank" rel="noopener noreferrer">Whisper-base</a>
<a href="https://huggingface.co/openai/whisper-small" target="_blank" rel="noopener noreferrer">Whisper-small</a>
<a href="https://huggingface.co/openai/whisper-medium" target="_blank" rel="noopener noreferrer">Whisper-medium</a>
<a href="https://huggingface.co/openai/whisper-large" target="_blank" rel="noopener noreferrer">Whisper-large</a></td>
<td>✅</td>
<td> </td>
<td>✅</td>
<td> </td>
<td>Latest</td>
</tr>
</tbody>
</table>

Expand Down Expand Up @@ -253,6 +265,14 @@ model = AutoModelForCausalLM.from_pretrained(model_name, load_in_4bit=True)
outputs = model.generate(inputs, streamer=streamer, max_new_tokens=300)
```

To use whisper to Audio-to-text, here is the sample code
```python
from intel_extension_for_transformers.transformers import AutoModelForCausalLM, WeightOnlyQuantConfig
model_name = "Local path for whisper" # please use local path
woq_config = WeightOnlyQuantConfig(use_ggml=True) #Currently, only Q40 is supported
model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=woq_config)
model('Local audio file')
```

## How to use: Python script

Expand Down
31 changes: 23 additions & 8 deletions neural_speed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ def __import_package(self, model_type):
import neural_speed.polyglot_cpp as cpp_model
elif model_type == "mistral":
import neural_speed.mistral_cpp as cpp_model
elif model_type == "whisper":
import neural_speed.whisper_cpp as cpp_model
else:
raise TypeError("Unspported model type {}!".format(model_type))
self.module = cpp_model
Expand Down Expand Up @@ -204,11 +206,24 @@ def generate(self, input_ids, streamer=None, interactive=False, ignore_prompt=Fa
def is_token_end(self):
return self.model.is_token_end()

def __call__(self, input_ids, reinit=False, **kwargs):
if self.model is None:
self.init_from_bin(self.model_type, self.bin_file, **kwargs)
self.generate_round = 0
elif reinit:
self.model.reinit()
self.generate_round = 0
return self.model.evaluate(input_ids.tolist())
def __call__(self, model_input, reinit=False, **kwargs):
if self.model_type == 'whisper':
if self.model is None:
self.model = self.module.Model()
self.model.init_model(self.bin_file)
if os.path.isfile(model_input):
self.model.inference(model_input)
else:
print("Please input an audio file")
return
if isinstance(model_input, torch.Tensor):
if self.model is None:
self.init_from_bin(self.model_type, self.bin_file, **kwargs)
self.generate_round = 0
elif reinit:
self.model.reinit()
self.generate_round = 0
return self.model.evaluate(model_input.tolist())
else:
print("Please input torch.Tensor")
return
3 changes: 2 additions & 1 deletion neural_speed/application/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ compile_quant(quant_chatglm2 quant_model.cpp chatglm2 chatglm2)
compile_quant(quant_baichuan quant_model.cpp baichuan baichuan)
compile_quant(quant_mistral quant_model.cpp mistral llama)
compile_quant(quant_qwen quant_model.cpp qwen qwen)
compile_quant(quant_whisper quant_whisper.cpp whisper whisper)

# all models running
if (NS_PYTHON_API)
Expand Down Expand Up @@ -126,4 +127,4 @@ compile_run(run_mistral main_run.cpp main_pybind.cpp mistral llama)
compile_run(run_qwen main_run.cpp main_pybind.cpp qwen qwen)

# speech recognition
compile_run(run_whisper audio_run.cpp "" whisper whisper)
compile_run(run_whisper audio_run.cpp whisper_pybind.cpp whisper whisper)
2 changes: 2 additions & 0 deletions neural_speed/application/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@ bool isValidFilename(const std::string& filename) {
return infile.good();
}

int64_t common_time_us() { return ne_time_us(); }

void gpt_print_usage(int /*argc*/, char** argv, const common_params& params) {
fprintf(stderr, "usage: %s [options]\n", argv[0]);
fprintf(stderr, "\n");
Expand Down
2 changes: 2 additions & 0 deletions neural_speed/application/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@

int32_t get_num_physical_cores();

int64_t common_time_us();

struct common_params {
int32_t n_threads = get_num_physical_cores();

Expand Down
1 change: 1 addition & 0 deletions neural_speed/application/main_pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
#include "models/model_utils/model_types.h"
#include "models/model_utils/model_config.h"
#include "models/model_utils/model_utils.h"
#include "models/model_utils/quant_utils.h"

#if defined(__unix__) || (defined(__APPLE__) && defined(__MACH__))
#include <signal.h>
Expand Down
9 changes: 5 additions & 4 deletions neural_speed/application/quant_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include <tuple>

#include "common.h"
#include "models/model_utils/quant_utils.h"
#include "models/model_utils/model_utils.h"

std::shared_ptr<quant_layer_base> get_model_quant_layer(const std::string& model_name) {
Expand Down Expand Up @@ -55,24 +56,24 @@ int main(int argc, char** argv) {
printf("ne_ftype: %d\n", ftype);
const int nthread = q_params.nthread;

const int64_t t_main_start_us = model_time_us();
const int64_t t_main_start_us = common_time_us();

int64_t t_quantize_us = 0;
auto quant_layer = get_model_quant_layer(q_params.model_name);
// load the model
{
const int64_t t_start_us = model_time_us();
const int64_t t_start_us = common_time_us();

if (model_quantize(q_params, quant_layer)) {
fprintf(stderr, "%s: failed to quantize model from '%s'\n", __func__, fname_inp.c_str());
return 1;
}

t_quantize_us = model_time_us() - t_start_us;
t_quantize_us = common_time_us() - t_start_us;
}
// report timing
{
const int64_t t_main_end_us = model_time_us();
const int64_t t_main_end_us = common_time_us();

printf("\n");
printf("%s: quantize time = %8.2f ms\n", __func__, t_quantize_us / 1000.0);
Expand Down
74 changes: 74 additions & 0 deletions neural_speed/application/quant_whisper.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
// Copyright (c) 2023 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cassert>
#include <cmath>
#include <cstdio>
#include <cstring>
#include <fstream>
#include <map>
#include <string>
#include <vector>
#include <regex> //NOLINT
#include "models/model_utils/quant_utils.h"
#include "common.h"

int main(int argc, char** argv) {
quant_params q_params;
if (quant_params_parse(argc, argv, q_params) == false) {
return 1;
}

// needed to initialize f16 tables
{
struct ne_init_params params = {0, NULL, false};
struct ne_context* ctx = ne_init(params);
ne_free(ctx);
}
const std::string fname_inp = q_params.model_file;
const std::string fname_out = q_params.out_file;
// printf("input_model_file:%s \n",fname_inp.c_str());

const ne_ftype ftype = quant_params_to_ftype(q_params);
if (ftype != NE_FTYPE_MOSTLY_Q4_0) {
fprintf(stderr, "%s: ITREX now only support quantize model to q4_0 \n", __func__);
return 1;
}

const int64_t t_main_start_us = common_time_us();

int64_t t_quantize_us = 0;

// load the model
{
const int64_t t_start_us = common_time_us();

if (!whisper_model_quantize(fname_inp, fname_out, ne_ftype(ftype))) {
fprintf(stderr, "%s: failed to quantize model from '%s'\n", __func__, fname_inp.c_str());
return 1;
}

t_quantize_us = common_time_us() - t_start_us;
}

// report timing
{
const int64_t t_main_end_us = common_time_us();

printf("\n");
printf("%s: quantize time = %8.2f ms\n", __func__, t_quantize_us / 1000.0f);
printf("%s: total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us) / 1000.0f);
}

return 0;
}
Loading

0 comments on commit 7595348

Please sign in to comment.