Add RegexReplace Op that internally calls RE2::Replace.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Thu, 1 Mar 2018 14:03:38 +0000 (06:03 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 1 Mar 2018 14:07:59 +0000 (06:07 -0800)
PiperOrigin-RevId: 187467840

tensorflow/core/api_def/base_api/api_def_RegexReplace.pbtxt [new file with mode: 0644]
tensorflow/core/kernels/BUILD
tensorflow/core/kernels/regex_replace_op.cc [new file with mode: 0644]
tensorflow/core/ops/string_ops.cc
tensorflow/python/kernel_tests/BUILD
tensorflow/python/kernel_tests/regex_replace_op_test.py [new file with mode: 0644]
tensorflow/python/ops/string_ops.py
tensorflow/tools/api/golden/tensorflow.pbtxt

diff --git a/tensorflow/core/api_def/base_api/api_def_RegexReplace.pbtxt b/tensorflow/core/api_def/base_api/api_def_RegexReplace.pbtxt
new file mode 100644 (file)
index 0000000..70ad521
--- /dev/null
@@ -0,0 +1,25 @@
+op {
+  graph_op_name: "RegexReplace"
+  in_arg {
+    name: "input"
+    description: "The text to be processed."
+  }
+  in_arg {
+    name: "pattern"
+    description: "The regular expression to match the input."
+  }
+  in_arg {
+    name: "rewrite"
+    description: "The rewrite to be applied to the matched expresion."
+  }
+  out_arg {
+    name: "output"
+    description: "The text after applying pattern and rewrite."
+  }
+  attr {
+    name: "replace_global"
+    description: "If True, the replacement is global, otherwise the replacement\nis done only on the first match."
+  }
+  summary: "Replaces the match of pattern in input with rewrite."
+  description: "It follows the re2 syntax (https://github.com/google/re2/wiki/Syntax)"
+}
index 3426cf6..feacee5 100644 (file)
@@ -4155,6 +4155,7 @@ cc_library(
         ":as_string_op",
         ":base64_ops",
         ":reduce_join_op",
+        ":regex_replace_op",
         ":string_join_op",
         ":string_split_op",
         ":string_to_hash_bucket_op",
@@ -4190,6 +4191,12 @@ tf_kernel_library(
 )
 
 tf_kernel_library(
+    name = "regex_replace_op",
+    prefix = "regex_replace_op",
+    deps = STRING_DEPS + ["@com_googlesource_code_re2//:re2"],
+)
+
+tf_kernel_library(
     name = "string_split_op",
     prefix = "string_split_op",
     deps = STRING_DEPS,
@@ -5063,6 +5070,7 @@ filegroup(
             "scatter_nd_op*",
             "mutex_ops.*",
             "batch_kernels.*",
+            "regex_replace_op.cc",
         ],
     ),
     visibility = ["//visibility:public"],
diff --git a/tensorflow/core/kernels/regex_replace_op.cc b/tensorflow/core/kernels/regex_replace_op.cc
new file mode 100644 (file)
index 0000000..59ec854
--- /dev/null
@@ -0,0 +1,76 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <string>
+
+#include "re2/re2.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+
+class RegexReplaceOp : public OpKernel {
+ public:
+  explicit RegexReplaceOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+    OP_REQUIRES_OK(ctx, ctx->GetAttr("replace_global", &replace_global_));
+  }
+
+  void Compute(OpKernelContext* ctx) override {
+    const Tensor* input_tensor;
+    OP_REQUIRES_OK(ctx, ctx->input("input", &input_tensor));
+    const auto& input_flat = input_tensor->flat<string>();
+
+    const Tensor* pattern_tensor;
+    OP_REQUIRES_OK(ctx, ctx->input("pattern", &pattern_tensor));
+    OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(pattern_tensor->shape()),
+                errors::InvalidArgument("Pattern must be scalar, but received ",
+                                        pattern_tensor->shape().DebugString()));
+    const string pattern = pattern_tensor->flat<string>()(0);
+    const RE2 match(pattern);
+    OP_REQUIRES(ctx, match.ok(),
+                errors::InvalidArgument("Invalid pattern: ", pattern,
+                                        ", error: ", match.error()));
+
+    const Tensor* rewrite_tensor;
+    OP_REQUIRES_OK(ctx, ctx->input("rewrite", &rewrite_tensor));
+    OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(rewrite_tensor->shape()),
+                errors::InvalidArgument("Rewrite must be scalar, but received ",
+                                        rewrite_tensor->shape().DebugString()));
+    const string rewrite = rewrite_tensor->flat<string>()(0);
+
+    Tensor* output_tensor = nullptr;
+    OP_REQUIRES_OK(ctx, ctx->allocate_output("output", input_tensor->shape(),
+                                             &output_tensor));
+    auto output_flat = output_tensor->flat<string>();
+    for (size_t i = 0; i < input_flat.size(); ++i) {
+      output_flat(i) = input_flat(i);
+      if (replace_global_) {
+        RE2::GlobalReplace(&output_flat(i), match, rewrite);
+      } else {
+        RE2::Replace(&output_flat(i), match, rewrite);
+      }
+    }
+  }
+
+ private:
+  bool replace_global_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("RegexReplace").Device(DEVICE_CPU),
+                        RegexReplaceOp);
+
+}  // namespace tensorflow
index e4c5bcf..05f216a 100644 (file)
@@ -23,6 +23,20 @@ using shape_inference::DimensionHandle;
 using shape_inference::InferenceContext;
 using shape_inference::ShapeHandle;
 
