Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

修复V100无法运行MiniCPM-V-2_6问题 #403

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions web_demo_2.6.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@
device = args.device
assert device in ['cuda', 'mps']


TORCH_TYPE = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.get_device_capability()[
0] >= 8 else torch.float16

# Load model
model_path = 'openbmb/MiniCPM-V-2_6'
if 'int4' in model_path:
Expand All @@ -44,7 +48,7 @@
if args.multi_gpus:
from accelerate import load_checkpoint_and_dispatch, init_empty_weights, infer_auto_device_map
with init_empty_weights():
model = AutoModel.from_pretrained(model_path, trust_remote_code=True, attn_implementation='sdpa', torch_dtype=torch.bfloat16)
model = AutoModel.from_pretrained(model_path, trust_remote_code=True, attn_implementation='sdpa', torch_dtype=TORCH_TYPE)
device_map = infer_auto_device_map(model, max_memory={0: "10GB", 1: "10GB"},
no_split_module_classes=['SiglipVisionTransformer', 'Qwen2DecoderLayer'])
device_id = device_map["llm.model.embed_tokens"]
Expand All @@ -63,9 +67,9 @@
device_map["llm.model.layers.16"] = device_id2
#print(device_map)

model = load_checkpoint_and_dispatch(model, model_path, dtype=torch.bfloat16, device_map=device_map)
model = load_checkpoint_and_dispatch(model, model_path, dtype=TORCH_TYPE, device_map=device_map)
else:
model = AutoModel.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.bfloat16)
model = AutoModel.from_pretrained(model_path, trust_remote_code=True, torch_dtype=TORCH_TYPE)
model = model.to(device=device)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model.eval()
Expand Down Expand Up @@ -554,4 +558,3 @@ def select_chat_type(_tab, _app_cfg):

# launch
demo.launch(share=False, debug=True, show_api=False, server_port=8885, server_name="0.0.0.0")