Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question about why getting blank image when using flux ControlNet. #20

Open
chuck-ma opened this issue Nov 13, 2024 · 7 comments
Open

Comments

@chuck-ma
Copy link

chuck-ma commented Nov 13, 2024

generate image code detail
from diffusers import FluxTransformer2DModel
import torch


def load_flux_model(
    model_path: str,
    load_from_file: bool = True,
    dtype: torch.dtype = torch.bfloat16,
) -> FluxTransformer2DModel:
    """
    加载FLUX模型,支持从单文件或预训练目录加载

    参数:
        model_path: 模型路径,可以是safetensors文件路径或预训练模型目录
        load_from_file: 是否从单个文件加载
        dtype: 模型计算精度
    """
    quantization_config = None

    if load_from_file:
        model = FluxTransformer2DModel.from_single_file(
            model_path, quantization_config=quantization_config, torch_dtype=dtype
        )
    else:
        model = FluxTransformer2DModel.from_pretrained(
            model_path, quantization_config=quantization_config, torch_dtype=dtype
        )

    return model

from huggingface_hub import hf_hub_download

ckpt_repo = "Kijai/flux-fp8"
ckpt_filename = "flux1-dev-fp8-e4m3fn.safetensors"

ckpt_path = hf_hub_download(ckpt_repo, filename=ckpt_filename)

model = load_flux_model(ckpt_path, )

from nunchaku.models.flux  import  load_quantized_model

qmodel_path = "mit-han-lab/svdquant-models/svdq-int4-flux.1-dev.safetensors"

if not os.path.exists(qmodel_path):
    hf_repo_id = os.path.dirname(qmodel_path)
    filename = os.path.basename(qmodel_path)
    qmodel_path = hf_hub_download(repo_id=hf_repo_id, filename=filename)


m = load_quantized_model(
        qmodel_path, "cuda"
    )

from nunchaku.models.flux  import  NunchakuFluxModel, EmbedND, QuantizedFluxModel, SVD_RANK
import types

def inject_transformer(
    transformer_model: FluxTransformer2DModel, m: QuantizedFluxModel
) -> None:
    """注入自定义transformer模型

    Args:
        transformer_model: 原始transformer模型
        custom_model: 要注入的自定义模型
    """
    # 注入位置编码
    transformer_model.pos_embed = EmbedND(
        dim=transformer_model.inner_dim, theta=10000, axes_dim=[16, 56, 56]
    )

    # 替换transformer块
    transformer_model.transformer_blocks = torch.nn.ModuleList([NunchakuFluxModel(m)])
    transformer_model.single_transformer_blocks = torch.nn.ModuleList([])

    def update_params(self: FluxTransformer2DModel, path: str):
        if not os.path.exists(path):
            hf_repo_id = os.path.dirname(path)
            filename = os.path.basename(path)
            path = hf_hub_download(repo_id=hf_repo_id, filename=filename)
        block = self.transformer_blocks[0]
        assert isinstance(block, NunchakuFluxModel)
        block.m.load(path, True)

    def set_lora_scale(self: FluxTransformer2DModel, scale: float):
        block = self.transformer_blocks[0]
        assert isinstance(block, NunchakuFluxModel)
        block.m.setLoraScale(SVD_RANK, scale)

    transformer_model.nunchaku_update_params = types.MethodType(
        update_params, transformer_model
    )
    transformer_model.nunchaku_set_lora_scale = types.MethodType(
        set_lora_scale, transformer_model
    )

    return transformer_model


import torch


model = inject_transformer(model, m)

model = model.to("cuda")

from diffusers.pipelines import FluxControlNetPipeline
dtype = torch.bfloat16
from diffusers import FluxPipeline, FluxTransformer2DModel
 
dtype = torch.bfloat16

flux_id = "black-forest-labs/FLUX.1-dev"

pipeline = FluxPipeline.from_pretrained(
                flux_id,
                transformer=model,
                torch_dtype=dtype,
            )

pipeline.vae.to("cuda")
pipeline.text_encoder.to("cuda")
pipeline.text_encoder_2.to("cuda")

print(11)

image = pipeline(
    "A cat holding a sign that says hello world",
    num_inference_steps=28,
    guidance_scale=3.5,
).images[0]

image

When I don't use flux ControlNet, everything is just fine. And the generation is pretty fast.

However, when I use flux ControlNet, I only get a blank image. Nothing throws errors.

generate image using controlnet code detail
from diffusers import FluxControlNetModel

del pipeline
torch.cuda.empty_cache()

# Load pipeline
controlnet = FluxControlNetModel.from_pretrained(
  "jasperai/Flux.1-dev-Controlnet-Depth",
  torch_dtype=torch.bfloat16
)

flux_id = "black-forest-labs/FLUX.1-dev"

pipeline = FluxControlNetPipeline.from_pretrained(
                flux_id,
                transformer=model,
                torch_dtype=dtype,
      controlnet=controlnet,

            )
pipeline.controlnet.to("cuda")
print(1)

pipeline.vae.to("cuda")
pipeline.text_encoder.to("cuda")

print(11)

from diffusers.utils import load_image


control_image = load_image(
  "https://hf-mirror.com/jasperai/Flux.1-dev-Controlnet-Depth/resolve/main/examples/depth.jpg"
)



prompt = "a statue of a gnome in a field of purple tulips"

image = pipeline(
    prompt, 
    control_image=control_image,
    controlnet_conditioning_scale=0.6,
    num_inference_steps=28, 
    guidance_scale=3.5,
    height=control_image.size[1],
    width=control_image.size[0]
).images[0]
image

Any help ?

@chuck-ma chuck-ma changed the title Question about use flux controlnet Question about why getting blank image when using flux ControlNet. Nov 13, 2024
@chuck-ma
Copy link
Author

It seems that the current library uses the diffusers version 0.30.3 which doesn't support the controlnet model. So we should upgrade the code so that we can make it work when running the controlnet model. Correct me if I'm wrong.

@lmxyy
Copy link
Collaborator

lmxyy commented Nov 14, 2024

We do not have native support for ControlNet for now. We will try to support it in our next release.

@chuck-ma
Copy link
Author

chuck-ma commented Nov 15, 2024

I've made a pull request trying to make nunchaku support controlnet.
however, it still gets blank image.

#25

Any guide on how to make it work?

I think if we can make FluxModel.cpp accept two parameters: controlnet_block_samples and controlnet_single_block_samples. Then everything will be just fine. Don't know how to debug the problem. Any help?@lmxyy @sxtyzhangzk @synxlin

@lmxyy
Copy link
Collaborator

lmxyy commented Nov 16, 2024

Thanks for your PR. I will take a deep look at it next week. Have a good weekend!

@chuck-ma
Copy link
Author

chuck-ma commented Nov 17, 2024

Thank you for your time.

@chuck-ma
Copy link
Author

Hello! Is there any progress now? If you can provide some ideas for debugging, I will be happy to help speed up the progress. @lmxyy

@lmxyy
Copy link
Collaborator

lmxyy commented Nov 21, 2024

I am occupied these two days. Hopefully, I can review your PR this Friday.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants