Automated g4 rollback of changelist 198444757
authorA. Unique TensorFlower <gardener@tensorflow.org>
Wed, 30 May 2018 23:17:45 +0000 (16:17 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 30 May 2018 23:20:14 +0000 (16:20 -0700)
PiperOrigin-RevId: 198637528

tensorflow/compiler/jit/kernels/xla_launch_op.cc
tensorflow/compiler/jit/xla_compile_on_demand_op.cc
tensorflow/compiler/tf2xla/tf2xla.cc
tensorflow/compiler/tf2xla/xla_compiler.cc
tensorflow/compiler/tf2xla/xla_compiler.h
tensorflow/compiler/tf2xla/xla_compiler_test.cc

index 27287e0..902fe27 100644 (file)
@@ -148,7 +148,7 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
 
   XlaCompiler::Options options;
   options.client = client;
-  options.device_type = &cache->device_type();
+  options.device_type = cache->device_type();
   options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition();
   options.graph_def_version = ctx->function_library()->graph_def_version();
   options.allow_cpu_custom_calls = (platform_id_ == se::host::kHostPlatformId);
index ab644ff..b1943d3 100644 (file)
@@ -151,8 +151,7 @@ Status XlaCompileOnDemandOp::Compile(
   core::ScopedUnref cache_ref(cache);
 
   XlaCompiler::Options options;
-  DeviceType device_type = metadata.jit_device_type();
-  options.device_type = &device_type;
+  options.device_type = metadata.jit_device_type();
   options.client = metadata.client();
   options.flib_def =
       new FunctionLibraryDefinition(OpRegistry::Global(), FunctionDefLibrary{});
index 3a08aa8..ac768b2 100644 (file)
@@ -263,8 +263,7 @@ Status ConvertGraphToXla(std::unique_ptr<Graph> graph, xla::Client* client,
   // Compile the graph into an XLA computation.
   XlaCompiler::Options compiler_options;
   compiler_options.client = client;
-  DeviceType device_type(DEVICE_CPU_XLA_JIT);
-  compiler_options.device_type = &device_type;
+  compiler_options.device_type = DeviceType(DEVICE_CPU_XLA_JIT);
   compiler_options.flib_def = &graph->flib_def();
   compiler_options.graph_def_version = graph->versions().producer();
   compiler_options.allow_cpu_custom_calls = true;
index f709891..2fce616 100644 (file)
@@ -83,12 +83,9 @@ XlaCompiler::XlaCompiler(XlaCompiler::Options options)
     : options_(options),
       initialization_status_(Status::OK()),
       next_step_id_(1),
-      device_(
-          new XlaCompilationDevice(SessionOptions(), *options_.device_type)),
+      device_(new XlaCompilationDevice(SessionOptions(), options_.device_type)),
       device_mgr_({device_}) {
-  // We no longer need the device_type.
-  options_.device_type = nullptr;
-
+  CHECK(!options_.device_type.type_string().empty());
   if (options_.populate_resource_manager) {
     initialization_status_ =
         (*options_.populate_resource_manager)(device_->resource_manager());
@@ -659,6 +656,65 @@ Status XlaCompiler::CompileSingleOp(
   return CompileGraph(options, name, std::move(graph), args, result);
 }
 
+namespace {
+
+// Check that the ops of all non-functional nodes have been registered.
+string ValidateFunctionDef(const FunctionDef* fdef,
+                           const FunctionLibraryDefinition& flib_def) {
+  std::vector<string> invalid_ops;
+  for (const NodeDef& node : fdef->node_def()) {
+    const string& op = node.op();
+    if (op == FunctionLibraryDefinition::kGradientOp || flib_def.Find(op)) {
+      continue;
+    }
+    const OpDef* op_def;
+    if (!OpRegistry::Global()->LookUpOpDef(op, &op_def).ok()) {
+      invalid_ops.push_back(op);
+    }
+  }
+  return tensorflow::str_util::Join(invalid_ops, ", ");
+}
+
+// Check that the graph doesn't have any invalid nodes (e.g. incompatible with
+// given device_type, invalid data type, missing attributes...)
+Status ValidateGraph(const Graph* graph,
+                     const FunctionLibraryDefinition& flib_def,
+                     const DeviceType& device_type, const string& name) {
+  std::vector<string> invalid_ops;
+  for (const Node* node : graph->nodes()) {
+    if (node->type_string() == FunctionLibraryDefinition::kGradientOp) {
+      continue;
+    }
+    const FunctionDef* fdef = flib_def.Find(node->def().op());
+    if (fdef) {
+      string error_msg = ValidateFunctionDef(fdef, flib_def);
+      if (!error_msg.empty()) {
+        invalid_ops.push_back(
+            strings::StrCat(node->def().op(), ":{", error_msg, "}"));
+      }
+      continue;
+    }
+    const OpDef* op_def;
+    if (!OpRegistry::Global()->LookUpOpDef(node->def().op(), &op_def).ok()) {
+      invalid_ops.push_back(node->def().op());
+      continue;
+    }
+    TF_RETURN_IF_ERROR(ValidateNodeDef(node->def(), *op_def));
+    if (!FindKernelDef(device_type, node->def(), nullptr, nullptr).ok()) {
+      invalid_ops.push_back(node->def().op());
+    }
+  }
+  if (!invalid_ops.empty()) {
+    return errors::InvalidArgument(strings::StrCat(
+        "Detected unsupported operations when trying to compile graph ", name,
+        " on ", device_type.type_string(), ":",
+        tensorflow::str_util::Join(invalid_ops, ", ")));
+  }
+  return Status::OK();
+}
+
+}  // namespace
+
 Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
                                  string const& name,
                                  std::unique_ptr<Graph> graph,
@@ -681,6 +737,11 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
       FunctionalizeControlFlow(flib_runtime_->GetFunctionLibraryDefinition(),
                                graph.get(), local_flib_def_.get()));
 
+  // Detect invalid nodes.
+  // FunctionalizeControlFlow may remove some nodes from the graph.
+  TF_RETURN_IF_ERROR(ValidateGraph(graph.get(), *options_.flib_def,
+                                   options_.device_type, name));
+
   xla::XlaBuilder builder(name);
   XlaContext* context = new XlaContext(
       this, &builder, options_.allow_cpu_custom_calls,
index bf496bd..76f4c4c 100644 (file)
@@ -18,6 +18,7 @@ limitations under the License.
 
 #include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h"
 #include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
 #include "tensorflow/compiler/xla/client/local_client.h"
 #include "tensorflow/core/common_runtime/device.h"
 #include "tensorflow/core/common_runtime/device_mgr.h"
@@ -244,9 +245,9 @@ class XlaCompiler {
   typedef std::function<TensorShape(const TensorShape&, DataType)>
       ShapeRepresentationFn;
   struct Options {
-    // Name of the compilation device to use. Needs to be live only during
-    // XlaCompiler's constructor.
-    const DeviceType* device_type = nullptr;
+    // Name of the compilation device to use. It must be set by the caller.
+    // The default empty value is invalid.
+    DeviceType device_type = DeviceType("");
 
     xla::Client* client = nullptr;
 
index 55772ca..5fbf4b9 100644 (file)
@@ -45,8 +45,6 @@ namespace tensorflow {
 
 class XlaCompilerTest : public ::testing::Test {
  protected:
-  XlaCompilerTest() : cpu_device_type_(DEVICE_CPU_XLA_JIT) {}
-
   void SetUp() override {
     client_ = xla::ClientLibrary::LocalClientOrDie();
 
@@ -58,7 +56,7 @@ class XlaCompilerTest : public ::testing::Test {
 
   XlaCompiler::Options DefaultOptions() {
     XlaCompiler::Options options;
-    options.device_type = &cpu_device_type_;
+    options.device_type = DeviceType(DEVICE_CPU_XLA_JIT);
     options.client = client_;
     options.flib_def = flib_def_.get();
     return options;
@@ -68,7 +66,6 @@ class XlaCompilerTest : public ::testing::Test {
     return compiler->local_flib_def_.get();
   }
 
-  DeviceType cpu_device_type_;
   xla::Client* client_;
   std::unique_ptr<FunctionLibraryDefinition> flib_def_;
 };
@@ -979,5 +976,78 @@ TEST_F(XlaCompilerTest, ArgRetvalShapeRepresentationFunction) {
   EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
 }
 
+// Tests a graph which has a function with an invalid op.
+TEST_F(XlaCompilerTest, FunctionWithInvalidOp) {
+  XlaCompiler compiler(DefaultOptions());
+
+  FunctionDefLibrary flib;
+  FunctionDef fn = FillFn();
+  NodeDef* node = fn.add_node_def();
+  node->set_name("Invalid");
+  node->set_op("InvalidOp"); /* unsupported op */
+  node = fn.add_node_def();
+  node->set_name("Switch");
+  node->set_op("Switch"); /* control flow node */
+  *flib.add_function() = fn;
+
+  TF_ASSERT_OK(flib_def_->AddFunctionDef(fn));
+
+  std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+
+  Scope scope = Scope::NewRootScope().ExitOnError();
+  auto value = ops::Const<int32>(scope.WithOpName("value"), 1, {});
+  auto shape = ops::Const<int32>(scope.WithOpName("shape"), {5}, {1});
+  TF_ASSERT_OK(scope.graph()->AddFunctionLibrary(flib));
+
+  NodeDef def;
+  TF_ASSERT_OK(NodeDefBuilder("fill_fn", "FillFn", flib_def_.get())
+                   .Input(value.name(), 0, DT_INT32)
+                   .Input(shape.name(), 1, DT_INT32)
+                   .Finalize(&def));
+  Status status;
+  Node* fill = scope.graph()->AddNode(def, &status);
+  TF_ASSERT_OK(status);
+  TF_ASSERT_OK(scope.DoShapeInference(fill));
+  scope.graph()->AddEdge(value.node(), 0, fill, 0);
+  scope.graph()->AddEdge(shape.node(), 0, fill, 1);
+
+  auto retval = ops::_Retval(scope.WithOpName("retval"), Output(fill), 0);
+
+  TF_ASSERT_OK(scope.ToGraph(graph.get()));
+
+  std::vector<XlaCompiler::Argument> args;
+  XlaCompiler::CompilationResult result;
+  status = compiler.CompileGraph(XlaCompiler::CompileOptions(), "fill",
+                                 std::move(graph), args, &result);
+  ASSERT_FALSE(status.ok());
+  EXPECT_TRUE(
+      str_util::StrContains(status.error_message(), "FillFn:{InvalidOp}"))
+      << status.error_message();
+}
+
+// Tests a graph which has a node with invalid data type.
+TEST_F(XlaCompilerTest, NodeWithInvalidDataType) {
+  std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+  NodeDef shape;
+  shape.set_name("Shape");
+  shape.set_op("Shape");
+  (*shape.mutable_attr())["T"].set_type(DT_INT32);
+  (*shape.mutable_attr())["out_type"].set_type(DT_BOOL); /* invalid type */
+  Status status;
+  Node* shape_node = graph->AddNode(shape, &status);
+  TF_ASSERT_OK(status);
+  graph->AddControlEdge(graph->source_node(), shape_node);
+
+  std::vector<XlaCompiler::Argument> args;
+  XlaCompiler::CompilationResult result;
+  XlaCompiler compiler(DefaultOptions());
+  status = compiler.CompileGraph(XlaCompiler::CompileOptions(), "invalid_type",
+                                 std::move(graph), args, &result);
+  ASSERT_FALSE(status.ok());
+  EXPECT_TRUE(str_util::StrContains(status.error_message(),
+                                    "is not in the list of allowed values"))
+      << status.error_message();
+}
+
 }  // namespace
 }  // namespace tensorflow