Skip to content

Commit

Permalink
Add 2D blocks used in diffusion model of stable diffusion. (#50)
Browse files Browse the repository at this point in the history
* Add CrossAttentionBlock2D and TransformerBlock2D which are basic modules
used in UNet of diffusion model.

* Add DownEncoderBlock2D and downsampling related configs
  • Loading branch information
yichunk authored Jun 12, 2024
1 parent 6b0ba07 commit fca58c3
Show file tree
Hide file tree
Showing 4 changed files with 552 additions and 51 deletions.
4 changes: 2 additions & 2 deletions ai_edge_torch/generative/examples/stable_diffusion/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def __init__(self, config: unet_cfg.AutoEncoderConfig):
num_layers=config.layers_per_block,
add_upsample=not_final_block,
upsample_conv=True,
sampling_config=unet_cfg.SamplingConfig(
sampling_config=unet_cfg.UpSamplingConfig(
2, unet_cfg.SamplingType.NEAREST
),
)
Expand Down Expand Up @@ -271,7 +271,7 @@ def get_model_config() -> unet_cfg.AutoEncoderConfig:
)

att_config = unet_cfg.AttentionBlock2DConfig(
dims=block_out_channels[-1],
dim=block_out_channels[-1],
normalization_config=norm_config,
attention_config=layers_cfg.AttentionConfig(
num_heads=1,
Expand Down
Loading

0 comments on commit fca58c3

Please sign in to comment.