-
Notifications
You must be signed in to change notification settings - Fork 4
/
model.py
86 lines (75 loc) · 3.09 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
from typing import Dict, List
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
class Model:
def __init__(self, **kwargs) -> None:
# Model name (uncomment/ comment for the different MPT flavors)
#self.model_name = 'mosaicml/mpt-7b'
model_name = 'mosaicml/mpt-7b-instruct'
#self.model_name = 'mosaicml/mpt-7b-storywriter'
#self.model_name = 'mosaicml/mpt-7b-chat'
# Device
self.device='cuda:0'
# Tokenizer
self.tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
self.tokenizer.pad_token = self.tokenizer.eos_token
self.tokenizer.padding_side = 'left'
# Attention implementation
# config = transformers.AutoConfig.from_pretrained(
# model_name,
# trust_remote_code=True
# )
# config.attn_config['attn_impl'] = 'triton'
# Model
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
#config=config,
trust_remote_code=True,
torch_dtype=torch.bfloat16
)
self.model.to(device=self.device)
self.model.eval()
def preprocess(self, request: dict):
generate_args = {
'max_new_tokens': 100,
'temperature': 1.0,
'top_p': 1.0,
'top_k': 50,
'repetition_penalty': 1.0,
'no_repeat_ngram_size': 0,
'use_cache': True,
'do_sample': True,
'eos_token_id': self.tokenizer.eos_token_id,
'pad_token_id': self.tokenizer.pad_token_id,
}
if 'max_tokens' in request.keys():
generate_args['max_new_tokens'] = request['max_tokens']
if 'temperature' in request.keys():
generate_args['temperature'] = request['temperature']
if 'top_p' in request.keys():
generate_args['top_p'] = request['top_p']
if 'top_k' in request.keys():
generate_args['top_k'] = request['top_k']
request['generate_args'] = generate_args
return request
def generate(self, prompt, generate_args):
encoded_inp = self.tokenizer(prompt, return_tensors='pt', padding=True)
for key, value in encoded_inp.items():
encoded_inp[key] = value.to(self.device)
with torch.no_grad():
encoded_gen = self.model.generate(
input_ids=encoded_inp['input_ids'],
attention_mask=encoded_inp['attention_mask'],
**generate_args,
)
decoded_gen = self.tokenizer.batch_decode(encoded_gen,
skip_special_tokens=True)
continuation = decoded_gen[0][len(prompt):]
return continuation
def predict(self, request: Dict) -> Dict[str, List]:
try:
prompt = request.pop("prompt")
completion = self.generate(prompt, request['generate_args'])
except Exception as exc:
return {"status": "error", "data": None, "message": str(exc)}
return {"status": "success", "data": completion, "message": None}