diff --git a/official/cv/segment-anything/configs/coco_box_finetune.yaml b/official/cv/segment-anything/configs/coco_box_finetune.yaml index 47bca93c..e5108065 100644 --- a/official/cv/segment-anything/configs/coco_box_finetune.yaml +++ b/official/cv/segment-anything/configs/coco_box_finetune.yaml @@ -6,7 +6,7 @@ mode: 0 # 0: graph, 1: pynative jit_level: O1 # O0 or O1 work_root: &work_root ./work_dir/ log_level: info -amp_level: O2 +amp_level: auto # --------------------------------------------- # Part2: module setting diff --git a/official/cv/segment-anything/configs/flare_box_finetune.yaml b/official/cv/segment-anything/configs/flare_box_finetune.yaml index c5fb8f41..81d8e5dd 100644 --- a/official/cv/segment-anything/configs/flare_box_finetune.yaml +++ b/official/cv/segment-anything/configs/flare_box_finetune.yaml @@ -6,7 +6,7 @@ mode: 0 # 0: graph, 1: pynative jit_level: O1 # O0 or O1 work_root: &work_root ./work_dir/ log_level: info -amp_level: O2 +amp_level: auto # --------------------------------------------- # Part2: module setting diff --git a/official/cv/segment-anything/segment_anything/build_sam.py b/official/cv/segment-anything/segment_anything/build_sam.py index 5210250d..106ce736 100644 --- a/official/cv/segment-anything/segment_anything/build_sam.py +++ b/official/cv/segment-anything/segment_anything/build_sam.py @@ -85,7 +85,7 @@ def _build_sam( embed_dim=encoder_embed_dim, img_size=image_size, mlp_ratio=4, - norm_layer=partial(mint.nn.LayerNorm, eps=1e-6), + norm_layer=partial(nn.LayerNorm, epsilon=1e-6), # use approximate=False to be close to pytorch, ref: # https://www.mindspore.cn/docs/zh-CN/master/note/api_mapping/pytorch_diff/GELU.html?highlight=gelu act_layer=partial(GELU, approximate=False), diff --git a/official/cv/segment-anything/segment_anything/modeling/image_encoder.py b/official/cv/segment-anything/segment_anything/modeling/image_encoder.py index 7ca920e9..84faf6e0 100644 --- a/official/cv/segment-anything/segment_anything/modeling/image_encoder.py +++ b/official/cv/segment-anything/segment_anything/modeling/image_encoder.py @@ -18,7 +18,7 @@ def __init__( mlp_ratio: float = 4.0, out_chans: int = 256, qkv_bias: bool = True, - norm_layer: Type[nn.Cell] = mint.nn.LayerNorm, + norm_layer: Type[nn.Cell] = nn.LayerNorm, act_layer: Type[nn.Cell] = GELU, use_abs_pos: bool = True, use_rel_pos: bool = False, @@ -225,7 +225,7 @@ def construct(self, x: ms.Tensor) -> ms.Tensor: if self.use_rel_pos: attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)) - attn = mint.softmax(attn, dim=-1) + attn = mint.nn.functional.softmax(attn, dim=-1) x = mint.bmm(attn, v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1) x = self.proj(x) @@ -300,7 +300,7 @@ def get_rel_pos(q_size: int, k_size: int, rel_pos: ms.Tensor) -> ms.Tensor: # Interpolate rel pos if needed. if rel_pos.shape[0] != max_rel_dist: # Interpolate rel pos. - rel_pos_resized = mint.interpolate( + rel_pos_resized = mint.nn.functional.interpolate( rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), size=max_rel_dist, mode="linear", diff --git a/official/cv/segment-anything/segment_anything/modeling/loss.py b/official/cv/segment-anything/segment_anything/modeling/loss.py index 4af3565f..af53578c 100644 --- a/official/cv/segment-anything/segment_anything/modeling/loss.py +++ b/official/cv/segment-anything/segment_anything/modeling/loss.py @@ -29,7 +29,7 @@ def __init__(self, focal_factor=20.0, dice_factor=1.0, mse_factor=1.0, mask_thre self.focal_loss = FocalLoss(reduction='none') self.dice_loss = DiceLoss(reduction='none') - self.mse_loss = nn.MSELoss(reduction='none') + self.mse_loss = mint.nn.MSELoss(reduction='none') def construct(self, pred_mask, pred_iou, gt_mask, valid_boxes): """ diff --git a/official/cv/segment-anything/segment_anything/modeling/prompt_encoder.py b/official/cv/segment-anything/segment_anything/modeling/prompt_encoder.py index 9fa161a8..813791f3 100644 --- a/official/cv/segment-anything/segment_anything/modeling/prompt_encoder.py +++ b/official/cv/segment-anything/segment_anything/modeling/prompt_encoder.py @@ -218,6 +218,7 @@ def _pe_encoding(self, coords: ms.Tensor) -> ms.Tensor: coords = mint.matmul(coords, self.positional_encoding_gaussian_matrix.astype(dtype)) coords = 2 * np.pi * coords # outputs d_1 x ... x d_n x C shape + coords = coords.to(ms.float32) return mint.cat([mint.sin(coords), mint.cos(coords)], dim=-1) def construct(self, size: Tuple[int, int]) -> ms.Tensor: diff --git a/official/cv/segment-anything/segment_anything/modeling/sam.py b/official/cv/segment-anything/segment_anything/modeling/sam.py index 9faf499c..fc560fa4 100644 --- a/official/cv/segment-anything/segment_anything/modeling/sam.py +++ b/official/cv/segment-anything/segment_anything/modeling/sam.py @@ -149,7 +149,7 @@ def construct( ) # low_res_masks (n, 4, h, w) if multimask_output else (n, 1, h, w) # iou_predictions (n, 4) if multimask_output else (n, 1) - pred_mask = mint.interpolate(low_res_masks, (h, w), mode='bilinear', align_corners=False) + pred_mask = mint.nn.functional.interpolate(low_res_masks, (h, w), mode='bilinear', align_corners=False) pred_masks.append(pred_mask) pred_ious.append(iou_predictions) @@ -185,14 +185,14 @@ def postprocess_masks( (ms.Tensor): Batched masks in BxCxHxW format, where (H, W) is given by original_size. """ - masks = mint.interpolate( + masks = mint.nn.functional.interpolate( masks, (self.image_encoder.img_size, self.image_encoder.img_size), mode="bilinear", align_corners=False, ) masks = masks[..., : input_size[0], : input_size[1]] - masks = mint.interpolate(masks, original_size, mode="bilinear", align_corners=False) + masks = mint.nn.functional.interpolate(masks, original_size, mode="bilinear", align_corners=False) return masks def preprocess(self, x: ms.Tensor) -> ms.Tensor: diff --git a/official/cv/segment-anything/segment_anything/modeling/transformer.py b/official/cv/segment-anything/segment_anything/modeling/transformer.py index b7840867..75719a3e 100644 --- a/official/cv/segment-anything/segment_anything/modeling/transformer.py +++ b/official/cv/segment-anything/segment_anything/modeling/transformer.py @@ -50,7 +50,7 @@ def __init__( self.final_attn_token_to_image = Attention( embedding_dim, num_heads, downsample_rate=attention_downsample_rate ) - self.norm_final_attn = mint.nn.LayerNorm([embedding_dim]) + self.norm_final_attn = nn.LayerNorm([embedding_dim]) def construct( self, @@ -124,17 +124,17 @@ def __init__( """ super().__init__() self.self_attn = Attention(embedding_dim, num_heads) - self.norm1 = mint.nn.LayerNorm([embedding_dim]) + self.norm1 = nn.LayerNorm([embedding_dim]) self.cross_attn_token_to_image = Attention( embedding_dim, num_heads, downsample_rate=attention_downsample_rate ) - self.norm2 = mint.nn.LayerNorm([embedding_dim]) + self.norm2 = nn.LayerNorm([embedding_dim]) self.mlp = MLPBlock(embedding_dim, mlp_dim, activation) - self.norm3 = mint.nn.LayerNorm([embedding_dim]) + self.norm3 = nn.LayerNorm([embedding_dim]) - self.norm4 = mint.nn.LayerNorm([embedding_dim]) + self.norm4 = nn.LayerNorm([embedding_dim]) self.cross_attn_image_to_token = Attention( embedding_dim, num_heads, downsample_rate=attention_downsample_rate ) @@ -224,7 +224,7 @@ def construct(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: dtype = q.dtype attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens attn = attn / Tensor(math.sqrt(c_per_head), dtype) - attn = mint.softmax(attn, dim=-1) + attn = mint.nn.functional.softmax(attn, dim=-1) # Get output out = attn @ v diff --git a/official/cv/segment-anything/segment_anything/utils/callbacks.py b/official/cv/segment-anything/segment_anything/utils/callbacks.py index 71fbe44f..089fc8fd 100644 --- a/official/cv/segment-anything/segment_anything/utils/callbacks.py +++ b/official/cv/segment-anything/segment_anything/utils/callbacks.py @@ -130,7 +130,7 @@ def on_train_step_end(self, run_context: RunContext): self.accumulate_loss += loss if cur_step % self.log_interval == 0: - lr = cb_params.network.optimizer.learning_rate.learning_rate[cur_step] + lr = cb_params.network.optimizer.learning_rate.learning_rate[cur_step-1] smooth_loss = self.accumulate_loss / self.log_interval step_cost = time.time() - self.step_start_time diff --git a/official/cv/segment-anything/segment_anything/utils/transforms.py b/official/cv/segment-anything/segment_anything/utils/transforms.py index f34b7254..9122a56c 100644 --- a/official/cv/segment-anything/segment_anything/utils/transforms.py +++ b/official/cv/segment-anything/segment_anything/utils/transforms.py @@ -101,7 +101,7 @@ def apply_image_ms(self, image: ms.Tensor) -> ms.Tensor: target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.target_length) # TODO note original version has antialias=True, ref: # https://stackoverflow.com/questions/60949936/why-bilinear-scaling-of-images-with-pil-and-pytorch-produces-different-results - return mint.interpolate( + return mint.nn.functional.interpolate( image, target_size, mode="bilinear", align_corners=False ) diff --git a/official/cv/segment-anything/test/test_blip2_generate.py b/official/cv/segment-anything/test/test_blip2_generate.py index 785936a8..fb9ba207 100644 --- a/official/cv/segment-anything/test/test_blip2_generate.py +++ b/official/cv/segment-anything/test/test_blip2_generate.py @@ -31,7 +31,7 @@ logits_per_image = mint.matmul(image_features, text_features.T) / model.temp -probs = mint.softmax(logits_per_image, dim=-1).asnumpy() +probs = mint.nn.functional.softmax(logits_per_image, dim=-1).asnumpy() print('logits', logits_per_image) print('prob', probs) diff --git a/official/cv/segment-anything/test/test_blip2_image_in_sam.py b/official/cv/segment-anything/test/test_blip2_image_in_sam.py index ebcf2256..e9e36d69 100644 --- a/official/cv/segment-anything/test/test_blip2_image_in_sam.py +++ b/official/cv/segment-anything/test/test_blip2_image_in_sam.py @@ -37,7 +37,7 @@ logits_per_image = mint.matmul(image_features, text_features.T) / model.temp # (20, 5) -probs = mint.softmax(logits_per_image, dim=-1).asnumpy() # (20, 5) +probs = mint.nn.functional.softmax(logits_per_image, dim=-1).asnumpy() # (20, 5) for i in range(20): print(f'\n\n{i}') diff --git a/official/cv/segment-anything/test/test_clip_generate.py b/official/cv/segment-anything/test/test_clip_generate.py index f4aa616e..69b48f91 100644 --- a/official/cv/segment-anything/test/test_clip_generate.py +++ b/official/cv/segment-anything/test/test_clip_generate.py @@ -23,7 +23,7 @@ input_images = input_images.astype(mindspore.float32) logits_per_image, _ = model(input_images, input_ids) -probs = mint.softmax(logits_per_image, dim=-1).asnumpy() +probs = mint.nn.functional.softmax(logits_per_image, dim=-1).asnumpy() print('logits', logits_per_image) print('prob', probs) diff --git a/official/cv/segment-anything/train.py b/official/cv/segment-anything/train.py index 52b36ae6..5dae5865 100644 --- a/official/cv/segment-anything/train.py +++ b/official/cv/segment-anything/train.py @@ -34,6 +34,7 @@ def main(args) -> None: loss_fn = create_loss_fn(args.network.loss) network.set_train() network = amp.auto_mixed_precision(network, args.get('amp_level', 'O0')) + loss_fn = amp.auto_mixed_precision(loss_fn, args.get('amp_level', 'O0')) # Step3: create optimizer, including learning rate scheduler and group parameter settings optimizer = create_optimizer(params=network.trainable_params(), args=args.optimizer,