From 2e0908db2b2a6be3f33fdbc61d764a2396d15b71 Mon Sep 17 00:00:00 2001 From: Aditya Goel <48102515+adityagoel4512@users.noreply.github.com> Date: Fri, 4 Aug 2023 15:46:40 +0100 Subject: [PATCH] Add RegexFullMatch operator (#5401) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Description This PR introduces the `RegexFullMatch` operator, as originally proposed in https://github.com/onnx/onnx/issues/5317. The `RegexFullMatch` operator takes one string tensor as input and returns a bool tensor of identical shape indicating if each element fully matches the regex pattern encoded in the `pattern` string attribute. This attribute is a string and we expect valid [re2](https://github.com/google/re2) regex. Some examples are as follows: ``` RegexFullMatch(["www.google.com", "www.facebook.com", "www.bbc.co.uk"], "www\.[\w.-]+\.\bcom\b") => [True, True, False] RegexFullMatch([["account@gmail.com", "account@hotmail.com"], ["not email", "account2@yahoo.com"]], "(\W|^)[\w.\-]{0,25}@(yahoo|gmail)\.com(\W|$)") => [[True, False], [False, True]] ``` ### Motivation and Context Closes https://github.com/onnx/onnx/issues/5317. Following discussion at the last Operators SIG Weekly the "engine" attribute has been dropped in favour of simply using [re2](https://github.com/google/re2) syntax for now. This reflects the fact that both [Tensorflow](https://www.tensorflow.org/api_docs/python/tf/strings/regex_full_match) and [PyTorch](https://pytorch.org/text/0.15.0/transforms.html#regextokenizer) operators requiring regex use re2 already. --------- Signed-off-by: Aditya Goel Signed-off-by: Chun-Wei Chen Signed-off-by: Aditya Goel <48102515+adityagoel4512@users.noreply.github.com> Co-authored-by: Chun-Wei Chen Co-authored-by: Christian Bourjau Co-authored-by: Xavier Dupré --- .azure-pipelines/Linux-CI.yml | 2 +- .azure-pipelines/MacOS-CI.yml | 2 +- .azure-pipelines/Windows-CI.yml | 2 +- .github/workflows/release_win.yml | 3 + README.md | 2 +- docs/Changelog.md | 38 ++++++ docs/Operators.md | 116 ++++++++++++++++++ docs/TestCoverage.md | 74 ++++++++++- .../test/case/node/regex_full_match.py | 67 ++++++++++ .../test_regex_full_match_basic/model.onnx | Bin 0 -> 148 bytes .../test_data_set_0/input_0.pb | 1 + .../test_data_set_0/output_0.pb | Bin 0 -> 12 bytes .../model.onnx | Bin 0 -> 187 bytes .../test_data_set_0/input_0.pb | 1 + .../test_data_set_0/output_0.pb | Bin 0 -> 15 bytes .../test_regex_full_match_empty/model.onnx | Bin 0 -> 180 bytes .../test_data_set_0/input_0.pb | Bin 0 -> 9 bytes .../test_data_set_0/output_0.pb | Bin 0 -> 11 bytes onnx/defs/operator_sets.h | 2 + onnx/defs/text/defs.cc | 28 +++++ onnx/reference/ops/_op_list.py | 1 + onnx/reference/ops/op_regex_full_match.py | 34 +++++ onnx/test/automatic_upgrade_test.py | 10 ++ onnx/test/reference_evaluator_backend_test.py | 7 ++ onnx/test/reference_evaluator_test.py | 54 ++++++++ onnx/test/shape_inference_test.py | 26 ++++ onnx/test/test_backend_onnxruntime.py | 1 + onnx/test/test_backend_reference.py | 7 ++ requirements-dev.txt | 2 + requirements-reference.txt | 1 + requirements-release.txt | 1 + setup.py | 6 + 32 files changed, 483 insertions(+), 5 deletions(-) create mode 100644 onnx/backend/test/case/node/regex_full_match.py create mode 100644 onnx/backend/test/data/node/test_regex_full_match_basic/model.onnx create mode 100644 onnx/backend/test/data/node/test_regex_full_match_basic/test_data_set_0/input_0.pb create mode 100644 onnx/backend/test/data/node/test_regex_full_match_basic/test_data_set_0/output_0.pb create mode 100644 onnx/backend/test/data/node/test_regex_full_match_email_domain/model.onnx create mode 100644 onnx/backend/test/data/node/test_regex_full_match_email_domain/test_data_set_0/input_0.pb create mode 100644 onnx/backend/test/data/node/test_regex_full_match_email_domain/test_data_set_0/output_0.pb create mode 100644 onnx/backend/test/data/node/test_regex_full_match_empty/model.onnx create mode 100644 onnx/backend/test/data/node/test_regex_full_match_empty/test_data_set_0/input_0.pb create mode 100644 onnx/backend/test/data/node/test_regex_full_match_empty/test_data_set_0/output_0.pb create mode 100644 onnx/reference/ops/op_regex_full_match.py create mode 100644 requirements-reference.txt diff --git a/.azure-pipelines/Linux-CI.yml b/.azure-pipelines/Linux-CI.yml index 80e5be406fa..661daf3032c 100644 --- a/.azure-pipelines/Linux-CI.yml +++ b/.azure-pipelines/Linux-CI.yml @@ -68,7 +68,7 @@ jobs: export CMAKE_ARGS="-DONNX_WERROR=ON -DONNX_USE_PROTOBUF_SHARED_LIBS=ON" # enable more sanitizer export CMAKE_ARGS="${CMAKE_ARGS} -DCMAKE_CXX_FLAGS='-fsanitize=undefined -fno-sanitize-recover=all '" - pip install -e . -v + pip install -e ".[reference]" -v displayName: 'Install ONNX and dependencies' - script: | diff --git a/.azure-pipelines/MacOS-CI.yml b/.azure-pipelines/MacOS-CI.yml index d924799dd65..4fb761b9849 100644 --- a/.azure-pipelines/MacOS-CI.yml +++ b/.azure-pipelines/MacOS-CI.yml @@ -62,7 +62,7 @@ jobs: if [ '$(onnx_lite)' == '1' ]; then export CMAKE_ARGS="${CMAKE_ARGS} -DONNX_USE_LITE_PROTO=ON" fi - pip install -e . -v + pip install -e ".[reference]" -v displayName: 'Install dependencies and ONNX' - script: | diff --git a/.azure-pipelines/Windows-CI.yml b/.azure-pipelines/Windows-CI.yml index 9513d3948c9..44b345ae05b 100644 --- a/.azure-pipelines/Windows-CI.yml +++ b/.azure-pipelines/Windows-CI.yml @@ -64,7 +64,7 @@ jobs: set CMAKE_ARGS=-DONNX_USE_PROTOBUF_SHARED_LIBS=OFF -DONNX_USE_LITE_PROTO=ON -DONNX_WERROR=OFF ) - pip install -e . -v + pip install -e ".[reference]" -v pytest IF NOT %ERRORLEVEL% EQU 0 ( @echo "pytest failed" diff --git a/.github/workflows/release_win.yml b/.github/workflows/release_win.yml index 1ec67db8ff8..ce25e42d2e2 100644 --- a/.github/workflows/release_win.yml +++ b/.github/workflows/release_win.yml @@ -49,6 +49,9 @@ jobs: run: | python -m pip install -q --upgrade pip cd onnx + if ('${{ matrix.architecture }}' -eq 'x86') { + sed -i '' '/google-re2/d' requirements-release.txt + } python -m pip install -q -r requirements-release.txt - name: Build ONNX wheel diff --git a/README.md b/README.md index 4af36235cde..08a46e70c6b 100644 --- a/README.md +++ b/README.md @@ -82,7 +82,7 @@ A roadmap process takes place every year. More details can be found [here](https ONNX released packages are published in PyPi. ```sh -pip install onnx +pip install onnx # or pip install onnx[reference] for optional reference implementation dependencies ``` [ONNX weekly packages](https://pypi.org/project/onnx-weekly/) are published in PyPI to enable experimentation and early testing. diff --git a/docs/Changelog.md b/docs/Changelog.md index aec36b1477a..9392f592666 100644 --- a/docs/Changelog.md +++ b/docs/Changelog.md @@ -24092,6 +24092,44 @@ This version of the operator has been available since version 20 of the default
Constrain grid types to float tensors.
+### **RegexFullMatch-20** + + RegexFullMatch performs a full regex match on each element of the input tensor. If an element fully matches the regex pattern specified as an attribute, the corresponding element in the output is True and it is False otherwise. [RE2](https://github.com/google/re2/wiki/Syntax) regex syntax is used. + +#### Version + +This version of the operator has been available since version 20 of the default ONNX operator set. + +#### Attributes + +
+
pattern : string
+
Regex pattern to match on. This must be valid RE2 syntax.
+
+ +#### Inputs + +
+
X (non-differentiable) : T1
+
Tensor with strings to match on.
+
+ +#### Outputs + +
+
Y (non-differentiable) : T2
+
Tensor of bools indicating if each input string fully matches the regex pattern specified.
+
+ +#### Type Constraints + +
+
T1 : tensor(string)
+
Inputs must be UTF-8 strings
+
T2 : tensor(bool)
+
Outputs are bools and are True where there is a full regex match and False otherwise.
+
+ ### **StringConcat-20** StringConcat concatenates string tensors elementwise (with NumPy-style broadcasting support) diff --git a/docs/Operators.md b/docs/Operators.md index 7cff608a3c1..56054753755 100644 --- a/docs/Operators.md +++ b/docs/Operators.md @@ -118,6 +118,7 @@ For an operator input/output's differentiability, it can be differentiable, |ReduceMin|18, 13, 12, 11, 1| |ReduceProd|18, 13, 11, 1| |ReduceSum|13, 11, 1| +|RegexFullMatch|20| |Reshape|19, 14, 13, 5, 1| |Resize|19, 18, 13, 11, 10| |ReverseSequence|10| @@ -22588,6 +22589,121 @@ expect( +### **RegexFullMatch** + + RegexFullMatch performs a full regex match on each element of the input tensor. If an element fully matches the regex pattern specified as an attribute, the corresponding element in the output is True and it is False otherwise. [RE2](https://github.com/google/re2/wiki/Syntax) regex syntax is used. + +#### Version + +This version of the operator has been available since version 20 of the default ONNX operator set. + +#### Attributes + +
+
pattern : string
+
Regex pattern to match on. This must be valid RE2 syntax.
+
+ +#### Inputs + +
+
X (non-differentiable) : T1
+
Tensor with strings to match on.
+
+ +#### Outputs + +
+
Y (non-differentiable) : T2
+
Tensor of bools indicating if each input string fully matches the regex pattern specified.
+
+ +#### Type Constraints + +
+
T1 : tensor(string)
+
Inputs must be UTF-8 strings
+
T2 : tensor(bool)
+
Outputs are bools and are True where there is a full regex match and False otherwise.
+
+ + +#### Examples + +
+basic + +```python +node = onnx.helper.make_node( + "RegexFullMatch", + inputs=["X"], + outputs=["Y"], + pattern=r"www\.[\w.-]+\.\bcom\b", +) + +x = np.array(["www.google.com", "www.facebook.com", "www.bbc.co.uk"]).astype( + object +) +result = np.array([True, True, False]) +expect(node, inputs=[x], outputs=[result], name="test_regex_full_match_basic") +``` + +
+ + +
+match_email_domain + +```python +node = onnx.helper.make_node( + "RegexFullMatch", + inputs=["X"], + outputs=["Y"], + pattern=r"(\W|^)[\w.\-]{0,25}@(yahoo|gmail)\.com(\W|$)", +) + +x = np.array( + [ + ["account@gmail.com", "account@hotmail.com"], + ["not email", "account2@yahoo.com"], + ] +).astype(object) +result = np.array([[True, False], [False, True]]) +expect( + node, + inputs=[x], + outputs=[result], + name="test_regex_full_match_email_domain", +) +``` + +
+ + +
+match_empty + +```python +node = onnx.helper.make_node( + "RegexFullMatch", + inputs=["X"], + outputs=["Y"], + pattern=r"(\W|^)[\w.\-]{0,25}@(yahoo|gmail)\.com(\W|$)", +) + +x = np.array([[], []]).astype(object) +result = np.array([[], []]).astype(bool) +expect( + node, + inputs=[x], + outputs=[result], + name="test_regex_full_match_empty", +) +``` + +
+ + ### **Relu** Relu takes one input data (Tensor) and produces one output data diff --git a/docs/TestCoverage.md b/docs/TestCoverage.md index 3ab6b75f1c1..3e2a2d8e77b 100644 --- a/docs/TestCoverage.md +++ b/docs/TestCoverage.md @@ -6,7 +6,7 @@ * [Overall Test Coverage](#overall-test-coverage) # Node Test Coverage ## Summary -Node tests have covered 177/190 (93.16%, 5 generators excluded) common operators. +Node tests have covered 178/191 (93.19%, 5 generators excluded) common operators. Node tests have covered 0/0 (N/A) experimental operators. @@ -15258,6 +15258,78 @@ expect( +### RegexFullMatch +There are 3 test cases, listed as following: +
+basic + +```python +node = onnx.helper.make_node( + "RegexFullMatch", + inputs=["X"], + outputs=["Y"], + pattern=r"www\.[\w.-]+\.\bcom\b", +) + +x = np.array(["www.google.com", "www.facebook.com", "www.bbc.co.uk"]).astype( + object +) +result = np.array([True, True, False]) +expect(node, inputs=[x], outputs=[result], name="test_regex_full_match_basic") +``` + +
+
+match_email_domain + +```python +node = onnx.helper.make_node( + "RegexFullMatch", + inputs=["X"], + outputs=["Y"], + pattern=r"(\W|^)[\w.\-]{0,25}@(yahoo|gmail)\.com(\W|$)", +) + +x = np.array( + [ + ["account@gmail.com", "account@hotmail.com"], + ["not email", "account2@yahoo.com"], + ] +).astype(object) +result = np.array([[True, False], [False, True]]) +expect( + node, + inputs=[x], + outputs=[result], + name="test_regex_full_match_email_domain", +) +``` + +
+
+match_empty + +```python +node = onnx.helper.make_node( + "RegexFullMatch", + inputs=["X"], + outputs=["Y"], + pattern=r"(\W|^)[\w.\-]{0,25}@(yahoo|gmail)\.com(\W|$)", +) + +x = np.array([[], []]).astype(object) +result = np.array([[], []]).astype(bool) +expect( + node, + inputs=[x], + outputs=[result], + name="test_regex_full_match_empty", +) +``` + +
+ + ### Relu There are 1 test cases, listed as following:
diff --git a/onnx/backend/test/case/node/regex_full_match.py b/onnx/backend/test/case/node/regex_full_match.py new file mode 100644 index 00000000000..5fe70dba13c --- /dev/null +++ b/onnx/backend/test/case/node/regex_full_match.py @@ -0,0 +1,67 @@ +# Copyright (c) ONNX Project Contributors +# +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np + +import onnx +from onnx.backend.test.case.base import Base +from onnx.backend.test.case.node import expect + + +class RegexFullMatch(Base): + @staticmethod + def export_basic() -> None: + node = onnx.helper.make_node( + "RegexFullMatch", + inputs=["X"], + outputs=["Y"], + pattern=r"www\.[\w.-]+\.\bcom\b", + ) + + x = np.array(["www.google.com", "www.facebook.com", "www.bbc.co.uk"]).astype( + object + ) + result = np.array([True, True, False]) + expect(node, inputs=[x], outputs=[result], name="test_regex_full_match_basic") + + @staticmethod + def export_match_email_domain() -> None: + node = onnx.helper.make_node( + "RegexFullMatch", + inputs=["X"], + outputs=["Y"], + pattern=r"(\W|^)[\w.\-]{0,25}@(yahoo|gmail)\.com(\W|$)", + ) + + x = np.array( + [ + ["account@gmail.com", "account@hotmail.com"], + ["not email", "account2@yahoo.com"], + ] + ).astype(object) + result = np.array([[True, False], [False, True]]) + expect( + node, + inputs=[x], + outputs=[result], + name="test_regex_full_match_email_domain", + ) + + @staticmethod + def export_match_empty() -> None: + node = onnx.helper.make_node( + "RegexFullMatch", + inputs=["X"], + outputs=["Y"], + pattern=r"(\W|^)[\w.\-]{0,25}@(yahoo|gmail)\.com(\W|$)", + ) + + x = np.array([[], []]).astype(object) + result = np.array([[], []]).astype(bool) + expect( + node, + inputs=[x], + outputs=[result], + name="test_regex_full_match_empty", + ) diff --git a/onnx/backend/test/data/node/test_regex_full_match_basic/model.onnx b/onnx/backend/test/data/node/test_regex_full_match_basic/model.onnx new file mode 100644 index 0000000000000000000000000000000000000000..539f0ef722c29f0dcc4d029ae679d93fef7ec431 GIT binary patch literal 148 zcmdy&{F1NFGwsY zNiE7#5-l$;kI{>cDc94D)sE4NNlMPojY(R-$SfocG9V&P&C5CH(E>Lv&P literal 0 HcmV?d00001 diff --git a/onnx/backend/test/data/node/test_regex_full_match_basic/test_data_set_0/input_0.pb b/onnx/backend/test/data/node/test_regex_full_match_basic/test_data_set_0/input_0.pb new file mode 100644 index 00000000000..8b647d880a7 --- /dev/null +++ b/onnx/backend/test/data/node/test_regex_full_match_basic/test_data_set_0/input_0.pb @@ -0,0 +1 @@ +2www.google.com2www.facebook.com2 www.bbc.co.ukBX \ No newline at end of file diff --git a/onnx/backend/test/data/node/test_regex_full_match_basic/test_data_set_0/output_0.pb b/onnx/backend/test/data/node/test_regex_full_match_basic/test_data_set_0/output_0.pb new file mode 100644 index 0000000000000000000000000000000000000000..b676b4374d1c0ef9b19b88dde3ff07172f6bd786 GIT binary patch literal 12 Tcmd;J7T|PZjPzn=WMlvU2QC2+ literal 0 HcmV?d00001 diff --git a/onnx/backend/test/data/node/test_regex_full_match_email_domain/model.onnx b/onnx/backend/test/data/node/test_regex_full_match_email_domain/model.onnx new file mode 100644 index 0000000000000000000000000000000000000000..bf03e4a2dee5b2a1d20e691afcd5dd6b94c9e9de GIT binary patch literal 187 zcmdPM_EXauvbYkIR H5D)()); fn(GetOpSchema()); fn(GetOpSchema()); + fn(GetOpSchema()); fn(GetOpSchema()); } }; diff --git a/onnx/defs/text/defs.cc b/onnx/defs/text/defs.cc index 01fc70b3237..5ff70cec265 100644 --- a/onnx/defs/text/defs.cc +++ b/onnx/defs/text/defs.cc @@ -33,6 +33,34 @@ ONNX_OPERATOR_SET_SCHEMA( *ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape()); })); +static const char* RegexFullMatch_doc = + R"DOC(RegexFullMatch performs a full regex match on each element of the input tensor. If an element fully matches the regex pattern specified as an attribute, the corresponding element in the output is True and it is False otherwise. [RE2](https://github.com/google/re2/wiki/Syntax) regex syntax is used.)DOC"; +ONNX_OPERATOR_SET_SCHEMA( + RegexFullMatch, + 20, + OpSchema() + .Input(0, "X", "Tensor with strings to match on.", "T1", OpSchema::Single, true, 1, OpSchema::NonDifferentiable) + .Attr("pattern", "Regex pattern to match on. This must be valid RE2 syntax.", AttributeProto::STRING, false) + .Output( + 0, + "Y", + "Tensor of bools indicating if each input string fully matches the regex pattern specified.", + "T2", + OpSchema::Single, + true, + 1, + OpSchema::NonDifferentiable) + .TypeConstraint("T1", {"tensor(string)"}, "Inputs must be UTF-8 strings") + .TypeConstraint( + "T2", + {"tensor(bool)"}, + "Outputs are bools and are True where there is a full regex match and False otherwise.") + .SetDoc(RegexFullMatch_doc) + .TypeAndShapeInferenceFunction([](InferenceContext& ctx) { + updateOutputElemType(ctx, 0, TensorProto::BOOL); + propagateShapeFromInputToOutput(ctx, 0, 0); + })); + static const char* StringSplit_doc = R"DOC(StringSplit splits a string tensor's elements into substrings based on a delimiter attribute and a maxsplit attribute. diff --git a/onnx/reference/ops/_op_list.py b/onnx/reference/ops/_op_list.py index be0925741a8..5caabeed7ce 100644 --- a/onnx/reference/ops/_op_list.py +++ b/onnx/reference/ops/_op_list.py @@ -176,6 +176,7 @@ ReduceSumSquare_1, ReduceSumSquare_18, ) +from onnx.reference.ops.op_regex_full_match import RegexFullMatch from onnx.reference.ops.op_relu import Relu from onnx.reference.ops.op_reshape import Reshape_5, Reshape_14 from onnx.reference.ops.op_resize import Resize diff --git a/onnx/reference/ops/op_regex_full_match.py b/onnx/reference/ops/op_regex_full_match.py new file mode 100644 index 00000000000..c3d65849113 --- /dev/null +++ b/onnx/reference/ops/op_regex_full_match.py @@ -0,0 +1,34 @@ +# Copyright (c) ONNX Project Contributors + +# SPDX-License-Identifier: Apache-2.0 +# pylint: disable=R0912,R0913,W0221 + +import numpy as np + +from onnx.reference.op_run import OpRun + +_acceptable_str_dtypes = ("U", "O") + + +class RegexFullMatch(OpRun): + def _run(self, x, pattern=None): + try: + # pylint: disable=import-outside-toplevel` + import re2 + except ImportError as e: + raise ImportError( + "re2 must be installed to use the reference implementation of the RegexFullMatch operator" + ) from e + + # As per onnx/mapping.py, object numpy dtype corresponds to TensorProto.STRING + if x.dtype.kind not in _acceptable_str_dtypes: + raise TypeError(f"Input must be string tensor, received dtype {x.dtype}") + try: + regex = re2.compile(pattern) + except re2.error as e: + raise ValueError(f"Invalid regex pattern {pattern!r}") from e + + fullmatch_func = np.vectorize( + lambda x: regex.fullmatch(x) is not None, otypes=[np.bool_] + ) + return (fullmatch_func(x),) diff --git a/onnx/test/automatic_upgrade_test.py b/onnx/test/automatic_upgrade_test.py index 4b820ab834a..8fa6dea19c5 100644 --- a/onnx/test/automatic_upgrade_test.py +++ b/onnx/test/automatic_upgrade_test.py @@ -1805,6 +1805,16 @@ def test_StringConcat(self) -> None: [[2, 3]], ) + def test_RegexFullMatch(self) -> None: + self._test_op_upgrade( + "RegexFullMatch", + 20, + [[2, 3]], + [[2, 3]], + [TensorProto.STRING], + [TensorProto.BOOL], + ) + def test_ops_tested(self) -> None: all_schemas = onnx.defs.get_all_schemas() all_op_names = [schema.name for schema in all_schemas if schema.domain == ""] diff --git a/onnx/test/reference_evaluator_backend_test.py b/onnx/test/reference_evaluator_backend_test.py index 3b4cd63d4d2..65336296cc8 100644 --- a/onnx/test/reference_evaluator_backend_test.py +++ b/onnx/test/reference_evaluator_backend_test.py @@ -22,6 +22,7 @@ import os import pprint +import sys import unittest try: @@ -82,6 +83,12 @@ "test_castlike_FLOAT_to_BFLOAT16", "test_castlike_FLOAT_to_BFLOAT16_expanded", } +if sys.platform == "win32": + SKIP_TESTS |= { + "test_regex_full_match_basic", + "test_regex_full_match_email_domain", + "test_regex_full_match_empty", + } def assert_allclose_string(expected, value): diff --git a/onnx/test/reference_evaluator_test.py b/onnx/test/reference_evaluator_test.py index b659b5e0409..6b56d9d3d7c 100644 --- a/onnx/test/reference_evaluator_test.py +++ b/onnx/test/reference_evaluator_test.py @@ -13,6 +13,7 @@ import itertools import math +import sys import unittest from contextlib import redirect_stdout from functools import wraps @@ -4949,6 +4950,59 @@ def test_string_split( num_splits, np.array(expected_num_splits, dtype=np.int64) ) + @parameterized.parameterized.expand( + [ + ( + ["www.google.com", "www.facebook.com", "www.bbc.co.uk"], + r"www\.[\w.-]+\.\bcom\b", + [True, True, False], + (3,), + ), + ( + [["Onnx", "tensorflow", "Numpy"], ["Pytorch", "Cython", "numba"]], + r"^[A-Z][a-z]*$", + [[True, False, True], [True, True, False]], + (2, 3), + ), + ( + [ + "account@gmail.com", + "account@hotmail.com", + "not email", + "account2@yahoo.com", + ], + r"(\W|^)[\w.\-]{0,25}@(yahoo|gmail)\.com(\W|$)", + [True, False, False, True], + (4,), + ), + ] + ) + @unittest.skipIf( + sys.platform == "win32", "google-re2 package is not built for win32" + ) + def test_regex_full_match(self, x, pattern, expected, expected_shape): + X = make_tensor_value_info("X", TensorProto.STRING, None) + Y = make_tensor_value_info("Y", TensorProto.BOOL, None) + node = make_node("RegexFullMatch", inputs=["X"], outputs=["Y"], pattern=pattern) + model = make_model(make_graph([node], "g", [X], [Y])) + ref = ReferenceEvaluator(model) + result, *_ = ref.run(None, {"X": np.array(x)}) + np.testing.assert_array_equal(result, expected) + self.assertEqual(result.dtype.kind, "b") + self.assertEqual(result.shape, expected_shape) + + @unittest.skipIf( + sys.platform == "win32", "google-re2 package is not built for win32" + ) + def test_regex_invalid_pattern(self): + X = make_tensor_value_info("X", TensorProto.STRING, None) + Y = make_tensor_value_info("Y", TensorProto.BOOL, None) + node = make_node("RegexFullMatch", inputs=["X"], outputs=["Y"], pattern="x)") + model = make_model(make_graph([node], "g", [X], [Y])) + ref = ReferenceEvaluator(model) + with self.assertRaises(ValueError): + ref.run(None, {"X": np.array(["x"])}) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/onnx/test/shape_inference_test.py b/onnx/test/shape_inference_test.py index 12090283dc6..3cae5e97957 100644 --- a/onnx/test/shape_inference_test.py +++ b/onnx/test/shape_inference_test.py @@ -1622,6 +1622,32 @@ def test_stringconcat_broadcasting(self, _, version) -> None: opset_imports=[helper.make_opsetid(ONNX_DOMAIN, version)], ) + @parameterized.expand(all_versions_for("RegexFullMatch")) + def test_regex_full_match(self, _, version) -> None: + graph = self._make_graph( + [("x", TensorProto.STRING, (2, 4, 3, 9))], + [make_node("RegexFullMatch", ["x"], ["y"], pattern=r"^[A-Z][a-z]*$")], + [], + ) + self._assert_inferred( + graph, + [make_tensor_value_info("y", TensorProto.BOOL, (2, 4, 3, 9))], + opset_imports=[helper.make_opsetid(ONNX_DOMAIN, version)], + ) + + @parameterized.expand(all_versions_for("RegexFullMatch")) + def test_regex_full_match_empty_shape(self, _, version) -> None: + graph = self._make_graph( + [("x", TensorProto.STRING, ())], + [make_node("RegexFullMatch", ["x"], ["y"], pattern=r"^[A-Z][a-z]*$")], + [], + ) + self._assert_inferred( + graph, + [make_tensor_value_info("y", TensorProto.BOOL, ())], + opset_imports=[helper.make_opsetid(ONNX_DOMAIN, version)], + ) + def test_unsqueeze_regular(self) -> None: graph = self._make_graph( [("x", TensorProto.FLOAT, (3, 2)), ("axes", TensorProto.INT64, (4,))], diff --git a/onnx/test/test_backend_onnxruntime.py b/onnx/test/test_backend_onnxruntime.py index de0b08fe802..974e7e0c378 100644 --- a/onnx/test/test_backend_onnxruntime.py +++ b/onnx/test/test_backend_onnxruntime.py @@ -252,6 +252,7 @@ def run_node(cls, node, inputs, device=None, outputs_info=None, **kwargs): "|equal" "|identity" "|reshape" + "|regex_full_match" "|string_split" "|string_concat" "|gelu" diff --git a/onnx/test/test_backend_reference.py b/onnx/test/test_backend_reference.py index 0beafa9e59f..3627060f4bd 100644 --- a/onnx/test/test_backend_reference.py +++ b/onnx/test/test_backend_reference.py @@ -4,6 +4,7 @@ import os import platform +import sys import unittest from typing import Any @@ -172,6 +173,12 @@ def run_node(cls, node, inputs, device=None, outputs_info=None, **kwargs): # The following tests fail due to discrepancies (small but still higher than 1e-7). backend_test.exclude("test_adam_multiple") # 1e-2 +# Currently google-re2 is not supported on Win32 and is required for the reference implementation of RegexFullMatch. +if sys.platform == "win32": + backend_test.exclude("test_regex_full_match_basic_cpu") + backend_test.exclude("test_regex_full_match_email_domain_cpu") + backend_test.exclude("test_regex_full_match_empty_cpu") + # import all test cases at global scope to make them visible to python.unittest globals().update(backend_test.test_cases) diff --git a/requirements-dev.txt b/requirements-dev.txt index 05cbbec5bfa..66aa8f93a9d 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -11,4 +11,6 @@ twine # Dependencies for linting. Versions match those in setup.py. lintrunner>=0.10.7 lintrunner-adapters>=0.8 +# Dependencies for the reference implementation. +google-re2 # Edit additional linter dependencies in requirements-lintrunner.txt diff --git a/requirements-reference.txt b/requirements-reference.txt new file mode 100644 index 00000000000..2347a47f5cb --- /dev/null +++ b/requirements-reference.txt @@ -0,0 +1 @@ +google-re2 diff --git a/requirements-release.txt b/requirements-release.txt index 96c766cfef4..fee00ee032e 100644 --- a/requirements-release.txt +++ b/requirements-release.txt @@ -7,3 +7,4 @@ wheel setuptools twine parameterized +google-re2 diff --git a/setup.py b/setup.py index ad5981d1928..4ff8d769100 100644 --- a/setup.py +++ b/setup.py @@ -355,6 +355,12 @@ def build_extensions(self): "lintrunner-adapters>=0.3", ] +if not os.path.exists("requirements-reference.txt"): + raise FileNotFoundError("Unable to find requirements-reference.txt") + +with open("requirements-reference.txt") as f: + extras_require["reference"] = f.read().splitlines() + ################################################################################ # Final ################################################################################