Skip to content

Commit

Permalink
Fix torch bloat16 -> numpy float32 conversion for compile max-autotune (
Browse files Browse the repository at this point in the history
  • Loading branch information
mawanda-jun authored Dec 7, 2023
1 parent 97a69cf commit ceaae4b
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions segment_anything_fast/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,8 @@ def predict(
)

masks_np = masks[0].detach().cpu().numpy()
iou_predictions_np = iou_predictions[0].detach().cpu().numpy()
low_res_masks_np = low_res_masks[0].detach().cpu().numpy()
iou_predictions_np = iou_predictions[0].detach().cpu().float().numpy()
low_res_masks_np = low_res_masks[0].detach().cpu().float().numpy()
return masks_np, iou_predictions_np, low_res_masks_np

@torch.no_grad()
Expand Down

0 comments on commit ceaae4b

Please sign in to comment.