Skip to content

Commit

Permalink
Torchchat CLI pipeline for Multimodal Models (pytorch#1140)
Browse files Browse the repository at this point in the history
* Torchchat CLI pipeline for Multimodal Models

* Remove torchaudio check; we don't use it

* Flip the imports back for ET

---------

Co-authored-by: vmpuri <[email protected]>
Co-authored-by: Jack-Khuu <[email protected]>
  • Loading branch information
3 people authored Sep 15, 2024
1 parent 6fae164 commit 26c1d8b
Show file tree
Hide file tree
Showing 5 changed files with 251 additions and 90 deletions.
1 change: 0 additions & 1 deletion .github/workflows/pull.yml
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,6 @@ jobs:
pip3 list
python3 -c 'import torch;print(f"torch: {torch.__version__, torch.version.git_version}")'
python3 -c 'import torchvision;print(f"torchvision: {torchvision.__version__, torchvision.version.git_version}")'
python3 -c 'import torchaudio;print(f"torchaudio: {torchaudio.__version__, torchaudio.version.git_version}")'
cd ../..
echo "Inside: ${PWD}"
Expand Down
46 changes: 35 additions & 11 deletions torchchat/cli/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,26 @@
import torch._dynamo.config
import torch._inductor.config
import torch.nn as nn

try:
from _torchchat_test_script import flamingo_meta_to_tune
except ImportError:
pass

from distributed import (
init_distributed,
launch_distributed,
ParallelDims,
parallelize_llama,
)

from torch.distributed.device_mesh import DeviceMesh

from torchchat.model import Model
from torchtune.models.convert_weights import meta_to_tune

from torchtune.training import set_default_dtype

from torchchat.model import Model, ModelType

from torchchat.model_config.model_config import resolve_model_config
from torchchat.utils.build_utils import (
Expand All @@ -35,10 +46,6 @@
from torchchat.utils.measure_time import measure_time
from torchchat.utils.quantize import quantize_model

from torchtune.models.convert_weights import meta_to_tune




@dataclass
class BuilderArgs:
Expand Down Expand Up @@ -143,7 +150,6 @@ def from_args(cls, args): # -> BuilderArgs:
if "chat" in path_basename or "instruct" in path_basename:
is_chat_model = True


output_pte_path = getattr(args, "output_pte_path", None)
output_dso_path = getattr(args, "output_dso_path", None)
if output_pte_path and args.dtype.startswith("fast"):
Expand Down Expand Up @@ -234,7 +240,12 @@ def validate_model(

is_tiktoken = self.is_tiktoken
is_sentencepiece = self.is_sentencepiece
use_tiktoken = model.config.transformer_args["text"].use_tiktoken
text_args = model.config.transformer_args.get("text")
if text_args is None:
# TODO: Will be refactored: Currently, the only model that doesn't have text in transfomer_args is Flamingo
use_tiktoken = model.config.model_type == ModelType.Flamingo
else:
use_tiktoken = text_args.use_tiktoken

if not (is_tiktoken == use_tiktoken) or not (is_sentencepiece != use_tiktoken):
raise RuntimeError(
Expand Down Expand Up @@ -266,7 +277,9 @@ def from_args(cls, args): # -> TokenizerArgs:
raise RuntimeError("cannot find tokenizer model")

if not tokenizer_path.is_file():
raise RuntimeError(f"did not find tokenizer at {tokenizer_path}")
raise RuntimeError(
f"did not find tokenizer at {os.path.abspath(tokenizer_path)}"
)

return cls(
tokenizer_path=tokenizer_path,
Expand Down Expand Up @@ -335,7 +348,9 @@ def _load_model_default(builder_args, only_config=False):

if builder_args.params_table and builder_args.params_table.endswith("Tune"):
print("Loading Tune checkpoint")
meta_checkpoint = torch.load(str(builder_args.checkpoint_path), mmap=True, weights_only=True)
meta_checkpoint = torch.load(
str(builder_args.checkpoint_path), mmap=True, weights_only=True
)
checkpoint = meta_to_tune(meta_checkpoint)
elif builder_args.checkpoint_dir is not None:
# Load multiple checkpoint; ignore the single path.
Expand Down Expand Up @@ -372,8 +387,17 @@ def _load_model_default(builder_args, only_config=False):
if "model" in checkpoint and "stories" in str(builder_args.checkpoint_path):
checkpoint = checkpoint["model"]

checkpoint = {"model." + k: v for k, v in checkpoint.items()}
model.load_state_dict(checkpoint, assign=True, strict=True)
if model.config.model_type == ModelType.Flamingo:
# TODO: Refactor this. For now, overwrite the model with model loaded from params_path
with set_default_dtype(builder_args.precision), torch.device(
builder_args.device
):
model = Model.from_params(builder_args.params_path)
state_dict = flamingo_meta_to_tune(checkpoint)
model.model.load_state_dict(state_dict)
else:
checkpoint = {"model." + k: v for k, v in checkpoint.items()}
model.load_state_dict(checkpoint, assign=True, strict=True)

return model

Expand Down
9 changes: 8 additions & 1 deletion torchchat/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def check_args(args, verb: str) -> None:
# different semantics.
if (
verb not in INVENTORY_VERBS
and args.model
and getattr(args, "model", None)
and not is_model_downloaded(args.model, args.model_directory)
):
download_and_convert(args.model, args.model_directory, args.hf_token)
Expand Down Expand Up @@ -320,6 +320,13 @@ def _add_generation_args(parser, verb: str) -> None:
help="Number of samples",
)

generator_parser.add_argument(
"--image-prompts",
nargs="+",
type=str,
default=None,
help="Paths to image files used as image prompts for multimodal models. Currently, 1 image input is supported.",
)
generator_parser.add_argument(
"--chat",
action="store_true",
Expand Down
Loading

0 comments on commit 26c1d8b

Please sign in to comment.