Skip to content

Commit

Permalink
add some nodes for batch
Browse files Browse the repository at this point in the history
  • Loading branch information
lldacing committed Oct 18, 2024
1 parent 30ac286 commit 8086ad3
Show file tree
Hide file tree
Showing 4 changed files with 333 additions and 6 deletions.
12 changes: 11 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,14 @@ Tips: base64格式字符串比较长,会导致界面卡顿,接口请求带
| ForEachClose | 循环结束节点 |
| LoadJsonStrToList | json字符串转换为对象列表 |
| GetValueFromJsonObj | 从对象中获取指定key的值 |
| FilterValueForList | 根据指定值过滤列表中元素 ||
| FilterValueForList | 根据指定值过滤列表中元素 |
| SliceList | 列表切片 |
| LoadLocalFilePath | 列出给定路径下的文件列表 |
| LoadImageFromLocalPath | 根据图片全路径加载图片 |
| LoadMaskFromLocalPath | 根据遮罩全路径加载遮罩 | |
| IsNoneOrEmpty | 判断是否为空或空字符串或空列表或空字典 |
| IsNoneOrEmptyOptional | 为空时返回指定值(惰性求值),否则返回原值 |
| EmptyOutputNode | 空的输出类型节点 |

### 示例
![save api extended](docs/example_note.png)
Expand All @@ -76,6 +83,9 @@ Tips: base64格式字符串比较长,会导致界面卡顿,接口请求带
![save api extended](example/example_3.png)

## 更新记录
### 2024-10-18
- 新增节点:SliceList、LoadLocalFilePath、LoadImageFromLocalPath、LoadMaskFromLocalPath、IsNoneOrEmpty、IsNoneOrEmptyOptional、EmptyOutputNode

### 2024-09-29
- 新增节点:FilterValueForList

Expand Down
112 changes: 111 additions & 1 deletion easyapi/ImageNode.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import base64
import copy
import io
import os

import numpy as np
import torch
from PIL import ImageOps, Image
from PIL import ImageOps, Image, ImageSequence

import node_helpers
from nodes import LoadImage
from comfy.cli_args import args
from PIL.PngImagePlugin import PngInfo
Expand Down Expand Up @@ -340,6 +344,108 @@ def convert(self, image):
return encoded_image, img, mask


class LoadImageFromLocalPath:
@classmethod
def INPUT_TYPES(s):
return {"required":
{
"image_path": ("STRING", {"default": ""},)
},
}

CATEGORY = "EasyApi/Image"

RETURN_TYPES = ("IMAGE", "MASK")
FUNCTION = "load_image"
def load_image(self, image_path):

img = node_helpers.pillow(Image.open, image_path)

output_images = []
output_masks = []
w, h = None, None

excluded_formats = ['MPO']
# 遍历图像的每一帧
for i in ImageSequence.Iterator(img):
# 旋转图像
i = node_helpers.pillow(ImageOps.exif_transpose, i)

if i.mode == 'I':
i = i.point(lambda i: i * (1 / 255))
# 将图像转换为RGB格式
image = i.convert("RGB")

if len(output_images) == 0:
w = image.size[0]
h = image.size[1]

if image.size[0] != w or image.size[1] != h:
continue

# 将图像转换为浮点数组 (H,W,Channel)
image = np.array(image).astype(np.float32) / 255.0
# 先把图片转成3维张量,并再在最前面添加一个维度,变成4维(1, H, W,Channel)
image = torch.from_numpy(image)[None,]
# 如果图像包含alpha通道,则将其转换为掩码
if 'A' in i.getbands():
# 计算后结果数组中透明像素会是0
mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0
# 把数组中透明像素设为1
mask = 1. - torch.from_numpy(mask)
else:
# 否则,创建一个64x64的零张量作为掩码
mask = torch.zeros((64, 64,), dtype=torch.float32, device="cpu")
# 将图像和掩码添加到输出列表中
output_images.append(image)
output_masks.append(mask.unsqueeze(0))

if len(output_images) > 1 and img.format not in excluded_formats:
# 如果有多个图像,则将它们按维度0拼接在一起
output_image = torch.cat(output_images, dim=0)
output_mask = torch.cat(output_masks, dim=0)
# 否则,返回单个图像和掩码
else:
output_image = output_images[0]
output_mask = output_masks[0]
# 返回输出图像和掩码
return (output_image, output_mask)


class LoadMaskFromLocalPath:
_color_channels = ["alpha", "red", "green", "blue"]
@classmethod
def INPUT_TYPES(s):
return {"required":
{
"image_path": ("STRING", {"default": ""}),
"channel": (s._color_channels, ),
}
}

CATEGORY = "EasyApi/Image"

RETURN_TYPES = ("MASK",)
FUNCTION = "load_mask"
def load_mask(self, image_path, channel):
i = node_helpers.pillow(Image.open, image_path)
i = node_helpers.pillow(ImageOps.exif_transpose, i)
if i.getbands() != ("R", "G", "B", "A"):
if i.mode == 'I':
i = i.point(lambda i: i * (1 / 255))
i = i.convert("RGBA")
mask = None
c = channel[0].upper()
if c in i.getbands():
mask = np.array(i.getchannel(c)).astype(np.float32) / 255.0
mask = torch.from_numpy(mask)
if c == 'A':
mask = 1. - mask
else:
mask = torch.zeros((64, 64), dtype=torch.float32, device="cpu")
return (mask.unsqueeze(0),)


NODE_CLASS_MAPPINGS = {
"Base64ToImage": Base64ToImage,
"LoadImageFromURL": LoadImageFromURL,
Expand All @@ -351,6 +457,8 @@ def convert(self, image):
"MaskToBase64Image": MaskToBase64Image,
"MaskImageToBase64": MaskImageToBase64,
"LoadImageToBase64": LoadImageToBase64,
"LoadImageFromLocalPath": LoadImageFromLocalPath,
"LoadMaskFromLocalPath": LoadMaskFromLocalPath,
}

# A dictionary that contains the friendly/humanly readable titles for the nodes
Expand All @@ -365,4 +473,6 @@ def convert(self, image):
"MaskToBase64Image": "Mask To Base64 Image",
"MaskImageToBase64": "Mask Image To Base64",
"LoadImageToBase64": "Load Image To Base64",
"LoadImageFromLocalPath": "Load Image From Local Path",
"LoadMaskFromLocalPath": "Load Mask From Local Path",
}
Loading

0 comments on commit 8086ad3

Please sign in to comment.