"//tensorflow/compiler/tf2xla:tf2xla_util",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/kernels:xla_cpu_only_ops",
+ "//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
":test_graph_tfadd_test",
":test_graph_tfadd_with_ckpt_saver_test",
":test_graph_tfadd_with_ckpt_test",
+ ":test_graph_tfassert_eq_test",
":test_graph_tffunction_test",
":test_graph_tfgather_test",
":test_graph_tfmatmul_test",
"//tensorflow/python", # TODO(b/34059704): remove when fixed
"//tensorflow/python:array_ops",
"//tensorflow/python:client",
+ "//tensorflow/python:control_flow_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform",
"test_graph_tfadd_with_ckpt_saver.ckpt",
"test_graph_tfadd_with_ckpt_saver.pb",
"test_graph_tfadd_with_ckpt_saver.saver",
+ "test_graph_tfassert_eq.pb",
"test_graph_tffunction.pb",
"test_graph_tfgather.pb",
"test_graph_tfmatmul.pb",
)
tf_library(
+ name = "test_graph_tfassert_eq",
+ testonly = 1,
+ config = "test_graph_tfassert_eq.config.pbtxt",
+ cpp_class = "AssertComp",
+ graph = "test_graph_tfassert_eq.pb",
+ tags = [
+ "manual",
+ ],
+)
+
+tf_library(
name = "test_graph_tffunction",
testonly = 1,
config = "test_graph_tffunction.config.pbtxt",
":test_graph_tfadd",
":test_graph_tfadd_with_ckpt",
":test_graph_tfadd_with_ckpt_saver",
+ ":test_graph_tfassert_eq",
":test_graph_tffunction",
":test_graph_tfgather",
":test_graph_tfmatmul",
from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import app
array_ops.identity(y, name='result')
+def tfassert_eq(_):
+ x = array_ops.placeholder(dtypes.int32, name='x_hold')
+ y = array_ops.placeholder(dtypes.int32, name='y_hold')
+ control_flow_ops.Assert(
+ math_ops.equal(x, y), ['Expected x == y.'], name='assert_eq')
+ math_ops.add(x, math_ops.negative(y), name='x_y_diff')
+
+
def write_graph(build_graph, out_dir):
"""Build a graph using build_graph and write it out."""
g = ops.Graph()
write_graph(tfmatmulandadd, FLAGS.out_dir)
write_graph(tffunction, FLAGS.out_dir)
write_graph(tfsplits, FLAGS.out_dir)
+ write_graph(tfassert_eq, FLAGS.out_dir)
if __name__ == '__main__':
--- /dev/null
+# Text form of tensorflow.tf2xla.Config proto.
+feed {
+ id { node_name: "x_hold" }
+ shape {
+ dim { size: 1 }
+ }
+}
+feed {
+ id { node_name: "y_hold" }
+ shape {
+ dim { size: 1 }
+ }
+}
+fetch {
+ id { node_name: "x_y_diff" }
+}
#include "tensorflow/compiler/aot/tests/test_graph_tfadd.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt_saver.h"
+#include "tensorflow/compiler/aot/tests/test_graph_tfassert_eq.h"
#include "tensorflow/compiler/aot/tests/test_graph_tffunction.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfgather.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfmatmul.h"
EXPECT_NEAR(expected[3], fn.result0(1, 1), 1e4);
}
+TEST(TFCompileTest, AssertEqAndReturnDiff) {
+ // Assert is converted into a no-op in XLA, so there is no failure even if the
+ // two args are different.
+ AssertComp assert;
+ EXPECT_EQ(assert.arg0_data(), assert.args()[0]);
+ EXPECT_EQ(assert.arg1_data(), assert.args()[1]);
+
+ assert.arg0() = 2;
+ assert.arg1() = 1;
+ const int32 expected_result = assert.arg0() - assert.arg1();
+ EXPECT_TRUE(assert.Run());
+ EXPECT_EQ(assert.error_msg(), "");
+ EXPECT_EQ(assert.result0(), expected_result);
+ EXPECT_EQ(assert.result0_data()[0], expected_result);
+ EXPECT_EQ(assert.result0_data(), assert.results()[0]);
+}
+
TEST(TFCompileTest, LookupNameIndex) {
// add doesn't have any names defined in its config.
AddComp add;
// is really a kind of function call and will be handled by
// IsCompilableCall().
if (node.type_string() == "SymbolicGradient") return false;
+ if (node.type_string() == "Const") {
+ // Skip Const op with type DT_STRING, since XLA doesn't support it, but the
+ // registered Const KernelDef says that it does, to support no-op Assert for
+ // tfcompile.
+ const AttrValue* attr = node.attrs().Find("dtype");
+ if (attr != nullptr && attr->type() == DT_STRING) {
+ return false;
+ }
+ }
return FindKernelDef(jit_device_type, node.def(), nullptr, nullptr).ok();
}
EXPECT_TRUE(clusters.empty());
}
+TEST(XlaCompilationTest, ConstOp) {
+ // valid data type
+ {
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+ Scope root = Scope::NewRootScope().ExitOnError();
+ auto c = ops::Const(root.WithOpName("const"), 0.5f);
+ c.node()->AddAttr(kXlaCompileAttr, true);
+ TF_ASSERT_OK(root.ToGraph(graph.get()));
+ TF_ASSERT_OK(MarkForCompilation(&graph));
+ EXPECT_EQ(1, GetClusters(*graph).size());
+ }
+
+ // invalid data type
+ {
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+ Scope root = Scope::NewRootScope().ExitOnError();
+ auto c = ops::Const(root.WithOpName("const"), string("string"));
+ c.node()->AddAttr(kXlaCompileAttr, true);
+ TF_ASSERT_OK(root.ToGraph(graph.get()));
+ TF_ASSERT_OK(MarkForCompilation(&graph));
+ EXPECT_TRUE(GetClusters(*graph).empty());
+ }
+}
+
} // namespace
} // namespace tensorflow
],
)
+# Kernels that have a dummy (no-op) implementation.
+tf_kernel_library(
+ name = "xla_dummy_ops",
+ srcs = [
+ "assert_op.cc",
+ "check_numerics_op.cc",
+ ],
+ deps = [
+ "//tensorflow/compiler/tf2xla:xla_compiler",
+ "//tensorflow/core:array_ops_op_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:logging_ops_op_lib",
+ ],
+ alwayslink = 1,
+)
+
# Kernels that only work on CPU, because they use XLA custom calls.
# Only link this when using the CPU backend for XLA.
tf_kernel_library(
--- /dev/null
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/mutex.h"
+
+namespace tensorflow {
+
+namespace {
+
+// This TensorFlow op supports the Assert primitve.
+class AssertOp : public XlaOpKernel {
+ public:
+ explicit AssertOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
+ ~AssertOp() override {}
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ static mutex mu(tensorflow::LINKER_INITIALIZED);
+ static int log_counter = 0;
+
+ mutex_lock l(mu);
+ if (log_counter < 20) {
+ ++log_counter;
+ LOG(WARNING) << "Ignoring Assert operator " << name();
+ }
+ }
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(AssertOp);
+};
+
+REGISTER_XLA_OP(Name("Assert"), AssertOp);
+
+} // anonymous namespace
+} // namespace tensorflow
--- /dev/null
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/mutex.h"
+
+namespace tensorflow {
+namespace {
+
+class CheckNumericsOp : public XlaOpKernel {
+ public:
+ explicit CheckNumericsOp(OpKernelConstruction* context)
+ : XlaOpKernel(context) {}
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ // TODO(b/32223192): add a real implementation of CheckNumerics
+ {
+ static mutex mu(tensorflow::LINKER_INITIALIZED);
+ static int log_counter = 0;
+ mutex_lock l(mu);
+ if (log_counter < 20) {
+ ++log_counter;
+ LOG(WARNING) << "Ignoring CheckNumerics operator " << name();
+ }
+ }
+ ctx->SetOutput(0, ctx->Input(0));
+ }
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(CheckNumericsOp);
+};
+
+REGISTER_XLA_OP(Name("CheckNumerics"), CheckNumericsOp);
+
+} // anonymous namespace
+} // namespace tensorflow
return Status::OK();
}
+void AddDtypeToKernalDefConstraint(StringPiece name, DataType dtype,
+ KernelDef* kdef) {
+ for (KernelDef::AttrConstraint& constraint : *kdef->mutable_constraint()) {
+ if (constraint.name() == name) {
+ constraint.mutable_allowed_values()->mutable_list()->add_type(dtype);
+ }
+ }
+}
+
} // namespace tensorflow
#include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/kernel_def.pb.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/lib/core/status.h"
// edges are considered.
Status SetNodeShardingFromNeighbors(Node* n, bool out_edges);
+// Add an allowed data type to the AttrConstraint with the given name.
+void AddDtypeToKernalDefConstraint(StringPiece name, DataType dtype,
+ KernelDef* kdef);
+
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_TF2XLA_TF2XLA_UTIL_H_
limitations under the License.
==============================================================================*/
+#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/core/framework/kernel_def.pb.h"
DT_FLOAT);
return true;
}
+ if (kdef->op() == "Const") {
+ AddDtypeToKernalDefConstraint("dtype", DT_STRING, kdef);
+ }
+ if (kdef->op() == "Assert") {
+ AddDtypeToKernalDefConstraint("T", DT_STRING, kdef);
+ }
return true;
}
limitations under the License.
==============================================================================*/
+#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/core/framework/kernel_def.pb.h"
kdef->op() == "RandomUniformInt" || kdef->op() == "TruncatedNormal") {
return false;
}
+ if (kdef->op() == "Const") {
+ AddDtypeToKernalDefConstraint("dtype", DT_STRING, kdef);
+ }
+ if (kdef->op() == "Assert") {
+ AddDtypeToKernalDefConstraint("T", DT_STRING, kdef);
+ }
return true;
}