Skip to content

Commit

Permalink
adding support to isinf, isnan and nonzero (#196)
Browse files Browse the repository at this point in the history
  • Loading branch information
karalinkw authored Mar 29, 2024
1 parent f04c3a7 commit 86e6117
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 0 deletions.
3 changes: 3 additions & 0 deletions onnx2torch/node_converters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
from onnx2torch.node_converters.global_average_pool import *
from onnx2torch.node_converters.identity import *
from onnx2torch.node_converters.instance_norm import *
from onnx2torch.node_converters.isinf import *
from onnx2torch.node_converters.isnan import *
from onnx2torch.node_converters.layer_norm import *
from onnx2torch.node_converters.logical import *
from onnx2torch.node_converters.lrn import *
Expand All @@ -32,6 +34,7 @@
from onnx2torch.node_converters.mod import *
from onnx2torch.node_converters.neg import *
from onnx2torch.node_converters.nms import *
from onnx2torch.node_converters.nonzero import *
from onnx2torch.node_converters.pad import *
from onnx2torch.node_converters.pow import *
from onnx2torch.node_converters.range import *
Expand Down
34 changes: 34 additions & 0 deletions onnx2torch/node_converters/isinf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
__all__ = [
'OnnxIsInf',
]

import torch
from torch import nn

from onnx2torch.node_converters.registry import add_converter
from onnx2torch.onnx_graph import OnnxGraph
from onnx2torch.onnx_node import OnnxNode
from onnx2torch.utils.common import OnnxMapping
from onnx2torch.utils.common import OnnxToTorchModule
from onnx2torch.utils.common import OperationConverterResult


class OnnxIsInf(nn.Module, OnnxToTorchModule):
def __init__(self):
super().__init__()

def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
return torch.isinf(input_tensor)


@add_converter(operation_type='IsInf', version=10)
def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult:
torch_module = OnnxIsInf()

return OperationConverterResult(
torch_module=torch_module,
onnx_mapping=OnnxMapping(
inputs=(node.input_values[0],),
outputs=node.output_values,
),
)
34 changes: 34 additions & 0 deletions onnx2torch/node_converters/isnan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
__all__ = [
'OnnxIsNaN',
]

import torch
from torch import nn

from onnx2torch.node_converters.registry import add_converter
from onnx2torch.onnx_graph import OnnxGraph
from onnx2torch.onnx_node import OnnxNode
from onnx2torch.utils.common import OnnxMapping
from onnx2torch.utils.common import OnnxToTorchModule
from onnx2torch.utils.common import OperationConverterResult


class OnnxIsNaN(nn.Module, OnnxToTorchModule):
def __init__(self):
super().__init__()

def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
return torch.isnan(input_tensor)


@add_converter(operation_type='IsNaN', version=13)
def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult:
torch_module = OnnxIsNaN()

return OperationConverterResult(
torch_module=torch_module,
onnx_mapping=OnnxMapping(
inputs=(node.input_values[0],),
outputs=node.output_values,
),
)
34 changes: 34 additions & 0 deletions onnx2torch/node_converters/nonzero.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
__all__ = [
'OnnxNonZero',
]

import torch
from torch import nn

from onnx2torch.node_converters.registry import add_converter
from onnx2torch.onnx_graph import OnnxGraph
from onnx2torch.onnx_node import OnnxNode
from onnx2torch.utils.common import OnnxMapping
from onnx2torch.utils.common import OnnxToTorchModule
from onnx2torch.utils.common import OperationConverterResult


class OnnxNonZero(nn.Module, OnnxToTorchModule):
def __init__(self):
super().__init__()

def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
return torch.nonzero(input_tensor)


@add_converter(operation_type='NonZero', version=13)
def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult:
torch_module = OnnxNonZero()

return OperationConverterResult(
torch_module=torch_module,
onnx_mapping=OnnxMapping(
inputs=(node.input_values[0],),
outputs=node.output_values,
),
)

0 comments on commit 86e6117

Please sign in to comment.