From dff3875cdca6a8cf49ee5ce4c0c970eda550157f Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 30 May 2018 16:17:45 -0700 Subject: [PATCH] Automated g4 rollback of changelist 198444757 PiperOrigin-RevId: 198637528 --- tensorflow/compiler/jit/kernels/xla_launch_op.cc | 2 +- .../compiler/jit/xla_compile_on_demand_op.cc | 3 +- tensorflow/compiler/tf2xla/tf2xla.cc | 3 +- tensorflow/compiler/tf2xla/xla_compiler.cc | 71 ++++++++++++++++++-- tensorflow/compiler/tf2xla/xla_compiler.h | 7 +- tensorflow/compiler/tf2xla/xla_compiler_test.cc | 78 ++++++++++++++++++++-- 6 files changed, 147 insertions(+), 17 deletions(-) diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_launch_op.cc index 27287e0..902fe27 100644 --- a/tensorflow/compiler/jit/kernels/xla_launch_op.cc +++ b/tensorflow/compiler/jit/kernels/xla_launch_op.cc @@ -148,7 +148,7 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { XlaCompiler::Options options; options.client = client; - options.device_type = &cache->device_type(); + options.device_type = cache->device_type(); options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition(); options.graph_def_version = ctx->function_library()->graph_def_version(); options.allow_cpu_custom_calls = (platform_id_ == se::host::kHostPlatformId); diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc index ab644ff..b1943d3 100644 --- a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc +++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc @@ -151,8 +151,7 @@ Status XlaCompileOnDemandOp::Compile( core::ScopedUnref cache_ref(cache); XlaCompiler::Options options; - DeviceType device_type = metadata.jit_device_type(); - options.device_type = &device_type; + options.device_type = metadata.jit_device_type(); options.client = metadata.client(); options.flib_def = new FunctionLibraryDefinition(OpRegistry::Global(), FunctionDefLibrary{}); diff --git a/tensorflow/compiler/tf2xla/tf2xla.cc b/tensorflow/compiler/tf2xla/tf2xla.cc index 3a08aa8..ac768b2 100644 --- a/tensorflow/compiler/tf2xla/tf2xla.cc +++ b/tensorflow/compiler/tf2xla/tf2xla.cc @@ -263,8 +263,7 @@ Status ConvertGraphToXla(std::unique_ptr graph, xla::Client* client, // Compile the graph into an XLA computation. XlaCompiler::Options compiler_options; compiler_options.client = client; - DeviceType device_type(DEVICE_CPU_XLA_JIT); - compiler_options.device_type = &device_type; + compiler_options.device_type = DeviceType(DEVICE_CPU_XLA_JIT); compiler_options.flib_def = &graph->flib_def(); compiler_options.graph_def_version = graph->versions().producer(); compiler_options.allow_cpu_custom_calls = true; diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index f709891..2fce616 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -83,12 +83,9 @@ XlaCompiler::XlaCompiler(XlaCompiler::Options options) : options_(options), initialization_status_(Status::OK()), next_step_id_(1), - device_( - new XlaCompilationDevice(SessionOptions(), *options_.device_type)), + device_(new XlaCompilationDevice(SessionOptions(), options_.device_type)), device_mgr_({device_}) { - // We no longer need the device_type. - options_.device_type = nullptr; - + CHECK(!options_.device_type.type_string().empty()); if (options_.populate_resource_manager) { initialization_status_ = (*options_.populate_resource_manager)(device_->resource_manager()); @@ -659,6 +656,65 @@ Status XlaCompiler::CompileSingleOp( return CompileGraph(options, name, std::move(graph), args, result); } +namespace { + +// Check that the ops of all non-functional nodes have been registered. +string ValidateFunctionDef(const FunctionDef* fdef, + const FunctionLibraryDefinition& flib_def) { + std::vector invalid_ops; + for (const NodeDef& node : fdef->node_def()) { + const string& op = node.op(); + if (op == FunctionLibraryDefinition::kGradientOp || flib_def.Find(op)) { + continue; + } + const OpDef* op_def; + if (!OpRegistry::Global()->LookUpOpDef(op, &op_def).ok()) { + invalid_ops.push_back(op); + } + } + return tensorflow::str_util::Join(invalid_ops, ", "); +} + +// Check that the graph doesn't have any invalid nodes (e.g. incompatible with +// given device_type, invalid data type, missing attributes...) +Status ValidateGraph(const Graph* graph, + const FunctionLibraryDefinition& flib_def, + const DeviceType& device_type, const string& name) { + std::vector invalid_ops; + for (const Node* node : graph->nodes()) { + if (node->type_string() == FunctionLibraryDefinition::kGradientOp) { + continue; + } + const FunctionDef* fdef = flib_def.Find(node->def().op()); + if (fdef) { + string error_msg = ValidateFunctionDef(fdef, flib_def); + if (!error_msg.empty()) { + invalid_ops.push_back( + strings::StrCat(node->def().op(), ":{", error_msg, "}")); + } + continue; + } + const OpDef* op_def; + if (!OpRegistry::Global()->LookUpOpDef(node->def().op(), &op_def).ok()) { + invalid_ops.push_back(node->def().op()); + continue; + } + TF_RETURN_IF_ERROR(ValidateNodeDef(node->def(), *op_def)); + if (!FindKernelDef(device_type, node->def(), nullptr, nullptr).ok()) { + invalid_ops.push_back(node->def().op()); + } + } + if (!invalid_ops.empty()) { + return errors::InvalidArgument(strings::StrCat( + "Detected unsupported operations when trying to compile graph ", name, + " on ", device_type.type_string(), ":", + tensorflow::str_util::Join(invalid_ops, ", "))); + } + return Status::OK(); +} + +} // namespace + Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, string const& name, std::unique_ptr graph, @@ -681,6 +737,11 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, FunctionalizeControlFlow(flib_runtime_->GetFunctionLibraryDefinition(), graph.get(), local_flib_def_.get())); + // Detect invalid nodes. + // FunctionalizeControlFlow may remove some nodes from the graph. + TF_RETURN_IF_ERROR(ValidateGraph(graph.get(), *options_.flib_def, + options_.device_type, name)); + xla::XlaBuilder builder(name); XlaContext* context = new XlaContext( this, &builder, options_.allow_cpu_custom_calls, diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h index bf496bd..76f4c4c 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.h +++ b/tensorflow/compiler/tf2xla/xla_compiler.h @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h" #include "tensorflow/compiler/tf2xla/xla_compilation_device.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_mgr.h" @@ -244,9 +245,9 @@ class XlaCompiler { typedef std::function ShapeRepresentationFn; struct Options { - // Name of the compilation device to use. Needs to be live only during - // XlaCompiler's constructor. - const DeviceType* device_type = nullptr; + // Name of the compilation device to use. It must be set by the caller. + // The default empty value is invalid. + DeviceType device_type = DeviceType(""); xla::Client* client = nullptr; diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc index 55772ca..5fbf4b9 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc @@ -45,8 +45,6 @@ namespace tensorflow { class XlaCompilerTest : public ::testing::Test { protected: - XlaCompilerTest() : cpu_device_type_(DEVICE_CPU_XLA_JIT) {} - void SetUp() override { client_ = xla::ClientLibrary::LocalClientOrDie(); @@ -58,7 +56,7 @@ class XlaCompilerTest : public ::testing::Test { XlaCompiler::Options DefaultOptions() { XlaCompiler::Options options; - options.device_type = &cpu_device_type_; + options.device_type = DeviceType(DEVICE_CPU_XLA_JIT); options.client = client_; options.flib_def = flib_def_.get(); return options; @@ -68,7 +66,6 @@ class XlaCompilerTest : public ::testing::Test { return compiler->local_flib_def_.get(); } - DeviceType cpu_device_type_; xla::Client* client_; std::unique_ptr flib_def_; }; @@ -979,5 +976,78 @@ TEST_F(XlaCompilerTest, ArgRetvalShapeRepresentationFunction) { EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); } +// Tests a graph which has a function with an invalid op. +TEST_F(XlaCompilerTest, FunctionWithInvalidOp) { + XlaCompiler compiler(DefaultOptions()); + + FunctionDefLibrary flib; + FunctionDef fn = FillFn(); + NodeDef* node = fn.add_node_def(); + node->set_name("Invalid"); + node->set_op("InvalidOp"); /* unsupported op */ + node = fn.add_node_def(); + node->set_name("Switch"); + node->set_op("Switch"); /* control flow node */ + *flib.add_function() = fn; + + TF_ASSERT_OK(flib_def_->AddFunctionDef(fn)); + + std::unique_ptr graph(new Graph(OpRegistry::Global())); + + Scope scope = Scope::NewRootScope().ExitOnError(); + auto value = ops::Const(scope.WithOpName("value"), 1, {}); + auto shape = ops::Const(scope.WithOpName("shape"), {5}, {1}); + TF_ASSERT_OK(scope.graph()->AddFunctionLibrary(flib)); + + NodeDef def; + TF_ASSERT_OK(NodeDefBuilder("fill_fn", "FillFn", flib_def_.get()) + .Input(value.name(), 0, DT_INT32) + .Input(shape.name(), 1, DT_INT32) + .Finalize(&def)); + Status status; + Node* fill = scope.graph()->AddNode(def, &status); + TF_ASSERT_OK(status); + TF_ASSERT_OK(scope.DoShapeInference(fill)); + scope.graph()->AddEdge(value.node(), 0, fill, 0); + scope.graph()->AddEdge(shape.node(), 0, fill, 1); + + auto retval = ops::_Retval(scope.WithOpName("retval"), Output(fill), 0); + + TF_ASSERT_OK(scope.ToGraph(graph.get())); + + std::vector args; + XlaCompiler::CompilationResult result; + status = compiler.CompileGraph(XlaCompiler::CompileOptions(), "fill", + std::move(graph), args, &result); + ASSERT_FALSE(status.ok()); + EXPECT_TRUE( + str_util::StrContains(status.error_message(), "FillFn:{InvalidOp}")) + << status.error_message(); +} + +// Tests a graph which has a node with invalid data type. +TEST_F(XlaCompilerTest, NodeWithInvalidDataType) { + std::unique_ptr graph(new Graph(OpRegistry::Global())); + NodeDef shape; + shape.set_name("Shape"); + shape.set_op("Shape"); + (*shape.mutable_attr())["T"].set_type(DT_INT32); + (*shape.mutable_attr())["out_type"].set_type(DT_BOOL); /* invalid type */ + Status status; + Node* shape_node = graph->AddNode(shape, &status); + TF_ASSERT_OK(status); + graph->AddControlEdge(graph->source_node(), shape_node); + + std::vector args; + XlaCompiler::CompilationResult result; + XlaCompiler compiler(DefaultOptions()); + status = compiler.CompileGraph(XlaCompiler::CompileOptions(), "invalid_type", + std::move(graph), args, &result); + ASSERT_FALSE(status.ok()); + EXPECT_TRUE(str_util::StrContains(status.error_message(), + "is not in the list of allowed values")) + << status.error_message(); +} + } // namespace } // namespace tensorflow -- 2.7.4