Move dummy AssertOp and CheckNumericsOp to //third_party/tensorflow/compiler/tf2xla...
authorA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 13 Apr 2018 00:07:35 +0000 (17:07 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 13 Apr 2018 00:15:33 +0000 (17:15 -0700)
Enable type DT_STRING for AssertOp and ConstOp, in order to make dummy Assert compile with a const string (assert message) as its input.

PiperOrigin-RevId: 192695938

14 files changed:
tensorflow/compiler/aot/BUILD
tensorflow/compiler/aot/tests/BUILD
tensorflow/compiler/aot/tests/make_test_graphs.py
tensorflow/compiler/aot/tests/test_graph_tfassert_eq.config.pbtxt [new file with mode: 0644]
tensorflow/compiler/aot/tests/tfcompile_test.cc
tensorflow/compiler/jit/mark_for_compilation_pass.cc
tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
tensorflow/compiler/tf2xla/kernels/BUILD
tensorflow/compiler/tf2xla/kernels/assert_op.cc [new file with mode: 0644]
tensorflow/compiler/tf2xla/kernels/check_numerics_op.cc [new file with mode: 0644]
tensorflow/compiler/tf2xla/tf2xla_util.cc
tensorflow/compiler/tf2xla/tf2xla_util.h
tensorflow/compiler/tf2xla/xla_cpu_backend.cc
tensorflow/compiler/tf2xla/xla_gpu_backend.cc

index fa03b1f..19e6bf6 100644 (file)
@@ -60,6 +60,7 @@ cc_library(
         "//tensorflow/compiler/tf2xla:tf2xla_util",
         "//tensorflow/compiler/tf2xla:xla_compiler",
         "//tensorflow/compiler/tf2xla/kernels:xla_cpu_only_ops",
+        "//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
         "//tensorflow/compiler/tf2xla/kernels:xla_ops",
         "//tensorflow/compiler/xla:shape_util",
         "//tensorflow/compiler/xla:statusor",
index b053dad..bb73cb1 100644 (file)
@@ -14,6 +14,7 @@ test_suite(
         ":test_graph_tfadd_test",
         ":test_graph_tfadd_with_ckpt_saver_test",
         ":test_graph_tfadd_with_ckpt_test",
+        ":test_graph_tfassert_eq_test",
         ":test_graph_tffunction_test",
         ":test_graph_tfgather_test",
         ":test_graph_tfmatmul_test",
@@ -33,6 +34,7 @@ py_binary(
         "//tensorflow/python",  # TODO(b/34059704): remove when fixed
         "//tensorflow/python:array_ops",
         "//tensorflow/python:client",
+        "//tensorflow/python:control_flow_ops",
         "//tensorflow/python:framework_for_generated_wrappers",
         "//tensorflow/python:math_ops",
         "//tensorflow/python:platform",
@@ -52,6 +54,7 @@ genrule(
         "test_graph_tfadd_with_ckpt_saver.ckpt",
         "test_graph_tfadd_with_ckpt_saver.pb",
         "test_graph_tfadd_with_ckpt_saver.saver",
+        "test_graph_tfassert_eq.pb",
         "test_graph_tffunction.pb",
         "test_graph_tfgather.pb",
         "test_graph_tfmatmul.pb",
@@ -105,6 +108,17 @@ tf_library(
 )
 
 tf_library(
+    name = "test_graph_tfassert_eq",
+    testonly = 1,
+    config = "test_graph_tfassert_eq.config.pbtxt",
+    cpp_class = "AssertComp",
+    graph = "test_graph_tfassert_eq.pb",
+    tags = [
+        "manual",
+    ],
+)
+
+tf_library(
     name = "test_graph_tffunction",
     testonly = 1,
     config = "test_graph_tffunction.config.pbtxt",
@@ -170,6 +184,7 @@ tf_cc_test(
         ":test_graph_tfadd",
         ":test_graph_tfadd_with_ckpt",
         ":test_graph_tfadd_with_ckpt_saver",
+        ":test_graph_tfassert_eq",
         ":test_graph_tffunction",
         ":test_graph_tfgather",
         ":test_graph_tfmatmul",
index 89c7cd4..67767f5 100644 (file)
@@ -29,6 +29,7 @@ from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import function
 from tensorflow.python.framework import ops
 from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import variables
 from tensorflow.python.platform import app
@@ -125,6 +126,14 @@ def tfsplits(_):
   array_ops.identity(y, name='result')
 
 
+def tfassert_eq(_):
+  x = array_ops.placeholder(dtypes.int32, name='x_hold')
+  y = array_ops.placeholder(dtypes.int32, name='y_hold')
+  control_flow_ops.Assert(
+      math_ops.equal(x, y), ['Expected x == y.'], name='assert_eq')
+  math_ops.add(x, math_ops.negative(y), name='x_y_diff')
+
+
 def write_graph(build_graph, out_dir):
   """Build a graph using build_graph and write it out."""
   g = ops.Graph()
@@ -144,6 +153,7 @@ def main(_):
   write_graph(tfmatmulandadd, FLAGS.out_dir)
   write_graph(tffunction, FLAGS.out_dir)
   write_graph(tfsplits, FLAGS.out_dir)
+  write_graph(tfassert_eq, FLAGS.out_dir)
 
 
 if __name__ == '__main__':
diff --git a/tensorflow/compiler/aot/tests/test_graph_tfassert_eq.config.pbtxt b/tensorflow/compiler/aot/tests/test_graph_tfassert_eq.config.pbtxt
new file mode 100644 (file)
index 0000000..8732d17
--- /dev/null
@@ -0,0 +1,16 @@
+# Text form of tensorflow.tf2xla.Config proto.
+feed {
+  id { node_name: "x_hold" }
+  shape {
+    dim { size: 1 }
+  }
+}
+feed {
+  id { node_name: "y_hold" }
+  shape {
+    dim { size: 1 }
+  }
+}
+fetch {
+  id { node_name: "x_y_diff" }
+}
index 413efd9..67dbd64 100644 (file)
@@ -20,6 +20,7 @@ limitations under the License.
 #include "tensorflow/compiler/aot/tests/test_graph_tfadd.h"
 #include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt.h"
 #include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt_saver.h"
+#include "tensorflow/compiler/aot/tests/test_graph_tfassert_eq.h"
 #include "tensorflow/compiler/aot/tests/test_graph_tffunction.h"
 #include "tensorflow/compiler/aot/tests/test_graph_tfgather.h"
 #include "tensorflow/compiler/aot/tests/test_graph_tfmatmul.h"
@@ -413,6 +414,23 @@ TEST(TFCompileTest, Splits) {
   EXPECT_NEAR(expected[3], fn.result0(1, 1), 1e4);
 }
 
+TEST(TFCompileTest, AssertEqAndReturnDiff) {
+  // Assert is converted into a no-op in XLA, so there is no failure even if the
+  // two args are different.
+  AssertComp assert;
+  EXPECT_EQ(assert.arg0_data(), assert.args()[0]);
+  EXPECT_EQ(assert.arg1_data(), assert.args()[1]);
+
+  assert.arg0() = 2;
+  assert.arg1() = 1;
+  const int32 expected_result = assert.arg0() - assert.arg1();
+  EXPECT_TRUE(assert.Run());
+  EXPECT_EQ(assert.error_msg(), "");
+  EXPECT_EQ(assert.result0(), expected_result);
+  EXPECT_EQ(assert.result0_data()[0], expected_result);
+  EXPECT_EQ(assert.result0_data(), assert.results()[0]);
+}
+
 TEST(TFCompileTest, LookupNameIndex) {
   // add doesn't have any names defined in its config.
   AddComp add;
index f32c0f4..0c9fbf3 100644 (file)
@@ -50,6 +50,15 @@ bool HasXLAKernel(const Node& node, const DeviceType& jit_device_type) {
   // is really a kind of function call and will be handled by
   // IsCompilableCall().
   if (node.type_string() == "SymbolicGradient") return false;
+  if (node.type_string() == "Const") {
+    // Skip Const op with type DT_STRING, since XLA doesn't support it, but the
+    // registered Const KernelDef says that it does, to support no-op Assert for
+    // tfcompile.
+    const AttrValue* attr = node.attrs().Find("dtype");
+    if (attr != nullptr && attr->type() == DT_STRING) {
+      return false;
+    }
+  }
   return FindKernelDef(jit_device_type, node.def(), nullptr, nullptr).ok();
 }
 
index 80edaf2..703d882 100644 (file)
@@ -609,5 +609,29 @@ TEST(XlaCompilationTest, DontCountIdentityOpsWithLocalJit) {
   EXPECT_TRUE(clusters.empty());
 }
 
+TEST(XlaCompilationTest, ConstOp) {
+  // valid data type
+  {
+    std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+    Scope root = Scope::NewRootScope().ExitOnError();
+    auto c = ops::Const(root.WithOpName("const"), 0.5f);
+    c.node()->AddAttr(kXlaCompileAttr, true);
+    TF_ASSERT_OK(root.ToGraph(graph.get()));
+    TF_ASSERT_OK(MarkForCompilation(&graph));
+    EXPECT_EQ(1, GetClusters(*graph).size());
+  }
+
+  // invalid data type
+  {
+    std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+    Scope root = Scope::NewRootScope().ExitOnError();
+    auto c = ops::Const(root.WithOpName("const"), string("string"));
+    c.node()->AddAttr(kXlaCompileAttr, true);
+    TF_ASSERT_OK(root.ToGraph(graph.get()));
+    TF_ASSERT_OK(MarkForCompilation(&graph));
+    EXPECT_TRUE(GetClusters(*graph).empty());
+  }
+}
+
 }  // namespace
 }  // namespace tensorflow
index f1bc7d6..3ba37b0 100644 (file)
@@ -171,6 +171,23 @@ tf_kernel_library(
     ],
 )
 
+# Kernels that have a dummy (no-op) implementation.
+tf_kernel_library(
+    name = "xla_dummy_ops",
+    srcs = [
+        "assert_op.cc",
+        "check_numerics_op.cc",
+    ],
+    deps = [
+        "//tensorflow/compiler/tf2xla:xla_compiler",
+        "//tensorflow/core:array_ops_op_lib",
+        "//tensorflow/core:framework",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:logging_ops_op_lib",
+    ],
+    alwayslink = 1,
+)
+
 # Kernels that only work on CPU, because they use XLA custom calls.
 # Only link this when using the CPU backend for XLA.
 tf_kernel_library(
diff --git a/tensorflow/compiler/tf2xla/kernels/assert_op.cc b/tensorflow/compiler/tf2xla/kernels/assert_op.cc
new file mode 100644 (file)
index 0000000..af4ab5e
--- /dev/null
@@ -0,0 +1,49 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/mutex.h"
+
+namespace tensorflow {
+
+namespace {
+
+// This TensorFlow op supports the Assert primitve.
+class AssertOp : public XlaOpKernel {
+ public:
+  explicit AssertOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
+  ~AssertOp() override {}
+
+  void Compile(XlaOpKernelContext* ctx) override {
+    static mutex mu(tensorflow::LINKER_INITIALIZED);
+    static int log_counter = 0;
+
+    mutex_lock l(mu);
+    if (log_counter < 20) {
+      ++log_counter;
+      LOG(WARNING) << "Ignoring Assert operator " << name();
+    }
+  }
+
+ private:
+  TF_DISALLOW_COPY_AND_ASSIGN(AssertOp);
+};
+
+REGISTER_XLA_OP(Name("Assert"), AssertOp);
+
+}  // anonymous namespace
+}  // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/check_numerics_op.cc b/tensorflow/compiler/tf2xla/kernels/check_numerics_op.cc
new file mode 100644 (file)
index 0000000..6061e82
--- /dev/null
@@ -0,0 +1,50 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/mutex.h"
+
+namespace tensorflow {
+namespace {
+
+class CheckNumericsOp : public XlaOpKernel {
+ public:
+  explicit CheckNumericsOp(OpKernelConstruction* context)
+      : XlaOpKernel(context) {}
+
+  void Compile(XlaOpKernelContext* ctx) override {
+    // TODO(b/32223192): add a real implementation of CheckNumerics
+    {
+      static mutex mu(tensorflow::LINKER_INITIALIZED);
+      static int log_counter = 0;
+      mutex_lock l(mu);
+      if (log_counter < 20) {
+        ++log_counter;
+        LOG(WARNING) << "Ignoring CheckNumerics operator " << name();
+      }
+    }
+    ctx->SetOutput(0, ctx->Input(0));
+  }
+
+ private:
+  TF_DISALLOW_COPY_AND_ASSIGN(CheckNumericsOp);
+};
+
+REGISTER_XLA_OP(Name("CheckNumerics"), CheckNumericsOp);
+
+}  // anonymous namespace
+}  // namespace tensorflow
index 2fc77cc..7ec85aa 100644 (file)
@@ -288,4 +288,13 @@ Status SetNodeShardingFromNeighbors(Node* n, bool out_edges) {
   return Status::OK();
 }
 
+void AddDtypeToKernalDefConstraint(StringPiece name, DataType dtype,
+                                   KernelDef* kdef) {
+  for (KernelDef::AttrConstraint& constraint : *kdef->mutable_constraint()) {
+    if (constraint.name() == name) {
+      constraint.mutable_allowed_values()->mutable_list()->add_type(dtype);
+    }
+  }
+}
+
 }  // namespace tensorflow
index e5fba8e..745beb3 100644 (file)
@@ -20,6 +20,7 @@ limitations under the License.
 
 #include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
 #include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/kernel_def.pb.h"
 #include "tensorflow/core/framework/op.h"
 #include "tensorflow/core/graph/graph.h"
 #include "tensorflow/core/lib/core/status.h"
@@ -51,6 +52,10 @@ string TensorIdToString(const tf2xla::TensorId& id);
 // edges are considered.
 Status SetNodeShardingFromNeighbors(Node* n, bool out_edges);
 
+// Add an allowed data type to the AttrConstraint with the given name.
+void AddDtypeToKernalDefConstraint(StringPiece name, DataType dtype,
+                                   KernelDef* kdef);
+
 }  // namespace tensorflow
 
 #endif  // TENSORFLOW_COMPILER_TF2XLA_TF2XLA_UTIL_H_
index 8286480..ead229a 100644 (file)
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
 
+#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
 #include "tensorflow/core/framework/kernel_def.pb.h"
 
@@ -30,6 +31,12 @@ bool CpuOpFilter(KernelDef* kdef) {
         DT_FLOAT);
     return true;
   }
+  if (kdef->op() == "Const") {
+    AddDtypeToKernalDefConstraint("dtype", DT_STRING, kdef);
+  }
+  if (kdef->op() == "Assert") {
+    AddDtypeToKernalDefConstraint("T", DT_STRING, kdef);
+  }
   return true;
 }
 
index 8ca757e..62168b6 100644 (file)
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
 
+#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
 #include "tensorflow/core/framework/kernel_def.pb.h"
 
@@ -25,6 +26,12 @@ bool GpuOpFilter(KernelDef* kdef) {
       kdef->op() == "RandomUniformInt" || kdef->op() == "TruncatedNormal") {
     return false;
   }
+  if (kdef->op() == "Const") {
+    AddDtypeToKernalDefConstraint("dtype", DT_STRING, kdef);
+  }
+  if (kdef->op() == "Assert") {
+    AddDtypeToKernalDefConstraint("T", DT_STRING, kdef);
+  }
   return true;
 }