Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
adam9500370 committed Nov 19, 2018
1 parent 89f4abe commit 1db2391
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 27 deletions.
8 changes: 3 additions & 5 deletions ptsemseg/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ def cross_entropy2d(input, target, weight=None, size_average=True):
# Handle inconsistent size between input and target
if h > ht and w > wt: # upsample labels
target = target.unsqueeze(1)
target = F.upsample(target, size=(h, w), mode="nearest")
target = F.interpolate(target.float(), size=(h, w), mode="nearest").long()
target = target.squeeze(1)
elif h < ht and w < wt: # upsample images
input = F.upsample(input, size=(ht, wt), mode="bilinear")
input = F.interpolate(input, size=(ht, wt), mode="bilinear", align_corners=True)
elif h != ht and w != wt:
raise Exception("Only support upsampling")

Expand Down Expand Up @@ -72,9 +72,7 @@ def multi_scale_cross_entropy2d(
if scale_weight == None: # scale_weight: torch tensor type
n_inp = len(input)
scale = 0.4
scale_weight = torch.pow(scale * torch.ones(n_inp), torch.arange(n_inp))
if input.is_cuda:
scale_weight = scale_weight.cuda()
scale_weight = torch.pow(scale * torch.ones(n_inp), torch.arange(n_inp).float()).to('cuda' if input.is_cuda else 'cpu')

loss = 0.0
for i, inp in enumerate(input):
Expand Down
6 changes: 3 additions & 3 deletions ptsemseg/loss/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ def cross_entropy2d(input, target, weight=None, size_average=True):
# Handle inconsistent size between input and target
if h > ht and w > wt: # upsample labels
target = target.unsequeeze(1)
target = F.upsample(target, size=(h, w), mode="nearest")
target = F.interpolate(target.float(), size=(h, w), mode="nearest").long()
target = target.sequeeze(1)
elif h < ht and w < wt: # upsample images
input = F.upsample(input, size=(ht, wt), mode="bilinear")
input = F.interpolate(input, size=(ht, wt), mode="bilinear", align_corners=True)
elif h != ht and w != wt:
raise Exception("Only support upsampling")

Expand All @@ -33,7 +33,7 @@ def multi_scale_cross_entropy2d(
if scale_weight == None: # scale_weight: torch tensor type
n_inp = len(input)
scale = 0.4
scale_weight = torch.pow(scale * torch.ones(n_inp), torch.arange(n_inp))
scale_weight = torch.pow(scale * torch.ones(n_inp), torch.arange(n_inp).float()).to('cuda' if input.is_cuda else 'cpu')

loss = 0.0
for i, inp in enumerate(input):
Expand Down
13 changes: 7 additions & 6 deletions ptsemseg/models/icnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def forward(self, x):
h, w = x.shape[2:]

# H, W -> H/2, W/2
x_sub2 = interp(x, output_size=get_interp_size(x, s_factor=2))
x_sub2 = F.interpolate(x, size=get_interp_size(x, s_factor=2), mode='bilinear', align_corners=True)

# H/2, W/2 -> H/4, W/4
x_sub2 = self.convbnrelu1_1(x_sub2)
Expand All @@ -193,7 +193,7 @@ def forward(self, x):
x_sub2 = self.res_block2(x_sub2)
x_sub2 = self.res_block3_conv(x_sub2)
# H/16, W/16 -> H/32, W/32
x_sub4 = interp(x_sub2, output_size=get_interp_size(x_sub2, s_factor=2))
x_sub4 = F.interpolate(x_sub2, size=get_interp_size(x_sub2, s_factor=2), mode='bilinear', align_corners=True)
x_sub4 = self.res_block3_identity(x_sub4)

x_sub4 = self.res_block4(x_sub4)
Expand All @@ -209,18 +209,19 @@ def forward(self, x):
x_sub24, sub4_cls = self.cff_sub24(x_sub4, x_sub2)
x_sub12, sub24_cls = self.cff_sub12(x_sub24, x_sub1)

x_sub12 = F.upsample(
x_sub12, size=get_interp_size(x_sub12, z_factor=2), mode="bilinear"
x_sub12 = F.interpolate(
x_sub12, size=get_interp_size(x_sub12, z_factor=2), mode="bilinear", align_corners=True
)
sub124_cls = self.classification(x_sub12)

if self.training:
return sub4_cls, sub24_cls, sub124_cls
return (sub124_cls, sub24_cls, sub4_cls)
else: # eval mode
sub124_cls = F.upsample(
sub124_cls = F.interpolate(
sub124_cls,
size=get_interp_size(sub124_cls, z_factor=4),
mode="bilinear",
align_corners=True
) # Test only
return sub124_cls

Expand Down
19 changes: 7 additions & 12 deletions ptsemseg/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
import numpy as np
import torch.nn.functional as F

from torch.autograd import Variable


class conv2DBatchNorm(nn.Module):
def __init__(
Expand Down Expand Up @@ -572,7 +570,7 @@ def forward(self, x):
# out = F.adaptive_avg_pool2d(x, output_size=(pool_size, pool_size))
if self.model_name != "icnet":
out = module(out)
out = F.upsample(out, size=(h, w), mode="bilinear")
out = F.interpolate(out, size=(h, w), mode="bilinear", align_corners=True)
output_slices.append(out)

return torch.cat(output_slices, dim=1)
Expand All @@ -586,7 +584,7 @@ def forward(self, x):
# out = F.adaptive_avg_pool2d(x, output_size=(pool_size, pool_size))
if self.model_name != "icnet":
out = module(out)
out = F.upsample(out, size=(h, w), mode="bilinear")
out = F.interpolate(out, size=(h, w), mode="bilinear", align_corners=True)
pp_sum = pp_sum + out

return pp_sum
Expand Down Expand Up @@ -791,8 +789,8 @@ def __init__(
)

def forward(self, x_low, x_high):
x_low_upsampled = F.upsample(
x_low, size=get_interp_size(x_low, z_factor=2), mode="bilinear"
x_low_upsampled = F.interpolate(
x_low, size=get_interp_size(x_low, z_factor=2), mode="bilinear", align_corners=True
)

low_cls = self.low_classifier_conv(x_low_upsampled)
Expand Down Expand Up @@ -824,16 +822,13 @@ def interp(input, output_size, mode="bilinear"):
oh, ow = output_size

# normalize to [-1, 1]
h = torch.arange(0, oh) / (oh - 1) * 2 - 1
w = torch.arange(0, ow) / (ow - 1) * 2 - 1
h = torch.arange(0, oh, dtype=torch.float, device='cuda' if input.is_cuda else 'cpu') / (oh - 1) * 2 - 1
w = torch.arange(0, ow, dtype=torch.float, device='cuda' if input.is_cuda else 'cpu') / (ow - 1) * 2 - 1

grid = torch.zeros(oh, ow, 2)
grid = torch.zeros(oh, ow, 2, dtype=torch.float, device='cuda' if input.is_cuda else 'cpu')
grid[:, :, 0] = w.unsqueeze(0).repeat(oh, 1)
grid[:, :, 1] = h.unsqueeze(0).repeat(ow, 1).transpose(0, 1)
grid = grid.unsqueeze(0).repeat(n, 1, 1, 1) # grid.shape: [n, oh, ow, 2]
grid = Variable(grid)
if input.is_cuda:
grid = grid.cuda()

return F.grid_sample(input, grid, mode=mode)

Expand Down
2 changes: 1 addition & 1 deletion validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def validate(cfg, args):
split=cfg['data']['val_split'],
is_transform=True,
img_size=(cfg['data']['img_rows'],
cfg['data']['img_rows']),
cfg['data']['img_cols']),
)

n_classes = loader.n_classes
Expand Down

0 comments on commit 1db2391

Please sign in to comment.