Skip to content

Commit

Permalink
Internal change only.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 683262823
  • Loading branch information
haozha111 authored and copybara-github committed Oct 7, 2024
1 parent 4f25328 commit 18d7630
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 5 deletions.
21 changes: 17 additions & 4 deletions ai_edge_torch/generative/examples/stable_diffusion/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,9 @@ py_binary(
":clip",
":decoder",
":diffusion",
":encoder",
":util",
"//third_party/py/absl:app",
"//third_party/py/absl/flags",
"//third_party/py/ai_edge_torch",
"//third_party/py/ai_edge_torch/generative/utilities:stable_diffusion_loader",
"//third_party/py/torch:pytorch",
Expand All @@ -44,7 +45,11 @@ py_library(
name = "clip",
srcs = ["clip.py"],
deps = [
"//third_party/py/ai_edge_torch",
"//third_party/py/ai_edge_torch/generative/layers:attention",
"//third_party/py/ai_edge_torch/generative/layers:attention_utils",
"//third_party/py/ai_edge_torch/generative/layers:builder",
"//third_party/py/ai_edge_torch/generative/layers:model_config",
"//third_party/py/ai_edge_torch/generative/utilities:loader",
"//third_party/py/torch:pytorch",
],
)
Expand All @@ -53,7 +58,10 @@ py_library(
name = "decoder",
srcs = ["decoder.py"],
deps = [
"//third_party/py/ai_edge_torch",
"//third_party/py/ai_edge_torch/generative/layers:builder",
"//third_party/py/ai_edge_torch/generative/layers:model_config",
"//third_party/py/ai_edge_torch/generative/layers/unet:blocks_2d",
"//third_party/py/ai_edge_torch/generative/layers/unet:model_config",
"//third_party/py/ai_edge_torch/generative/utilities:stable_diffusion_loader",
"//third_party/py/torch:pytorch",
],
Expand Down Expand Up @@ -87,7 +95,11 @@ py_library(
name = "diffusion",
srcs = ["diffusion.py"],
deps = [
"//third_party/py/ai_edge_torch",
"//third_party/py/ai_edge_torch/generative/layers:builder",
"//third_party/py/ai_edge_torch/generative/layers:model_config",
"//third_party/py/ai_edge_torch/generative/layers/unet:blocks_2d",
"//third_party/py/ai_edge_torch/generative/layers/unet:model_config",
"//third_party/py/ai_edge_torch/generative/utilities:stable_diffusion_loader",
"//third_party/py/torch:pytorch",
],
)
Expand All @@ -100,6 +112,7 @@ py_library(
":util",
"//third_party/py/PIL:pil",
"//third_party/py/ai_edge_torch",
"//third_party/py/ai_edge_torch/generative/examples/stable_diffusion/samplers:__init__",
"//third_party/py/numpy",
"//third_party/py/tqdm",
],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

from typing import Tuple

from absl import app
from ai_edge_torch.generative.layers import attention
from ai_edge_torch.generative.layers import builder
from ai_edge_torch.generative.layers import kv_cache as kv_utils
Expand Down
52 changes: 52 additions & 0 deletions ai_edge_torch/generative/layers/unet/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright 2024 The AI Edge Torch Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
load("//third_party/bazel_rules/rules_python/python:py_library.bzl", "py_library")

package(
default_applicable_licenses = [
"//third_party/py/ai_edge_torch:license",
],
default_visibility = [
"//third_party/py/ai_edge_torch:__subpackages__",
],
)

py_library(
name = "builder",
srcs = ["builder.py"],
deps = [
":model_config",
"//third_party/py/torch:pytorch",
],
)

py_library(
name = "blocks_2d",
srcs = ["blocks_2d.py"],
deps = [
":builder",
":model_config",
"//third_party/py/ai_edge_torch/generative/layers:attention",
"//third_party/py/ai_edge_torch/generative/layers:builder",
"//third_party/py/ai_edge_torch/generative/layers:model_config",
"//third_party/py/torch:pytorch",
],
)

py_library(
name = "model_config",
srcs = ["model_config.py"],
deps = ["//third_party/py/ai_edge_torch/generative/layers:model_config"],
)

0 comments on commit 18d7630

Please sign in to comment.