Add tf.regex_match for regex match support (#19160)
authorYong Tang <yong.tang.github@outlook.com>
Tue, 15 May 2018 00:58:36 +0000 (17:58 -0700)
committerRasmus Munk Larsen <rmlarsen@google.com>
Tue, 15 May 2018 00:58:36 +0000 (17:58 -0700)
* Add tf.regex_match for regex match support

This fix tries to address the issue raised in 18264.
Currently tf.regex_replace has already been supported
though there was no regex match support.
This fix adds the tf.regex_match support in a similiar
pattern as tf.regex_replace.

This fix fixes 18264.

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
* Update BUILD file for the tf.regex_match kernel

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
* Register RegexMatch ops

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
* Add test cases for tf.regex_match

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
* Update api_defs

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
* Update API golden

update with:
```
bazel-bin/tensorflow/tools/api/tests/api_compatibility_test
           --update_goldens True
```

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
* Expose regex_full_match in tf.strings namespace

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
* Update golden API

```
bazel-bin/tensorflow/tools/api/tests/api_compatibility_test
           --update_goldens True
```

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
tensorflow/core/api_def/base_api/api_def_RegexFullMatch.pbtxt [new file with mode: 0644]
tensorflow/core/api_def/python_api/api_def_RegexFullMatch.pbtxt [new file with mode: 0644]
tensorflow/core/kernels/BUILD
tensorflow/core/kernels/regex_full_match_op.cc [new file with mode: 0644]
tensorflow/core/ops/string_ops.cc
tensorflow/python/kernel_tests/BUILD
tensorflow/python/kernel_tests/regex_full_match_op_test.py [new file with mode: 0644]
tensorflow/python/ops/string_ops.py
tensorflow/tools/api/generator/BUILD
tensorflow/tools/api/golden/tensorflow.pbtxt
tensorflow/tools/api/golden/tensorflow.strings.pbtxt [new file with mode: 0644]

diff --git a/tensorflow/core/api_def/base_api/api_def_RegexFullMatch.pbtxt b/tensorflow/core/api_def/base_api/api_def_RegexFullMatch.pbtxt
new file mode 100644 (file)
index 0000000..8cef243
--- /dev/null
@@ -0,0 +1,30 @@
+op {
+  graph_op_name: "RegexFullMatch"
+  in_arg {
+    name: "input"
+    description: <<END
+A string tensor of the text to be processed.
+END
+  }
+  in_arg {
+    name: "pattern"
+    description: <<END
+A 1-D string tensor of the regular expression to match the input.
+END
+  }
+  out_arg {
+    name: "output"
+    description: <<END
+A bool tensor with the same shape as `input`.
+END
+  }
+  summary: "Check if the input matches the regex pattern."
+  description: <<END
+The input is a string tensor of any shape. The pattern is a scalar
+string tensor which is applied to every element of the input tensor.
+The boolean values (True or False) of the output tensor indicate
+if the input matches the regex pattern provided.
+
+The pattern follows the re2 syntax (https://github.com/google/re2/wiki/Syntax)
+END
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_RegexFullMatch.pbtxt b/tensorflow/core/api_def/python_api/api_def_RegexFullMatch.pbtxt
new file mode 100644 (file)
index 0000000..ec310c8
--- /dev/null
@@ -0,0 +1,4 @@
+op {
+  graph_op_name: "RegexFullMatch"
+  visibility: HIDDEN
+}
index 3fb03cd..d6496c5 100644 (file)
@@ -4249,6 +4249,7 @@ cc_library(
         ":as_string_op",
         ":base64_ops",
         ":reduce_join_op",
+        ":regex_full_match_op",
         ":regex_replace_op",
         ":string_join_op",
         ":string_split_op",
@@ -4286,6 +4287,12 @@ tf_kernel_library(
 )
 
 tf_kernel_library(
+    name = "regex_full_match_op",
+    prefix = "regex_full_match_op",
+    deps = STRING_DEPS + ["@com_googlesource_code_re2//:re2"],
+)
+
+tf_kernel_library(
     name = "regex_replace_op",
     prefix = "regex_replace_op",
     deps = STRING_DEPS + ["@com_googlesource_code_re2//:re2"],
diff --git a/tensorflow/core/kernels/regex_full_match_op.cc b/tensorflow/core/kernels/regex_full_match_op.cc
new file mode 100644 (file)
index 0000000..5863a2c
--- /dev/null
@@ -0,0 +1,59 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <string>
+
+#include "re2/re2.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+
+class RegexFullMatchOp : public OpKernel {
+ public:
+  explicit RegexFullMatchOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+  void Compute(OpKernelContext* ctx) override {
+    const Tensor* input_tensor;
+    OP_REQUIRES_OK(ctx, ctx->input("input", &input_tensor));
+    const auto& input_flat = input_tensor->flat<string>();
+
+    const Tensor* pattern_tensor;
+    OP_REQUIRES_OK(ctx, ctx->input("pattern", &pattern_tensor));
+    OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(pattern_tensor->shape()),
+                errors::InvalidArgument("Pattern must be scalar, but received ",
+                                        pattern_tensor->shape().DebugString()));
+    const string pattern = pattern_tensor->flat<string>()(0);
+    const RE2 match(pattern);
+    OP_REQUIRES(ctx, match.ok(),
+                errors::InvalidArgument("Invalid pattern: ", pattern,
+                                        ", error: ", match.error()));
+
+    Tensor* output_tensor = nullptr;
+    OP_REQUIRES_OK(ctx, ctx->allocate_output("output", input_tensor->shape(),
+                                             &output_tensor));
+    auto output_flat = output_tensor->flat<bool>();
+    for (size_t i = 0; i < input_flat.size(); ++i) {
+      output_flat(i) = RE2::FullMatch(input_flat(i), match);
+    }
+  }
+};
+
+REGISTER_KERNEL_BUILDER(Name("RegexFullMatch").Device(DEVICE_CPU),
+                        RegexFullMatchOp);
+
+}  // namespace tensorflow
index 469f193..1d5c743 100644 (file)
@@ -37,6 +37,17 @@ REGISTER_OP("RegexReplace")
       return Status::OK();
     });
 
+REGISTER_OP("RegexFullMatch")
+    .Input("input: string")
+    .Input("pattern: string")
+    .Output("output: bool")
+    .SetShapeFn([](InferenceContext* c) {
+      ShapeHandle unused;
+      TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
+      c->set_output(0, c->input(0));
+      return Status::OK();
+    });
+
 REGISTER_OP("StringToHashBucketFast")
     .Input("input: string")
     .Output("output: int64")
index c892b6e..ef0d728 100644 (file)
@@ -742,6 +742,18 @@ tf_py_test(
 )
 
 tf_py_test(
+    name = "regex_full_match_op_test",
+    size = "small",
+    srcs = ["regex_full_match_op_test.py"],
+    additional_deps = [
+        "//tensorflow/python:client_testlib",
+        "//tensorflow/python:constant_op",
+        "//tensorflow/python:dtypes",
+        "//tensorflow/python:string_ops",
+    ],
+)
+
+tf_py_test(
     name = "save_restore_ops_test",
     size = "small",
     srcs = ["save_restore_ops_test.py"],
diff --git a/tensorflow/python/kernel_tests/regex_full_match_op_test.py b/tensorflow/python/kernel_tests/regex_full_match_op_test.py
new file mode 100644 (file)
index 0000000..5daae1b
--- /dev/null
@@ -0,0 +1,54 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for RegexFullMatch op from string_ops."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import string_ops
+from tensorflow.python.platform import test
+
+
+class RegexFullMatchOpTest(test.TestCase):
+
+  def testRegexFullMatch(self):
+    values = ["abaaba", "abcdabcde"]
+    with self.test_session():
+      input_vector = constant_op.constant(values, dtypes.string)
+      matched = string_ops.regex_full_match(input_vector, "a.*a").eval()
+      self.assertAllEqual([True, False], matched)
+
+  def testEmptyMatch(self):
+    values = ["abc", "1"]
+    with self.test_session():
+      input_vector = constant_op.constant(values, dtypes.string)
+      matched = string_ops.regex_full_match(input_vector, "").eval()
+      self.assertAllEqual([False, False], matched)
+
+  def testInvalidPattern(self):
+    values = ["abc", "1"]
+    with self.test_session():
+      input_vector = constant_op.constant(values, dtypes.string)
+      invalid_pattern = "A["
+      matched = string_ops.regex_full_match(input_vector, invalid_pattern)
+      with self.assertRaisesOpError("Invalid pattern"):
+        matched.eval()
+
+
+if __name__ == "__main__":
+  test.main()
index 9f58c6a..baf169b 100644 (file)
@@ -39,6 +39,8 @@ from tensorflow.python.util import deprecation
 from tensorflow.python.util.tf_export import tf_export
 # pylint: enable=wildcard-import
 
+# Expose regex_full_match in strings namespace
+tf_export("strings.regex_full_match")(regex_full_match)
 
 @tf_export("string_split")
 def string_split(source, delimiter=" ", skip_empty=True):  # pylint: disable=invalid-name
index a1c5699..edc1e07 100644 (file)
@@ -101,6 +101,7 @@ genrule(
         "api/profiler/__init__.py",
         "api/python_io/__init__.py",
         "api/resource_loader/__init__.py",
+        "api/strings/__init__.py",
         "api/saved_model/__init__.py",
         "api/saved_model/builder/__init__.py",
         "api/saved_model/constants/__init__.py",
index 0b12bc0..823a36d 100644 (file)
@@ -497,6 +497,10 @@ tf_module {
     mtype: "<class \'tensorflow.python.framework.dtypes.DType\'>"
   }
   member {
+    name: "strings"
+    mtype: "<type \'module\'>"
+  }
+  member {
     name: "summary"
     mtype: "<type \'module\'>"
   }
diff --git a/tensorflow/tools/api/golden/tensorflow.strings.pbtxt b/tensorflow/tools/api/golden/tensorflow.strings.pbtxt
new file mode 100644 (file)
index 0000000..a3fbe95
--- /dev/null
@@ -0,0 +1,7 @@
+path: "tensorflow.strings"
+tf_module {
+  member_method {
+    name: "regex_full_match"
+    argspec: "args=[\'input\', \'pattern\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+}