[TF:XLA] Work around crash in Gather op on CPU backend by making loop bound a compile...
authorPeter Hawkins <phawkins@google.com>
Tue, 13 Feb 2018 06:45:49 +0000 (22:45 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 13 Feb 2018 06:49:12 +0000 (22:49 -0800)
PiperOrigin-RevId: 185486148

tensorflow/compiler/tf2xla/kernels/gather_op.cc
tensorflow/compiler/tf2xla/lib/BUILD
tensorflow/compiler/tf2xla/lib/scatter.cc
tensorflow/compiler/tf2xla/lib/while_loop.cc
tensorflow/compiler/tf2xla/lib/while_loop.h

index 24f7bebdad0b1b1f65ab315065a577c9169fe884..7945c05af40df21a798a2cff51fe7f8e935793f6 100644 (file)
@@ -153,9 +153,9 @@ Status XlaGather(const xla::ComputationDataHandle& input,
   };
 
   // Construct the While loop, extract and reshape the output.
-  auto num_indices_value =
-      XlaHelpers::IntegerLiteral(builder, index_type, num_indices);
-  TF_ASSIGN_OR_RETURN(auto outputs, XlaForEachIndex(num_indices_value, body_fn,
+  xla::PrimitiveType ptype;
+  TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(index_type, &ptype));
+  TF_ASSIGN_OR_RETURN(auto outputs, XlaForEachIndex(num_indices, ptype, body_fn,
                                                     init, "gather", builder));
   *gather_output = builder->Reshape(outputs[2], out_shape.dim_sizes());
   return Status::OK();
index da2fdd1d34737ff679c8831cb7a852b9383a4830..488fda74bf7b5c1d66f8d706a1be3cc1fc29a492 100644 (file)
@@ -131,6 +131,7 @@ cc_library(
     srcs = ["while_loop.cc"],
     hdrs = ["while_loop.h"],
     deps = [
+        ":util",
         "//tensorflow/compiler/xla:shape_util",
         "//tensorflow/compiler/xla:status_macros",
         "//tensorflow/compiler/xla:statusor",
index e8026ecc8d7bf29a448847ad467ae8575f191e8a..6009243f9774eea24e8049e2bd50fe32f291132f 100644 (file)
@@ -180,10 +180,9 @@ xla::StatusOr<xla::ComputationDataHandle> XlaScatter(
     return std::vector<xla::ComputationDataHandle>{indices, updates, buffer};
   };
 
-  xla::ComputationDataHandle num_indices_value =
-      IntegerLiteral(builder, indices_shape->element_type(), num_indices);
-  TF_ASSIGN_OR_RETURN(auto outputs, XlaForEachIndex(num_indices_value, body_fn,
-                                                    init, "scatter", builder));
+  TF_ASSIGN_OR_RETURN(
+      auto outputs, XlaForEachIndex(num_indices, indices_shape->element_type(),
+                                    body_fn, init, "scatter", builder));
   return outputs[2];
 }
 
index d35f6cd13fb23be3e9cff803edb8e3a509d0076a..86c02ac2e65c12d3527c4022df0cc603e522ef7a 100644 (file)
@@ -14,6 +14,7 @@ limitations under the License.
 ==============================================================================*/
 
 #include "tensorflow/compiler/tf2xla/lib/while_loop.h"
+#include "tensorflow/compiler/tf2xla/lib/util.h"
 #include "tensorflow/compiler/xla/shape_util.h"
 #include "tensorflow/compiler/xla/status_macros.h"
 
@@ -79,19 +80,16 @@ xla::StatusOr<std::vector<xla::ComputationDataHandle>> XlaWhileLoop(
 }
 
 xla::StatusOr<std::vector<xla::ComputationDataHandle>> XlaForEachIndex(
-    const xla::ComputationDataHandle& num_iterations,
+    int64 num_iterations, xla::PrimitiveType num_iterations_type,
     const ForEachIndexBodyFunction& body_function,
     gtl::ArraySlice<xla::ComputationDataHandle> initial_values,
     StringPiece name, xla::ComputationBuilder* builder) {
-  TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> num_iterations_shape,
-                      builder->GetShape(num_iterations));
-  TF_RET_CHECK(xla::ShapeUtil::IsScalar(*num_iterations_shape));
-  xla::PrimitiveType num_iterations_type = num_iterations_shape->element_type();
-
   auto while_cond_fn = [&](gtl::ArraySlice<xla::ComputationDataHandle> values,
                            xla::ComputationBuilder* cond_builder)
       -> xla::StatusOr<xla::ComputationDataHandle> {
-    return cond_builder->Lt(values[0], values[1]);
+    return cond_builder->Lt(
+        values[0],
+        IntegerLiteral(cond_builder, num_iterations_type, num_iterations));
   };
   auto while_body_fn = [&](gtl::ArraySlice<xla::ComputationDataHandle> values,
                            xla::ComputationBuilder* body_builder)
@@ -103,9 +101,8 @@ xla::StatusOr<std::vector<xla::ComputationDataHandle>> XlaForEachIndex(
     updated_values.push_back(body_builder->Add(
         iteration,
         body_builder->ConstantLiteral(xla::Literal::One(num_iterations_type))));
-    updated_values.push_back(values[1]);
 
-    values.remove_prefix(2);
+    values.remove_prefix(1);
     TF_ASSIGN_OR_RETURN(std::vector<xla::ComputationDataHandle> body_outputs,
                         body_function(iteration, values, body_builder));
     updated_values.insert(updated_values.end(), body_outputs.begin(),
@@ -114,15 +111,14 @@ xla::StatusOr<std::vector<xla::ComputationDataHandle>> XlaForEachIndex(
   };
 
   std::vector<xla::ComputationDataHandle> values;
-  values.reserve(initial_values.size() + 2);
+  values.reserve(initial_values.size() + 1);
   values.push_back(
       builder->ConstantLiteral(xla::Literal::Zero(num_iterations_type)));
-  values.push_back(num_iterations);
   values.insert(values.end(), initial_values.begin(), initial_values.end());
 
   TF_ASSIGN_OR_RETURN(values, XlaWhileLoop(while_cond_fn, while_body_fn, values,
                                            name, builder));
-  values.erase(values.begin(), values.begin() + 2);
+  values.erase(values.begin(), values.begin() + 1);
   return values;
 }
 
index 562a589fbf508754dbb7e157ec45723bc4f39e4d..2e67a0c99b6deb65fa16ab2dec1727f5cb5fcb92 100644 (file)
@@ -64,7 +64,7 @@ typedef std::function<xla::StatusOr<std::vector<xla::ComputationDataHandle>>(
     ForEachIndexBodyFunction;
 
 xla::StatusOr<std::vector<xla::ComputationDataHandle>> XlaForEachIndex(
-    const xla::ComputationDataHandle& num_iterations,
+    int64 num_iterations, xla::PrimitiveType num_iterations_type,
     const ForEachIndexBodyFunction& body_function,
     gtl::ArraySlice<xla::ComputationDataHandle> initial_values,
     StringPiece name, xla::ComputationBuilder* builder);