diff --git a/.azure-pipelines/Linux-CI.yml b/.azure-pipelines/Linux-CI.yml
index 80e5be406fa..661daf3032c 100644
--- a/.azure-pipelines/Linux-CI.yml
+++ b/.azure-pipelines/Linux-CI.yml
@@ -68,7 +68,7 @@ jobs:
export CMAKE_ARGS="-DONNX_WERROR=ON -DONNX_USE_PROTOBUF_SHARED_LIBS=ON"
# enable more sanitizer
export CMAKE_ARGS="${CMAKE_ARGS} -DCMAKE_CXX_FLAGS='-fsanitize=undefined -fno-sanitize-recover=all '"
- pip install -e . -v
+ pip install -e ".[reference]" -v
displayName: 'Install ONNX and dependencies'
- script: |
diff --git a/.azure-pipelines/MacOS-CI.yml b/.azure-pipelines/MacOS-CI.yml
index d924799dd65..4fb761b9849 100644
--- a/.azure-pipelines/MacOS-CI.yml
+++ b/.azure-pipelines/MacOS-CI.yml
@@ -62,7 +62,7 @@ jobs:
if [ '$(onnx_lite)' == '1' ]; then
export CMAKE_ARGS="${CMAKE_ARGS} -DONNX_USE_LITE_PROTO=ON"
fi
- pip install -e . -v
+ pip install -e ".[reference]" -v
displayName: 'Install dependencies and ONNX'
- script: |
diff --git a/.azure-pipelines/Windows-CI.yml b/.azure-pipelines/Windows-CI.yml
index 9513d3948c9..44b345ae05b 100644
--- a/.azure-pipelines/Windows-CI.yml
+++ b/.azure-pipelines/Windows-CI.yml
@@ -64,7 +64,7 @@ jobs:
set CMAKE_ARGS=-DONNX_USE_PROTOBUF_SHARED_LIBS=OFF -DONNX_USE_LITE_PROTO=ON -DONNX_WERROR=OFF
)
- pip install -e . -v
+ pip install -e ".[reference]" -v
pytest
IF NOT %ERRORLEVEL% EQU 0 (
@echo "pytest failed"
diff --git a/.github/workflows/release_win.yml b/.github/workflows/release_win.yml
index 1ec67db8ff8..ce25e42d2e2 100644
--- a/.github/workflows/release_win.yml
+++ b/.github/workflows/release_win.yml
@@ -49,6 +49,9 @@ jobs:
run: |
python -m pip install -q --upgrade pip
cd onnx
+ if ('${{ matrix.architecture }}' -eq 'x86') {
+ sed -i '' '/google-re2/d' requirements-release.txt
+ }
python -m pip install -q -r requirements-release.txt
- name: Build ONNX wheel
diff --git a/README.md b/README.md
index 4af36235cde..08a46e70c6b 100644
--- a/README.md
+++ b/README.md
@@ -82,7 +82,7 @@ A roadmap process takes place every year. More details can be found [here](https
ONNX released packages are published in PyPi.
```sh
-pip install onnx
+pip install onnx # or pip install onnx[reference] for optional reference implementation dependencies
```
[ONNX weekly packages](https://pypi.org/project/onnx-weekly/) are published in PyPI to enable experimentation and early testing.
diff --git a/docs/Changelog.md b/docs/Changelog.md
index aec36b1477a..9392f592666 100644
--- a/docs/Changelog.md
+++ b/docs/Changelog.md
@@ -24092,6 +24092,44 @@ This version of the operator has been available since version 20 of the default
Constrain grid types to float tensors.
+### **RegexFullMatch-20**
+
+ RegexFullMatch performs a full regex match on each element of the input tensor. If an element fully matches the regex pattern specified as an attribute, the corresponding element in the output is True and it is False otherwise. [RE2](https://github.com/google/re2/wiki/Syntax) regex syntax is used.
+
+#### Version
+
+This version of the operator has been available since version 20 of the default ONNX operator set.
+
+#### Attributes
+
+
+- pattern : string
+- Regex pattern to match on. This must be valid RE2 syntax.
+
+
+#### Inputs
+
+
+- X (non-differentiable) : T1
+- Tensor with strings to match on.
+
+
+#### Outputs
+
+
+- Y (non-differentiable) : T2
+- Tensor of bools indicating if each input string fully matches the regex pattern specified.
+
+
+#### Type Constraints
+
+
+- T1 : tensor(string)
+- Inputs must be UTF-8 strings
+- T2 : tensor(bool)
+- Outputs are bools and are True where there is a full regex match and False otherwise.
+
+
### **StringConcat-20**
StringConcat concatenates string tensors elementwise (with NumPy-style broadcasting support)
diff --git a/docs/Operators.md b/docs/Operators.md
index 7cff608a3c1..56054753755 100644
--- a/docs/Operators.md
+++ b/docs/Operators.md
@@ -118,6 +118,7 @@ For an operator input/output's differentiability, it can be differentiable,
|ReduceMin|18, 13, 12, 11, 1|
|ReduceProd|18, 13, 11, 1|
|ReduceSum|13, 11, 1|
+|RegexFullMatch|20|
|Reshape|19, 14, 13, 5, 1|
|Resize|19, 18, 13, 11, 10|
|ReverseSequence|10|
@@ -22588,6 +22589,121 @@ expect(
+### **RegexFullMatch**
+
+ RegexFullMatch performs a full regex match on each element of the input tensor. If an element fully matches the regex pattern specified as an attribute, the corresponding element in the output is True and it is False otherwise. [RE2](https://github.com/google/re2/wiki/Syntax) regex syntax is used.
+
+#### Version
+
+This version of the operator has been available since version 20 of the default ONNX operator set.
+
+#### Attributes
+
+
+- pattern : string
+- Regex pattern to match on. This must be valid RE2 syntax.
+
+
+#### Inputs
+
+
+- X (non-differentiable) : T1
+- Tensor with strings to match on.
+
+
+#### Outputs
+
+
+- Y (non-differentiable) : T2
+- Tensor of bools indicating if each input string fully matches the regex pattern specified.
+
+
+#### Type Constraints
+
+
+- T1 : tensor(string)
+- Inputs must be UTF-8 strings
+- T2 : tensor(bool)
+- Outputs are bools and are True where there is a full regex match and False otherwise.
+
+
+
+#### Examples
+
+
+basic
+
+```python
+node = onnx.helper.make_node(
+ "RegexFullMatch",
+ inputs=["X"],
+ outputs=["Y"],
+ pattern=r"www\.[\w.-]+\.\bcom\b",
+)
+
+x = np.array(["www.google.com", "www.facebook.com", "www.bbc.co.uk"]).astype(
+ object
+)
+result = np.array([True, True, False])
+expect(node, inputs=[x], outputs=[result], name="test_regex_full_match_basic")
+```
+
+
+
+
+
+match_email_domain
+
+```python
+node = onnx.helper.make_node(
+ "RegexFullMatch",
+ inputs=["X"],
+ outputs=["Y"],
+ pattern=r"(\W|^)[\w.\-]{0,25}@(yahoo|gmail)\.com(\W|$)",
+)
+
+x = np.array(
+ [
+ ["account@gmail.com", "account@hotmail.com"],
+ ["not email", "account2@yahoo.com"],
+ ]
+).astype(object)
+result = np.array([[True, False], [False, True]])
+expect(
+ node,
+ inputs=[x],
+ outputs=[result],
+ name="test_regex_full_match_email_domain",
+)
+```
+
+
+
+
+
+match_empty
+
+```python
+node = onnx.helper.make_node(
+ "RegexFullMatch",
+ inputs=["X"],
+ outputs=["Y"],
+ pattern=r"(\W|^)[\w.\-]{0,25}@(yahoo|gmail)\.com(\W|$)",
+)
+
+x = np.array([[], []]).astype(object)
+result = np.array([[], []]).astype(bool)
+expect(
+ node,
+ inputs=[x],
+ outputs=[result],
+ name="test_regex_full_match_empty",
+)
+```
+
+
+
+
### **Relu**
Relu takes one input data (Tensor) and produces one output data
diff --git a/docs/TestCoverage.md b/docs/TestCoverage.md
index 3ab6b75f1c1..3e2a2d8e77b 100644
--- a/docs/TestCoverage.md
+++ b/docs/TestCoverage.md
@@ -6,7 +6,7 @@
* [Overall Test Coverage](#overall-test-coverage)
# Node Test Coverage
## Summary
-Node tests have covered 177/190 (93.16%, 5 generators excluded) common operators.
+Node tests have covered 178/191 (93.19%, 5 generators excluded) common operators.
Node tests have covered 0/0 (N/A) experimental operators.
@@ -15258,6 +15258,78 @@ expect(
+### RegexFullMatch
+There are 3 test cases, listed as following:
+
+basic
+
+```python
+node = onnx.helper.make_node(
+ "RegexFullMatch",
+ inputs=["X"],
+ outputs=["Y"],
+ pattern=r"www\.[\w.-]+\.\bcom\b",
+)
+
+x = np.array(["www.google.com", "www.facebook.com", "www.bbc.co.uk"]).astype(
+ object
+)
+result = np.array([True, True, False])
+expect(node, inputs=[x], outputs=[result], name="test_regex_full_match_basic")
+```
+
+
+
+match_email_domain
+
+```python
+node = onnx.helper.make_node(
+ "RegexFullMatch",
+ inputs=["X"],
+ outputs=["Y"],
+ pattern=r"(\W|^)[\w.\-]{0,25}@(yahoo|gmail)\.com(\W|$)",
+)
+
+x = np.array(
+ [
+ ["account@gmail.com", "account@hotmail.com"],
+ ["not email", "account2@yahoo.com"],
+ ]
+).astype(object)
+result = np.array([[True, False], [False, True]])
+expect(
+ node,
+ inputs=[x],
+ outputs=[result],
+ name="test_regex_full_match_email_domain",
+)
+```
+
+
+
+match_empty
+
+```python
+node = onnx.helper.make_node(
+ "RegexFullMatch",
+ inputs=["X"],
+ outputs=["Y"],
+ pattern=r"(\W|^)[\w.\-]{0,25}@(yahoo|gmail)\.com(\W|$)",
+)
+
+x = np.array([[], []]).astype(object)
+result = np.array([[], []]).astype(bool)
+expect(
+ node,
+ inputs=[x],
+ outputs=[result],
+ name="test_regex_full_match_empty",
+)
+```
+
+
+
+
### Relu
There are 1 test cases, listed as following:
diff --git a/onnx/backend/test/case/node/regex_full_match.py b/onnx/backend/test/case/node/regex_full_match.py
new file mode 100644
index 00000000000..5fe70dba13c
--- /dev/null
+++ b/onnx/backend/test/case/node/regex_full_match.py
@@ -0,0 +1,67 @@
+# Copyright (c) ONNX Project Contributors
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import numpy as np
+
+import onnx
+from onnx.backend.test.case.base import Base
+from onnx.backend.test.case.node import expect
+
+
+class RegexFullMatch(Base):
+ @staticmethod
+ def export_basic() -> None:
+ node = onnx.helper.make_node(
+ "RegexFullMatch",
+ inputs=["X"],
+ outputs=["Y"],
+ pattern=r"www\.[\w.-]+\.\bcom\b",
+ )
+
+ x = np.array(["www.google.com", "www.facebook.com", "www.bbc.co.uk"]).astype(
+ object
+ )
+ result = np.array([True, True, False])
+ expect(node, inputs=[x], outputs=[result], name="test_regex_full_match_basic")
+
+ @staticmethod
+ def export_match_email_domain() -> None:
+ node = onnx.helper.make_node(
+ "RegexFullMatch",
+ inputs=["X"],
+ outputs=["Y"],
+ pattern=r"(\W|^)[\w.\-]{0,25}@(yahoo|gmail)\.com(\W|$)",
+ )
+
+ x = np.array(
+ [
+ ["account@gmail.com", "account@hotmail.com"],
+ ["not email", "account2@yahoo.com"],
+ ]
+ ).astype(object)
+ result = np.array([[True, False], [False, True]])
+ expect(
+ node,
+ inputs=[x],
+ outputs=[result],
+ name="test_regex_full_match_email_domain",
+ )
+
+ @staticmethod
+ def export_match_empty() -> None:
+ node = onnx.helper.make_node(
+ "RegexFullMatch",
+ inputs=["X"],
+ outputs=["Y"],
+ pattern=r"(\W|^)[\w.\-]{0,25}@(yahoo|gmail)\.com(\W|$)",
+ )
+
+ x = np.array([[], []]).astype(object)
+ result = np.array([[], []]).astype(bool)
+ expect(
+ node,
+ inputs=[x],
+ outputs=[result],
+ name="test_regex_full_match_empty",
+ )
diff --git a/onnx/backend/test/data/node/test_regex_full_match_basic/model.onnx b/onnx/backend/test/data/node/test_regex_full_match_basic/model.onnx
new file mode 100644
index 00000000000..539f0ef722c
Binary files /dev/null and b/onnx/backend/test/data/node/test_regex_full_match_basic/model.onnx differ
diff --git a/onnx/backend/test/data/node/test_regex_full_match_basic/test_data_set_0/input_0.pb b/onnx/backend/test/data/node/test_regex_full_match_basic/test_data_set_0/input_0.pb
new file mode 100644
index 00000000000..8b647d880a7
--- /dev/null
+++ b/onnx/backend/test/data/node/test_regex_full_match_basic/test_data_set_0/input_0.pb
@@ -0,0 +1 @@
+2www.google.com2www.facebook.com2
www.bbc.co.ukBX
\ No newline at end of file
diff --git a/onnx/backend/test/data/node/test_regex_full_match_basic/test_data_set_0/output_0.pb b/onnx/backend/test/data/node/test_regex_full_match_basic/test_data_set_0/output_0.pb
new file mode 100644
index 00000000000..b676b4374d1
Binary files /dev/null and b/onnx/backend/test/data/node/test_regex_full_match_basic/test_data_set_0/output_0.pb differ
diff --git a/onnx/backend/test/data/node/test_regex_full_match_email_domain/model.onnx b/onnx/backend/test/data/node/test_regex_full_match_email_domain/model.onnx
new file mode 100644
index 00000000000..bf03e4a2dee
Binary files /dev/null and b/onnx/backend/test/data/node/test_regex_full_match_email_domain/model.onnx differ
diff --git a/onnx/backend/test/data/node/test_regex_full_match_email_domain/test_data_set_0/input_0.pb b/onnx/backend/test/data/node/test_regex_full_match_email_domain/test_data_set_0/input_0.pb
new file mode 100644
index 00000000000..9eea4c75e8e
--- /dev/null
+++ b/onnx/backend/test/data/node/test_regex_full_match_email_domain/test_data_set_0/input_0.pb
@@ -0,0 +1 @@
+2account@gmail.com2account@hotmail.com2 not email2account2@yahoo.comBX
\ No newline at end of file
diff --git a/onnx/backend/test/data/node/test_regex_full_match_email_domain/test_data_set_0/output_0.pb b/onnx/backend/test/data/node/test_regex_full_match_email_domain/test_data_set_0/output_0.pb
new file mode 100644
index 00000000000..d2ce9b09727
Binary files /dev/null and b/onnx/backend/test/data/node/test_regex_full_match_email_domain/test_data_set_0/output_0.pb differ
diff --git a/onnx/backend/test/data/node/test_regex_full_match_empty/model.onnx b/onnx/backend/test/data/node/test_regex_full_match_empty/model.onnx
new file mode 100644
index 00000000000..59486862a30
Binary files /dev/null and b/onnx/backend/test/data/node/test_regex_full_match_empty/model.onnx differ
diff --git a/onnx/backend/test/data/node/test_regex_full_match_empty/test_data_set_0/input_0.pb b/onnx/backend/test/data/node/test_regex_full_match_empty/test_data_set_0/input_0.pb
new file mode 100644
index 00000000000..998c963c06e
Binary files /dev/null and b/onnx/backend/test/data/node/test_regex_full_match_empty/test_data_set_0/input_0.pb differ
diff --git a/onnx/backend/test/data/node/test_regex_full_match_empty/test_data_set_0/output_0.pb b/onnx/backend/test/data/node/test_regex_full_match_empty/test_data_set_0/output_0.pb
new file mode 100644
index 00000000000..0ac37705fe2
Binary files /dev/null and b/onnx/backend/test/data/node/test_regex_full_match_empty/test_data_set_0/output_0.pb differ
diff --git a/onnx/defs/operator_sets.h b/onnx/defs/operator_sets.h
index e305708875d..7b500b01b32 100644
--- a/onnx/defs/operator_sets.h
+++ b/onnx/defs/operator_sets.h
@@ -1107,6 +1107,7 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 20, GridSample);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 20, Gelu);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 20, ConstantOfShape);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 20, StringConcat);
+class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 20, RegexFullMatch);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 20, StringSplit);
// Iterate over schema from ai.onnx version 20
@@ -1118,6 +1119,7 @@ class OpSet_Onnx_ver20 {
fn(GetOpSchema());
fn(GetOpSchema());
fn(GetOpSchema());
+ fn(GetOpSchema());
fn(GetOpSchema());
}
};
diff --git a/onnx/defs/text/defs.cc b/onnx/defs/text/defs.cc
index 01fc70b3237..5ff70cec265 100644
--- a/onnx/defs/text/defs.cc
+++ b/onnx/defs/text/defs.cc
@@ -33,6 +33,34 @@ ONNX_OPERATOR_SET_SCHEMA(
*ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape());
}));
+static const char* RegexFullMatch_doc =
+ R"DOC(RegexFullMatch performs a full regex match on each element of the input tensor. If an element fully matches the regex pattern specified as an attribute, the corresponding element in the output is True and it is False otherwise. [RE2](https://github.com/google/re2/wiki/Syntax) regex syntax is used.)DOC";
+ONNX_OPERATOR_SET_SCHEMA(
+ RegexFullMatch,
+ 20,
+ OpSchema()
+ .Input(0, "X", "Tensor with strings to match on.", "T1", OpSchema::Single, true, 1, OpSchema::NonDifferentiable)
+ .Attr("pattern", "Regex pattern to match on. This must be valid RE2 syntax.", AttributeProto::STRING, false)
+ .Output(
+ 0,
+ "Y",
+ "Tensor of bools indicating if each input string fully matches the regex pattern specified.",
+ "T2",
+ OpSchema::Single,
+ true,
+ 1,
+ OpSchema::NonDifferentiable)
+ .TypeConstraint("T1", {"tensor(string)"}, "Inputs must be UTF-8 strings")
+ .TypeConstraint(
+ "T2",
+ {"tensor(bool)"},
+ "Outputs are bools and are True where there is a full regex match and False otherwise.")
+ .SetDoc(RegexFullMatch_doc)
+ .TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
+ updateOutputElemType(ctx, 0, TensorProto::BOOL);
+ propagateShapeFromInputToOutput(ctx, 0, 0);
+ }));
+
static const char* StringSplit_doc =
R"DOC(StringSplit splits a string tensor's elements into substrings based on a delimiter attribute and a maxsplit attribute.
diff --git a/onnx/reference/ops/_op_list.py b/onnx/reference/ops/_op_list.py
index be0925741a8..5caabeed7ce 100644
--- a/onnx/reference/ops/_op_list.py
+++ b/onnx/reference/ops/_op_list.py
@@ -176,6 +176,7 @@
ReduceSumSquare_1,
ReduceSumSquare_18,
)
+from onnx.reference.ops.op_regex_full_match import RegexFullMatch
from onnx.reference.ops.op_relu import Relu
from onnx.reference.ops.op_reshape import Reshape_5, Reshape_14
from onnx.reference.ops.op_resize import Resize
diff --git a/onnx/reference/ops/op_regex_full_match.py b/onnx/reference/ops/op_regex_full_match.py
new file mode 100644
index 00000000000..c3d65849113
--- /dev/null
+++ b/onnx/reference/ops/op_regex_full_match.py
@@ -0,0 +1,34 @@
+# Copyright (c) ONNX Project Contributors
+
+# SPDX-License-Identifier: Apache-2.0
+# pylint: disable=R0912,R0913,W0221
+
+import numpy as np
+
+from onnx.reference.op_run import OpRun
+
+_acceptable_str_dtypes = ("U", "O")
+
+
+class RegexFullMatch(OpRun):
+ def _run(self, x, pattern=None):
+ try:
+ # pylint: disable=import-outside-toplevel`
+ import re2
+ except ImportError as e:
+ raise ImportError(
+ "re2 must be installed to use the reference implementation of the RegexFullMatch operator"
+ ) from e
+
+ # As per onnx/mapping.py, object numpy dtype corresponds to TensorProto.STRING
+ if x.dtype.kind not in _acceptable_str_dtypes:
+ raise TypeError(f"Input must be string tensor, received dtype {x.dtype}")
+ try:
+ regex = re2.compile(pattern)
+ except re2.error as e:
+ raise ValueError(f"Invalid regex pattern {pattern!r}") from e
+
+ fullmatch_func = np.vectorize(
+ lambda x: regex.fullmatch(x) is not None, otypes=[np.bool_]
+ )
+ return (fullmatch_func(x),)
diff --git a/onnx/test/automatic_upgrade_test.py b/onnx/test/automatic_upgrade_test.py
index 4b820ab834a..8fa6dea19c5 100644
--- a/onnx/test/automatic_upgrade_test.py
+++ b/onnx/test/automatic_upgrade_test.py
@@ -1805,6 +1805,16 @@ def test_StringConcat(self) -> None:
[[2, 3]],
)
+ def test_RegexFullMatch(self) -> None:
+ self._test_op_upgrade(
+ "RegexFullMatch",
+ 20,
+ [[2, 3]],
+ [[2, 3]],
+ [TensorProto.STRING],
+ [TensorProto.BOOL],
+ )
+
def test_ops_tested(self) -> None:
all_schemas = onnx.defs.get_all_schemas()
all_op_names = [schema.name for schema in all_schemas if schema.domain == ""]
diff --git a/onnx/test/reference_evaluator_backend_test.py b/onnx/test/reference_evaluator_backend_test.py
index 3b4cd63d4d2..65336296cc8 100644
--- a/onnx/test/reference_evaluator_backend_test.py
+++ b/onnx/test/reference_evaluator_backend_test.py
@@ -22,6 +22,7 @@
import os
import pprint
+import sys
import unittest
try:
@@ -82,6 +83,12 @@
"test_castlike_FLOAT_to_BFLOAT16",
"test_castlike_FLOAT_to_BFLOAT16_expanded",
}
+if sys.platform == "win32":
+ SKIP_TESTS |= {
+ "test_regex_full_match_basic",
+ "test_regex_full_match_email_domain",
+ "test_regex_full_match_empty",
+ }
def assert_allclose_string(expected, value):
diff --git a/onnx/test/reference_evaluator_test.py b/onnx/test/reference_evaluator_test.py
index b659b5e0409..6b56d9d3d7c 100644
--- a/onnx/test/reference_evaluator_test.py
+++ b/onnx/test/reference_evaluator_test.py
@@ -13,6 +13,7 @@
import itertools
import math
+import sys
import unittest
from contextlib import redirect_stdout
from functools import wraps
@@ -4949,6 +4950,59 @@ def test_string_split(
num_splits, np.array(expected_num_splits, dtype=np.int64)
)
+ @parameterized.parameterized.expand(
+ [
+ (
+ ["www.google.com", "www.facebook.com", "www.bbc.co.uk"],
+ r"www\.[\w.-]+\.\bcom\b",
+ [True, True, False],
+ (3,),
+ ),
+ (
+ [["Onnx", "tensorflow", "Numpy"], ["Pytorch", "Cython", "numba"]],
+ r"^[A-Z][a-z]*$",
+ [[True, False, True], [True, True, False]],
+ (2, 3),
+ ),
+ (
+ [
+ "account@gmail.com",
+ "account@hotmail.com",
+ "not email",
+ "account2@yahoo.com",
+ ],
+ r"(\W|^)[\w.\-]{0,25}@(yahoo|gmail)\.com(\W|$)",
+ [True, False, False, True],
+ (4,),
+ ),
+ ]
+ )
+ @unittest.skipIf(
+ sys.platform == "win32", "google-re2 package is not built for win32"
+ )
+ def test_regex_full_match(self, x, pattern, expected, expected_shape):
+ X = make_tensor_value_info("X", TensorProto.STRING, None)
+ Y = make_tensor_value_info("Y", TensorProto.BOOL, None)
+ node = make_node("RegexFullMatch", inputs=["X"], outputs=["Y"], pattern=pattern)
+ model = make_model(make_graph([node], "g", [X], [Y]))
+ ref = ReferenceEvaluator(model)
+ result, *_ = ref.run(None, {"X": np.array(x)})
+ np.testing.assert_array_equal(result, expected)
+ self.assertEqual(result.dtype.kind, "b")
+ self.assertEqual(result.shape, expected_shape)
+
+ @unittest.skipIf(
+ sys.platform == "win32", "google-re2 package is not built for win32"
+ )
+ def test_regex_invalid_pattern(self):
+ X = make_tensor_value_info("X", TensorProto.STRING, None)
+ Y = make_tensor_value_info("Y", TensorProto.BOOL, None)
+ node = make_node("RegexFullMatch", inputs=["X"], outputs=["Y"], pattern="x)")
+ model = make_model(make_graph([node], "g", [X], [Y]))
+ ref = ReferenceEvaluator(model)
+ with self.assertRaises(ValueError):
+ ref.run(None, {"X": np.array(["x"])})
+
if __name__ == "__main__":
unittest.main(verbosity=2)
diff --git a/onnx/test/shape_inference_test.py b/onnx/test/shape_inference_test.py
index 12090283dc6..3cae5e97957 100644
--- a/onnx/test/shape_inference_test.py
+++ b/onnx/test/shape_inference_test.py
@@ -1622,6 +1622,32 @@ def test_stringconcat_broadcasting(self, _, version) -> None:
opset_imports=[helper.make_opsetid(ONNX_DOMAIN, version)],
)
+ @parameterized.expand(all_versions_for("RegexFullMatch"))
+ def test_regex_full_match(self, _, version) -> None:
+ graph = self._make_graph(
+ [("x", TensorProto.STRING, (2, 4, 3, 9))],
+ [make_node("RegexFullMatch", ["x"], ["y"], pattern=r"^[A-Z][a-z]*$")],
+ [],
+ )
+ self._assert_inferred(
+ graph,
+ [make_tensor_value_info("y", TensorProto.BOOL, (2, 4, 3, 9))],
+ opset_imports=[helper.make_opsetid(ONNX_DOMAIN, version)],
+ )
+
+ @parameterized.expand(all_versions_for("RegexFullMatch"))
+ def test_regex_full_match_empty_shape(self, _, version) -> None:
+ graph = self._make_graph(
+ [("x", TensorProto.STRING, ())],
+ [make_node("RegexFullMatch", ["x"], ["y"], pattern=r"^[A-Z][a-z]*$")],
+ [],
+ )
+ self._assert_inferred(
+ graph,
+ [make_tensor_value_info("y", TensorProto.BOOL, ())],
+ opset_imports=[helper.make_opsetid(ONNX_DOMAIN, version)],
+ )
+
def test_unsqueeze_regular(self) -> None:
graph = self._make_graph(
[("x", TensorProto.FLOAT, (3, 2)), ("axes", TensorProto.INT64, (4,))],
diff --git a/onnx/test/test_backend_onnxruntime.py b/onnx/test/test_backend_onnxruntime.py
index de0b08fe802..974e7e0c378 100644
--- a/onnx/test/test_backend_onnxruntime.py
+++ b/onnx/test/test_backend_onnxruntime.py
@@ -252,6 +252,7 @@ def run_node(cls, node, inputs, device=None, outputs_info=None, **kwargs):
"|equal"
"|identity"
"|reshape"
+ "|regex_full_match"
"|string_split"
"|string_concat"
"|gelu"
diff --git a/onnx/test/test_backend_reference.py b/onnx/test/test_backend_reference.py
index 0beafa9e59f..3627060f4bd 100644
--- a/onnx/test/test_backend_reference.py
+++ b/onnx/test/test_backend_reference.py
@@ -4,6 +4,7 @@
import os
import platform
+import sys
import unittest
from typing import Any
@@ -172,6 +173,12 @@ def run_node(cls, node, inputs, device=None, outputs_info=None, **kwargs):
# The following tests fail due to discrepancies (small but still higher than 1e-7).
backend_test.exclude("test_adam_multiple") # 1e-2
+# Currently google-re2 is not supported on Win32 and is required for the reference implementation of RegexFullMatch.
+if sys.platform == "win32":
+ backend_test.exclude("test_regex_full_match_basic_cpu")
+ backend_test.exclude("test_regex_full_match_email_domain_cpu")
+ backend_test.exclude("test_regex_full_match_empty_cpu")
+
# import all test cases at global scope to make them visible to python.unittest
globals().update(backend_test.test_cases)
diff --git a/requirements-dev.txt b/requirements-dev.txt
index 05cbbec5bfa..66aa8f93a9d 100644
--- a/requirements-dev.txt
+++ b/requirements-dev.txt
@@ -11,4 +11,6 @@ twine
# Dependencies for linting. Versions match those in setup.py.
lintrunner>=0.10.7
lintrunner-adapters>=0.8
+# Dependencies for the reference implementation.
+google-re2
# Edit additional linter dependencies in requirements-lintrunner.txt
diff --git a/requirements-reference.txt b/requirements-reference.txt
new file mode 100644
index 00000000000..2347a47f5cb
--- /dev/null
+++ b/requirements-reference.txt
@@ -0,0 +1 @@
+google-re2
diff --git a/requirements-release.txt b/requirements-release.txt
index 96c766cfef4..fee00ee032e 100644
--- a/requirements-release.txt
+++ b/requirements-release.txt
@@ -7,3 +7,4 @@ wheel
setuptools
twine
parameterized
+google-re2
diff --git a/setup.py b/setup.py
index ad5981d1928..4ff8d769100 100644
--- a/setup.py
+++ b/setup.py
@@ -355,6 +355,12 @@ def build_extensions(self):
"lintrunner-adapters>=0.3",
]
+if not os.path.exists("requirements-reference.txt"):
+ raise FileNotFoundError("Unable to find requirements-reference.txt")
+
+with open("requirements-reference.txt") as f:
+ extras_require["reference"] = f.read().splitlines()
+
################################################################################
# Final
################################################################################