From d24b52adff3675809aaa623b0c160a526cd1f12a Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 29 May 2018 13:06:57 -0700 Subject: [PATCH] Automated g4 rollback of changelist 198421828 PiperOrigin-RevId: 198444757 --- tensorflow/compiler/jit/kernels/xla_launch_op.cc | 2 +- .../compiler/jit/xla_compile_on_demand_op.cc | 3 +- tensorflow/compiler/tf2xla/tf2xla.cc | 3 +- tensorflow/compiler/tf2xla/xla_compiler.cc | 65 ++-------------------- tensorflow/compiler/tf2xla/xla_compiler.h | 7 +-- tensorflow/compiler/tf2xla/xla_compiler_test.cc | 54 ++---------------- 6 files changed, 17 insertions(+), 117 deletions(-) diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_launch_op.cc index 902fe27..27287e0 100644 --- a/tensorflow/compiler/jit/kernels/xla_launch_op.cc +++ b/tensorflow/compiler/jit/kernels/xla_launch_op.cc @@ -148,7 +148,7 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { XlaCompiler::Options options; options.client = client; - options.device_type = cache->device_type(); + options.device_type = &cache->device_type(); options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition(); options.graph_def_version = ctx->function_library()->graph_def_version(); options.allow_cpu_custom_calls = (platform_id_ == se::host::kHostPlatformId); diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc index b1943d3..ab644ff 100644 --- a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc +++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc @@ -151,7 +151,8 @@ Status XlaCompileOnDemandOp::Compile( core::ScopedUnref cache_ref(cache); XlaCompiler::Options options; - options.device_type = metadata.jit_device_type(); + DeviceType device_type = metadata.jit_device_type(); + options.device_type = &device_type; options.client = metadata.client(); options.flib_def = new FunctionLibraryDefinition(OpRegistry::Global(), FunctionDefLibrary{}); diff --git a/tensorflow/compiler/tf2xla/tf2xla.cc b/tensorflow/compiler/tf2xla/tf2xla.cc index ac768b2..3a08aa8 100644 --- a/tensorflow/compiler/tf2xla/tf2xla.cc +++ b/tensorflow/compiler/tf2xla/tf2xla.cc @@ -263,7 +263,8 @@ Status ConvertGraphToXla(std::unique_ptr graph, xla::Client* client, // Compile the graph into an XLA computation. XlaCompiler::Options compiler_options; compiler_options.client = client; - compiler_options.device_type = DeviceType(DEVICE_CPU_XLA_JIT); + DeviceType device_type(DEVICE_CPU_XLA_JIT); + compiler_options.device_type = &device_type; compiler_options.flib_def = &graph->flib_def(); compiler_options.graph_def_version = graph->versions().producer(); compiler_options.allow_cpu_custom_calls = true; diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index ccbc74e..f709891 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -83,9 +83,12 @@ XlaCompiler::XlaCompiler(XlaCompiler::Options options) : options_(options), initialization_status_(Status::OK()), next_step_id_(1), - device_(new XlaCompilationDevice(SessionOptions(), options_.device_type)), + device_( + new XlaCompilationDevice(SessionOptions(), *options_.device_type)), device_mgr_({device_}) { - CHECK(!options_.device_type.type_string().empty()); + // We no longer need the device_type. + options_.device_type = nullptr; + if (options_.populate_resource_manager) { initialization_status_ = (*options_.populate_resource_manager)(device_->resource_manager()); @@ -656,59 +659,6 @@ Status XlaCompiler::CompileSingleOp( return CompileGraph(options, name, std::move(graph), args, result); } -namespace { - -// Check that the ops of all non-functional nodes have been registered. -string ValidateFunctionDef(const FunctionDef* fdef, - const FunctionLibraryDefinition& flib_def) { - std::vector invalid_ops; - for (const NodeDef& node : fdef->node_def()) { - const string& op = node.op(); - if (op == FunctionLibraryDefinition::kGradientOp || flib_def.Find(op)) { - continue; - } - const OpDef* op_def; - if (!OpRegistry::Global()->LookUpOpDef(op, &op_def).ok()) { - invalid_ops.push_back(op); - } - } - return tensorflow::str_util::Join(invalid_ops, ", "); -} - -// Check that the graph doesn't have any nodes incompatible with given -// device_type. -Status ValidateGraph(const Graph* graph, - const FunctionLibraryDefinition& flib_def, - const DeviceType& device_type, const string& name) { - std::vector invalid_ops; - for (const Node* node : graph->nodes()) { - if (node->type_string() == FunctionLibraryDefinition::kGradientOp) { - continue; - } - const FunctionDef* fdef = flib_def.Find(node->def().op()); - if (fdef) { - string error_msg = ValidateFunctionDef(fdef, flib_def); - if (!error_msg.empty()) { - invalid_ops.push_back( - strings::StrCat(node->def().op(), ":{", error_msg, "}")); - } - continue; - } - if (!FindKernelDef(device_type, node->def(), nullptr, nullptr).ok()) { - invalid_ops.push_back(node->def().op()); - } - } - if (!invalid_ops.empty()) { - return errors::InvalidArgument(strings::StrCat( - "Detected unsupported operations when trying to compile graph ", name, - " on ", device_type.type_string(), ":", - tensorflow::str_util::Join(invalid_ops, ", "))); - } - return Status::OK(); -} - -} // namespace - Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, string const& name, std::unique_ptr graph, @@ -731,11 +681,6 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, FunctionalizeControlFlow(flib_runtime_->GetFunctionLibraryDefinition(), graph.get(), local_flib_def_.get())); - // Detect ops incompatible with the device_type. - // FunctionalizeControlFlow may remove some unsupported ops. - TF_RETURN_IF_ERROR(ValidateGraph(graph.get(), *options_.flib_def, - options_.device_type, name)); - xla::XlaBuilder builder(name); XlaContext* context = new XlaContext( this, &builder, options_.allow_cpu_custom_calls, diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h index 76f4c4c..bf496bd 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.h +++ b/tensorflow/compiler/tf2xla/xla_compiler.h @@ -18,7 +18,6 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h" #include "tensorflow/compiler/tf2xla/xla_compilation_device.h" -#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_mgr.h" @@ -245,9 +244,9 @@ class XlaCompiler { typedef std::function ShapeRepresentationFn; struct Options { - // Name of the compilation device to use. It must be set by the caller. - // The default empty value is invalid. - DeviceType device_type = DeviceType(""); + // Name of the compilation device to use. Needs to be live only during + // XlaCompiler's constructor. + const DeviceType* device_type = nullptr; xla::Client* client = nullptr; diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc index 246b386..55772ca 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc @@ -45,6 +45,8 @@ namespace tensorflow { class XlaCompilerTest : public ::testing::Test { protected: + XlaCompilerTest() : cpu_device_type_(DEVICE_CPU_XLA_JIT) {} + void SetUp() override { client_ = xla::ClientLibrary::LocalClientOrDie(); @@ -56,7 +58,7 @@ class XlaCompilerTest : public ::testing::Test { XlaCompiler::Options DefaultOptions() { XlaCompiler::Options options; - options.device_type = DeviceType(DEVICE_CPU_XLA_JIT); + options.device_type = &cpu_device_type_; options.client = client_; options.flib_def = flib_def_.get(); return options; @@ -66,6 +68,7 @@ class XlaCompilerTest : public ::testing::Test { return compiler->local_flib_def_.get(); } + DeviceType cpu_device_type_; xla::Client* client_; std::unique_ptr flib_def_; }; @@ -976,54 +979,5 @@ TEST_F(XlaCompilerTest, ArgRetvalShapeRepresentationFunction) { EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); } -// Tests a graph which has a function with an invalid op. -TEST_F(XlaCompilerTest, FunctionWithInvalidOp) { - XlaCompiler compiler(DefaultOptions()); - - FunctionDefLibrary flib; - FunctionDef fn = FillFn(); - NodeDef* node = fn.add_node_def(); - node->set_name("Invalid"); - node->set_op("InvalidOp"); /* unsupported op */ - node = fn.add_node_def(); - node->set_name("Switch"); - node->set_op("Switch"); /* control flow node */ - *flib.add_function() = fn; - - TF_ASSERT_OK(flib_def_->AddFunctionDef(fn)); - - std::unique_ptr graph(new Graph(OpRegistry::Global())); - - Scope scope = Scope::NewRootScope().ExitOnError(); - auto value = ops::Const(scope.WithOpName("value"), 1, {}); - auto shape = ops::Const(scope.WithOpName("shape"), {5}, {1}); - TF_ASSERT_OK(scope.graph()->AddFunctionLibrary(flib)); - - NodeDef def; - TF_ASSERT_OK(NodeDefBuilder("fill_fn", "FillFn", flib_def_.get()) - .Input(value.name(), 0, DT_INT32) - .Input(shape.name(), 1, DT_INT32) - .Finalize(&def)); - Status status; - Node* fill = scope.graph()->AddNode(def, &status); - TF_ASSERT_OK(status); - TF_ASSERT_OK(scope.DoShapeInference(fill)); - scope.graph()->AddEdge(value.node(), 0, fill, 0); - scope.graph()->AddEdge(shape.node(), 0, fill, 1); - - auto retval = ops::_Retval(scope.WithOpName("retval"), Output(fill), 0); - - TF_ASSERT_OK(scope.ToGraph(graph.get())); - - std::vector args; - XlaCompiler::CompilationResult result; - status = compiler.CompileGraph(XlaCompiler::CompileOptions(), "fill", - std::move(graph), args, &result); - ASSERT_FALSE(status.ok()); - EXPECT_TRUE( - str_util::StrContains(status.error_message(), "FillFn:{InvalidOp}")) - << status.error_message(); -} - } // namespace } // namespace tensorflow -- 2.7.4