Skip to content

Commit

Permalink
Fix issue with 3d masks.
Browse files Browse the repository at this point in the history
  • Loading branch information
comfyanonymous committed Nov 13, 2024
1 parent 3748e7e commit 3b9a6cf
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions comfy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -853,19 +853,18 @@ def reshape_mask(input_mask, output_shape):
dims = len(output_shape) - 2

if dims == 1:
mask = input_mask
scale_mode = "linear"

if dims == 2:
mask = input_mask.reshape((-1, 1, input_mask.shape[-2], input_mask.shape[-1]))
input_mask = input_mask.reshape((-1, 1, input_mask.shape[-2], input_mask.shape[-1]))
scale_mode = "bilinear"

if dims == 3:
if len(input_mask.shape) < 5:
mask = input_mask.reshape((1, 1, -1, input_mask.shape[-2], input_mask.shape[-1]))
input_mask = input_mask.reshape((1, 1, -1, input_mask.shape[-2], input_mask.shape[-1]))
scale_mode = "trilinear"

mask = torch.nn.functional.interpolate(mask, size=output_shape[2:], mode=scale_mode)
mask = torch.nn.functional.interpolate(input_mask, size=output_shape[2:], mode=scale_mode)
if mask.shape[1] < output_shape[1]:
mask = mask.repeat((1, output_shape[1]) + (1,) * dims)[:,:output_shape[1]]
mask = comfy.utils.repeat_to_batch_size(mask, output_shape[0])
Expand Down

0 comments on commit 3b9a6cf

Please sign in to comment.