Skip to content

Commit

Permalink
add ipynb to ipex example
Browse files Browse the repository at this point in the history
  • Loading branch information
jiqing-feng committed Mar 20, 2024
1 parent 70874f4 commit 7e940e8
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 242 deletions.
31 changes: 0 additions & 31 deletions examples/ipex/text-generation/README.md

This file was deleted.

96 changes: 96 additions & 0 deletions examples/ipex/text-generation/ipex_text-generation.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# IPEX model for text-generation"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"IPEX model will replace the linears and some ops. Please note that IPEXModel uses a graph mode model to inference to accelerate the generation."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from transformers import AutoTokenizer\n",
"from optimum.intel.ipex import IPEXModelForCausalLM"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Framework not specified. Using pt to export the model.\n",
"Passing the argument `library_name` to `get_supported_tasks_for_model_type` is required, but got library_name=None. Defaulting to `transformers`. An error will be raised in a future version of Optimum if `library_name` is not provided.\n",
"/home/jiqingfe/frameworks.ai.pytorch.ipex-cpu/intel_extension_for_pytorch/frontend.py:462: UserWarning: Conv BatchNorm folding failed during the optimize process.\n",
" warnings.warn(\n",
"/home/jiqingfe/frameworks.ai.pytorch.ipex-cpu/intel_extension_for_pytorch/frontend.py:469: UserWarning: Linear BatchNorm folding failed during the optimize process.\n",
" warnings.warn(\n",
"/home/jiqingfe/miniconda3/envs/ipex/lib/python3.10/site-packages/transformers/modeling_utils.py:4193: FutureWarning: `_is_quantized_training_enabled` is going to be deprecated in transformers 4.39.0. Please use `model.hf_quantizer.is_trainable` instead\n",
" warnings.warn(\n",
"/home/jiqingfe/miniconda3/envs/ipex/lib/python3.10/site-packages/transformers/models/gpt2/modeling_gpt2.py:801: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!\n",
" if batch_size <= 0:\n",
"Passing the argument `library_name` to `get_supported_tasks_for_model_type` is required, but got library_name=None. Defaulting to `transformers`. An error will be raised in a future version of Optimum if `library_name` is not provided.\n",
"Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n",
"access to the `model_dtype` attribute is deprecated and will be removed after v1.18.0, please use `_dtype` instead.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Answer the following yes/no question by reasoning step-by-step please. Can you write a whole Haiku in a single tweet? Yes, you can.\n",
"\n",
"Yes, I can write Haikus in one tweet. I have no idea how to do that, but I'm sure\n"
]
}
],
"source": [
"model = IPEXModelForCausalLM.from_pretrained(\"gpt2\", torch_dtype=torch.bfloat16, export=True)\n",
"tokenizer = AutoTokenizer.from_pretrained(\"gpt2\")\n",
"input_sentence = [\"Answer the following yes/no question by reasoning step-by-step please. Can you write a whole Haiku in a single tweet?\"]\n",
"model_inputs = tokenizer(input_sentence, return_tensors=\"pt\")\n",
"generation_kwargs = dict(max_new_tokens=32, do_sample=False, num_beams=4, num_beam_groups=1, no_repeat_ngram_size=2, use_cache=True)\n",
"\n",
"generated_ids = model.generate(**model_inputs, **generation_kwargs)\n",
"output = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]\n",
"print(output)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "ipex",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
211 changes: 0 additions & 211 deletions examples/ipex/text-generation/run_generation.py

This file was deleted.

0 comments on commit 7e940e8

Please sign in to comment.