+REGISTER_OP("RegexReplace")
+    .Input("input: string")
+    .Input("pattern: string")
+    .Input("rewrite: string")
+    .Output("output: string")
+    .Attr("replace_global: bool = true")
+    .SetShapeFn([](InferenceContext* c) {
+      ShapeHandle unused;
+      TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
+      TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
+      c->set_output(0, c->input(0));
+      return Status::OK();
+    });
+
 REGISTER_OP("StringToHashBucketFast")
     .Input("input: string")
     .Output("output: int64")
index c9aa4a2..0f13e8b 100644 (file)
@@ -713,6 +713,18 @@ cuda_py_test(
 )
 
 tf_py_test(
+    name = "regex_replace_op_test",
+    size = "small",
+    srcs = ["regex_replace_op_test.py"],
+    additional_deps = [
+        "//tensorflow/python:client_testlib",
+        "//tensorflow/python:constant_op",
+        "//tensorflow/python:dtypes",
+        "//tensorflow/python:string_ops",
+    ],
+)
+
+tf_py_test(
     name = "save_restore_ops_test",
     size = "small",
     srcs = ["save_restore_ops_test.py"],
diff --git a/tensorflow/python/kernel_tests/regex_replace_op_test.py b/tensorflow/python/kernel_tests/regex_replace_op_test.py
new file mode 100644 (file)
index 0000000..6739ac3
--- /dev/null
@@ -0,0 +1,71 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for RegexReplace op from string_ops."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import string_ops
+from tensorflow.python.platform import test
+
+
+class RegexReplaceOpTest(test.TestCase):
+
+  def testRemovePrefix(self):
+    values = ["a:foo", "a:bar", "a:foo", "b:baz", "b:qux", "ca:b"]
+    with self.test_session():
+      input_vector = constant_op.constant(values, dtypes.string)
+      stripped = string_ops.regex_replace(
+          input_vector, "^(a:|b:)", "", replace_global=False).eval()
+      self.assertAllEqual([b"foo", b"bar", b"foo", b"baz", b"qux", b"ca:b"],
+                          stripped)
+
+  def testRegexReplace(self):
+    values = ["aba\naba", "abcdabcde"]
+    with self.test_session():
+      input_vector = constant_op.constant(values, dtypes.string)
+      stripped = string_ops.regex_replace(input_vector, "a.*a", "(\\0)").eval()
+      self.assertAllEqual([b"(aba)\n(aba)", b"(abcda)bcde"], stripped)
+
+  def testEmptyMatch(self):
+    values = ["abc", "1"]
+    with self.test_session():
+      input_vector = constant_op.constant(values, dtypes.string)
+      stripped = string_ops.regex_replace(input_vector, "", "x").eval()
+      self.assertAllEqual([b"xaxbxcx", b"x1x"], stripped)
+
+  def testInvalidPattern(self):
+    values = ["abc", "1"]
+    with self.test_session():
+      input_vector = constant_op.constant(values, dtypes.string)
+      invalid_pattern = "A["
+      replace = string_ops.regex_replace(input_vector, invalid_pattern, "x")
+      with self.assertRaisesOpError("Invalid pattern"):
+        replace.eval()
+
+  def testGlobal(self):
+    values = ["ababababab", "abcabcabc", ""]
+    with self.test_session():
+      input_vector = constant_op.constant(values, dtypes.string)
+      stripped = string_ops.regex_replace(input_vector, "ab", "abc",
+                                          True).eval()
+      self.assertAllEqual([b"abcabcabcabcabc", b"abccabccabcc", b""], stripped)
+
+
+if __name__ == "__main__":
+  test.main()
index 0335d24..5bd75b9 100644 (file)
@@ -17,6 +17,7 @@
 
 See the @{$python/string_ops} guide.
 
+@@regex_replace
 @@string_to_hash_bucket_fast
 @@string_to_hash_bucket_strong
 @@string_to_hash_bucket
@@ -139,6 +140,7 @@ def reduce_join(inputs, axis=None,
 reduce_join.__doc__ = deprecation.rewrite_argument_docstring(
     gen_string_ops.reduce_join.__doc__, "reduction_indices", "axis")
 
+ops.NotDifferentiable("RegexReplace")
 ops.NotDifferentiable("StringToHashBucket")
 ops.NotDifferentiable("StringToHashBucketFast")
 ops.NotDifferentiable("StringToHashBucketStrong")
index 2333736..8c9e7af 100644 (file)
@@ -1601,6 +1601,10 @@ tf_module {
     argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\', \'reduction_indices\', \'keep_dims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], "
   }
   member_method {
+    name: "regex_replace"
+    argspec: "args=[\'input\', \'pattern\', \'rewrite\', \'replace_global\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
+  }
+  member_method {
     name: "register_tensor_conversion_function"
     argspec: "args=[\'base_type\', \'conversion_func\', \'priority\'], varargs=None, keywords=None, defaults=[\'100\'], "
   }