-
Notifications
You must be signed in to change notification settings - Fork 335
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Chapter09/Semantic_Segmentation_with_U_Net.ipynb #79
Comments
I am also getting error:
|
class SegData(Dataset):
def __init__(self, split):
self.items = stems(f'./dataset1/dataset1/images_prepped_{split}')
self.split = split
def __len__(self):
return len(self.items)
def __getitem__(self, ix):
image = read(f'./dataset1/images_prepped_{self.split}/{self.items[ix]}.png', 1)
image = cv2.resize(image, (224,224))
# read(f'./dataset1/annotations_prepped_{self.split}/{self.items[ix]}.png')
mask = cv2.imread(f'./dataset1/annotations_prepped_{self.split}/{self.items[ix]}.png', cv2.IMREAD_GRAYSCALE)
mask = cv2.resize(mask, (224,224))
return image, mask
def choose(self): return self[randint(len(self))]
def collate_fn(self, batch):
ims, masks = list(zip(*batch))
ims = torch.cat([tfms(im.copy()/255.)[None] for im in ims]).float().to(device)
ce_masks = torch.cat([torch.Tensor(mask[None]) for mask in masks]).long().to(device)
return ims, ce_masks |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
RuntimeError: only batches of spatial targets supported (3D tensors) but got targets of size: : [4, 224, 224, 3]
in UnetLoss(preds, targets)
1 ce = nn.CrossEntropyLoss()
2 def UnetLoss(preds, targets):
----> 3 ce_loss = ce(preds, targets)
4 acc = (torch.max(preds, 1)[1] == targets).float().mean()
5 return ce_loss, acc
The text was updated successfully, but these errors were encountered: