--- /dev/null
+op {
+ graph_op_name: "RegexFullMatch"
+ in_arg {
+ name: "input"
+ description: <<END
+A string tensor of the text to be processed.
+END
+ }
+ in_arg {
+ name: "pattern"
+ description: <<END
+A 1-D string tensor of the regular expression to match the input.
+END
+ }
+ out_arg {
+ name: "output"
+ description: <<END
+A bool tensor with the same shape as `input`.
+END
+ }
+ summary: "Check if the input matches the regex pattern."
+ description: <<END
+The input is a string tensor of any shape. The pattern is a scalar
+string tensor which is applied to every element of the input tensor.
+The boolean values (True or False) of the output tensor indicate
+if the input matches the regex pattern provided.
+
+The pattern follows the re2 syntax (https://github.com/google/re2/wiki/Syntax)
+END
+}
--- /dev/null
+op {
+ graph_op_name: "RegexFullMatch"
+ visibility: HIDDEN
+}
":as_string_op",
":base64_ops",
":reduce_join_op",
+ ":regex_full_match_op",
":regex_replace_op",
":string_join_op",
":string_split_op",
)
tf_kernel_library(
+ name = "regex_full_match_op",
+ prefix = "regex_full_match_op",
+ deps = STRING_DEPS + ["@com_googlesource_code_re2//:re2"],
+)
+
+tf_kernel_library(
name = "regex_replace_op",
prefix = "regex_replace_op",
deps = STRING_DEPS + ["@com_googlesource_code_re2//:re2"],
--- /dev/null
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <string>
+
+#include "re2/re2.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+
+class RegexFullMatchOp : public OpKernel {
+ public:
+ explicit RegexFullMatchOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+ void Compute(OpKernelContext* ctx) override {
+ const Tensor* input_tensor;
+ OP_REQUIRES_OK(ctx, ctx->input("input", &input_tensor));
+ const auto& input_flat = input_tensor->flat<string>();
+
+ const Tensor* pattern_tensor;
+ OP_REQUIRES_OK(ctx, ctx->input("pattern", &pattern_tensor));
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(pattern_tensor->shape()),
+ errors::InvalidArgument("Pattern must be scalar, but received ",
+ pattern_tensor->shape().DebugString()));
+ const string pattern = pattern_tensor->flat<string>()(0);
+ const RE2 match(pattern);
+ OP_REQUIRES(ctx, match.ok(),
+ errors::InvalidArgument("Invalid pattern: ", pattern,
+ ", error: ", match.error()));
+
+ Tensor* output_tensor = nullptr;
+ OP_REQUIRES_OK(ctx, ctx->allocate_output("output", input_tensor->shape(),
+ &output_tensor));
+ auto output_flat = output_tensor->flat<bool>();
+ for (size_t i = 0; i < input_flat.size(); ++i) {
+ output_flat(i) = RE2::FullMatch(input_flat(i), match);
+ }
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("RegexFullMatch").Device(DEVICE_CPU),
+ RegexFullMatchOp);
+
+} // namespace tensorflow
return Status::OK();
});
+REGISTER_OP("RegexFullMatch")
+ .Input("input: string")
+ .Input("pattern: string")
+ .Output("output: bool")
+ .SetShapeFn([](InferenceContext* c) {
+ ShapeHandle unused;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
+ c->set_output(0, c->input(0));
+ return Status::OK();
+ });
+
REGISTER_OP("StringToHashBucketFast")
.Input("input: string")
.Output("output: int64")
)
tf_py_test(
+ name = "regex_full_match_op_test",
+ size = "small",
+ srcs = ["regex_full_match_op_test.py"],
+ additional_deps = [
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:string_ops",
+ ],
+)
+
+tf_py_test(
name = "save_restore_ops_test",
size = "small",
srcs = ["save_restore_ops_test.py"],
--- /dev/null
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for RegexFullMatch op from string_ops."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import string_ops
+from tensorflow.python.platform import test
+
+
+class RegexFullMatchOpTest(test.TestCase):
+
+ def testRegexFullMatch(self):
+ values = ["abaaba", "abcdabcde"]
+ with self.test_session():
+ input_vector = constant_op.constant(values, dtypes.string)
+ matched = string_ops.regex_full_match(input_vector, "a.*a").eval()
+ self.assertAllEqual([True, False], matched)
+
+ def testEmptyMatch(self):
+ values = ["abc", "1"]
+ with self.test_session():
+ input_vector = constant_op.constant(values, dtypes.string)
+ matched = string_ops.regex_full_match(input_vector, "").eval()
+ self.assertAllEqual([False, False], matched)
+
+ def testInvalidPattern(self):
+ values = ["abc", "1"]
+ with self.test_session():
+ input_vector = constant_op.constant(values, dtypes.string)
+ invalid_pattern = "A["
+ matched = string_ops.regex_full_match(input_vector, invalid_pattern)
+ with self.assertRaisesOpError("Invalid pattern"):
+ matched.eval()
+
+
+if __name__ == "__main__":
+ test.main()
from tensorflow.python.util.tf_export import tf_export
# pylint: enable=wildcard-import
+# Expose regex_full_match in strings namespace
+tf_export("strings.regex_full_match")(regex_full_match)
@tf_export("string_split")
def string_split(source, delimiter=" ", skip_empty=True): # pylint: disable=invalid-name
"api/profiler/__init__.py",
"api/python_io/__init__.py",
"api/resource_loader/__init__.py",
+ "api/strings/__init__.py",
"api/saved_model/__init__.py",
"api/saved_model/builder/__init__.py",
"api/saved_model/constants/__init__.py",
mtype: "<class \'tensorflow.python.framework.dtypes.DType\'>"
}
member {
+ name: "strings"
+ mtype: "<type \'module\'>"
+ }
+ member {
name: "summary"
mtype: "<type \'module\'>"
}
--- /dev/null
+path: "tensorflow.strings"
+tf_module {
+ member_method {
+ name: "regex_full_match"
+ argspec: "args=[\'input\', \'pattern\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+}