Adding several utility functions to TF2XLA to help with the Cholesky refactor. Mainl...
authorA. Unique TensorFlower <gardener@tensorflow.org>
Mon, 16 Apr 2018 20:32:12 +0000 (13:32 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 16 Apr 2018 20:34:26 +0000 (13:34 -0700)
PiperOrigin-RevId: 193090634

tensorflow/compiler/tf2xla/lib/BUILD
tensorflow/compiler/tf2xla/lib/util.cc
tensorflow/compiler/tf2xla/lib/util.h
tensorflow/compiler/tf2xla/lib/util_test.cc [new file with mode: 0644]

index 344773c..ea6e1a4 100644 (file)
@@ -126,6 +126,30 @@ cc_library(
     ],
 )
 
+xla_test(
+    name = "util_test",
+    srcs = ["util_test.cc"],
+    deps = [
+        ":batch_dot",
+        ":util",
+        "//tensorflow/compiler/xla:array2d",
+        "//tensorflow/compiler/xla:literal_util",
+        "//tensorflow/compiler/xla:shape_util",
+        "//tensorflow/compiler/xla:statusor",
+        "//tensorflow/compiler/xla:test",
+        "//tensorflow/compiler/xla:types",
+        "//tensorflow/compiler/xla:xla_data_proto",
+        "//tensorflow/compiler/xla/client:computation_builder",
+        "//tensorflow/compiler/xla/client:global_data",
+        "//tensorflow/compiler/xla/client:local_client",
+        "//tensorflow/compiler/xla/tests:client_library_test_base",
+        "//tensorflow/compiler/xla/tests:literal_test_util",
+        "//tensorflow/compiler/xla/tests:xla_internal_test_main",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:test",
+    ],
+)
+
 cc_library(
     name = "while_loop",
     srcs = ["while_loop.cc"],
index f579669..31d823c 100644 (file)
@@ -1,4 +1,4 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
 
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
@@ -140,13 +140,47 @@ xla::StatusOr<xla::ComputationDataHandle> SliceInMinorDims(
   return builder->Slice(x, padded_start, padded_end, strides);
 }
 
+std::vector<int64> PrependMajorDims(xla::ComputationBuilder* builder,
+                                    const gtl::ArraySlice<int64>& major_dims,
+                                    const gtl::ArraySlice<int64>& indices) {
+  std::vector<int64> output(indices.size() + major_dims.size());
+  std::copy(major_dims.begin(), major_dims.end(), output.begin());
+  std::copy(indices.begin(), indices.end(), output.begin() + major_dims.size());
+  return output;
+}
+
+xla::StatusOr<xla::ComputationDataHandle> DynamicSliceInMinorDims(
+    xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x,
+    const std::vector<xla::ComputationDataHandle>& starts,
+    const gtl::ArraySlice<int64>& sizes) {
+  TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> shape, builder->GetShape(x));
+  const int64 n_dims = xla::ShapeUtil::Rank(*shape);
+  int64 n_minor_dims = starts.size();
+  TF_RET_CHECK(n_minor_dims == sizes.size());
+  TF_RET_CHECK(n_minor_dims <= n_dims);
+  gtl::ArraySlice<int64> major_dims(xla::AsInt64Slice(shape->dimensions()),
+                                    /*pos=*/0,
+                                    /*len=*/n_dims - sizes.size());
+  TF_ASSIGN_OR_RETURN(auto padded_starts,
+                      PrependZerosInMajorDims(builder, x, starts));
+  auto padded_sizes = PrependMajorDims(builder, major_dims, sizes);
+  return builder->DynamicSlice(x, padded_starts, padded_sizes);
+}
+
 xla::StatusOr<xla::ComputationDataHandle> UpdateSlice(
     xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x,
     const xla::ComputationDataHandle& update, gtl::ArraySlice<int64> start) {
   // TODO(phawkins): make int64 work on all backends, remove the int32 cast.
   std::vector<int32> start_as_int32(start.begin(), start.end());
-  return builder->DynamicUpdateSlice(
-      x, update, builder->ConstantR1<int32>(start_as_int32));
+  auto start_constant = builder->ConstantR1<int32>(start_as_int32);
+  TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> shape, builder->GetShape(x));
+  const int64 n_dims = xla::ShapeUtil::Rank(*shape);
+  TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> start_constant_shape,
+                      builder->GetShape(start_constant));
+  const int64 start_length =
+      xla::ShapeUtil::GetDimension(*start_constant_shape, -1);
+  TF_RET_CHECK(start_length == n_dims);
+  return builder->DynamicUpdateSlice(x, update, start_constant);
 }
 
 xla::StatusOr<xla::ComputationDataHandle> UpdateSliceInMinorDims(
@@ -162,6 +196,29 @@ xla::StatusOr<xla::ComputationDataHandle> UpdateSliceInMinorDims(
   return UpdateSlice(builder, x, update, padded_start);
 }
 
+xla::StatusOr<xla::ComputationDataHandle> DynamicUpdateSliceInMinorDims(
+    xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x,
+    const xla::ComputationDataHandle& update,
+    const std::vector<xla::ComputationDataHandle>& starts) {
+  TF_ASSIGN_OR_RETURN(auto padded_starts,
+                      PrependZerosInMajorDims(builder, x, starts));
+  return builder->DynamicUpdateSlice(x, update, padded_starts);
+}
+
+xla::StatusOr<xla::ComputationDataHandle> PrependZerosInMajorDims(
+    xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x,
+    const std::vector<xla::ComputationDataHandle>& starts) {
+  TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> shape, builder->GetShape(x));
+  const int64 n_dims = xla::ShapeUtil::Rank(*shape);
+  auto zero = builder->Reshape(builder->ConstantR0<int32>(0), {1});
+  std::vector<xla::ComputationDataHandle> padded_starts(n_dims, zero);
+  for (int i = 0; i < starts.size(); ++i) {
+    padded_starts[n_dims - starts.size() + i] =
+        builder->Reshape(starts[i], {1});
+  }
+  return builder->ConcatInDim(padded_starts, 0);
+}
+
 xla::StatusOr<xla::ComputationDataHandle> TransposeInMinorDims(
     xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x) {
   TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> shape, builder->GetShape(x));
index 51f8baa..b684123 100644 (file)
@@ -1,4 +1,4 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
 
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
@@ -32,16 +32,39 @@ xla::ComputationDataHandle Zeros(xla::ComputationBuilder* builder,
 xla::ComputationDataHandle FloatLiteral(xla::ComputationBuilder* builder,
                                         xla::PrimitiveType type, double value);
 
+// Makes a 1D tensor [0, ..., x, y] from two tensors x and y with zeros
+// prepended until the array is length n_dims.
+xla::ComputationDataHandle PrependZerosInMajorDims(
+    xla::ComputationBuilder* builder,
+    gtl::ArraySlice<xla::ComputationDataHandle> starts);
+
 // Returns a integer scalar constant of 'type' with 'value'.
 // If 'type' is complex, returns a real value with zero imaginary component.
 xla::ComputationDataHandle IntegerLiteral(xla::ComputationBuilder* builder,
                                           xla::PrimitiveType type, int64 value);
 
+// Builds a vector of zeros of length rank(x) with the last two values being
+// those in `starts`.
+xla::StatusOr<xla::ComputationDataHandle> PrependZerosInMajorDims(
+    xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x,
+    const std::vector<xla::ComputationDataHandle>& starts);
+
 // Performs a slice in the minor dimensions of a Tensor.
 xla::StatusOr<xla::ComputationDataHandle> SliceInMinorDims(
     xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x,
     gtl::ArraySlice<int64> start, gtl::ArraySlice<int64> end);
 
+// Builds a 1-d vector out of a concatenation of `major_dims` and `starts`.
+std::vector<int64> PrependMajorDims(xla::ComputationBuilder* builder,
+                                    const gtl::ArraySlice<int64>& major_dims,
+                                    const gtl::ArraySlice<int64>& indices);
+
+// Performs a dynamic slice in the minor dimensions of a Tensor.
+xla::StatusOr<xla::ComputationDataHandle> DynamicSliceInMinorDims(
+    xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x,
+    const std::vector<xla::ComputationDataHandle>& starts,
+    const gtl::ArraySlice<int64>& sizes);
+
 // Updates a slice of 'x', i.e.,
 // x[start[0], ..., start[n]] = update
 xla::StatusOr<xla::ComputationDataHandle> UpdateSlice(
@@ -54,6 +77,11 @@ xla::StatusOr<xla::ComputationDataHandle> UpdateSliceInMinorDims(
     xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x,
     const xla::ComputationDataHandle& update, gtl::ArraySlice<int64> start);
 
+xla::StatusOr<xla::ComputationDataHandle> DynamicUpdateSliceInMinorDims(
+    xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x,
+    const xla::ComputationDataHandle& update,
+    const std::vector<xla::ComputationDataHandle>& starts);
+
 // Transposes a stack of matrices `x` by swapping the last two dimensions.
 xla::StatusOr<xla::ComputationDataHandle> TransposeInMinorDims(
     xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x);
diff --git a/tensorflow/compiler/tf2xla/lib/util_test.cc b/tensorflow/compiler/tf2xla/lib/util_test.cc
new file mode 100644 (file)
index 0000000..b6bd33a
--- /dev/null
@@ -0,0 +1,145 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/lib/util.h"
+
+#include <memory>
+#include <numeric>
+#include <vector>
+
+#include "tensorflow/compiler/tf2xla/lib/batch_dot.h"
+#include "tensorflow/compiler/xla/array2d.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/test.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+
+namespace tensorflow {
+namespace {
+
+using UtilTest = xla::ClientLibraryTestBase;
+using UtilLeftLookingTest = xla::ClientLibraryTestBase;
+
+xla::Array2D<float> BValsRight() {
+  return {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}};
+}
+
+xla::Array2D<float> BValsLeft() {
+  return {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}, {10, 11, 12}};
+}
+
+xla::Array2D<float> AValsFull() {
+  return {{2, 0, 1, 2}, {3, 6, 0, 1}, {4, 7, 9, 0}, {5, 8, 10, 11}};
+}
+
+xla::Array3D<float> BatchedAValsFull() {
+  return {{
+              {2, 0, 1, 2},
+              {3, 6, 0, 1},
+              {4, 7, 9, 0},
+              {5, 8, 10, 11},
+          },
+          {
+              {16, 24, 8, 12},
+              {24, 61, 82, 48},
+              {8, 82, 456, 106},
+              {12, 48, 106, 62},
+          }};
+}
+
+XLA_TEST_F(UtilTest, Simple2dLookup) {
+  xla::ComputationBuilder builder(client_, TestName());
+
+  xla::ComputationDataHandle a, x, y;
+  auto a_data = CreateR2Parameter<float>(BValsRight(), 0, "a", &builder, &a);
+  auto x_data = CreateR0Parameter<int>(2, 1, "x", &builder, &x);
+  auto y_data = CreateR0Parameter<int>(1, 2, "y", &builder, &y);
+  auto result = DynamicSliceInMinorDims(&builder, a, {x, y}, {1, 1});
+  TF_ASSERT_OK(result.status());
+
+  ComputeAndCompareR2<float>(&builder, {{10}},
+                             {a_data.get(), x_data.get(), y_data.get()},
+                             xla::ErrorSpec(1e-2, 1e-2));
+}
+
+XLA_TEST_F(UtilTest, Simple3dLookup) {
+  xla::ComputationBuilder builder(client_, TestName());
+
+  xla::ComputationDataHandle a, index;
+  auto a_data =
+      CreateR3Parameter<float>(BatchedAValsFull(), 0, "a", &builder, &a);
+  auto index_data = CreateR0Parameter<int>(1, 1, "index", &builder, &index);
+
+  TF_ASSERT_OK_AND_ASSIGN(
+      auto l_index,
+      DynamicSliceInMinorDims(&builder, a,
+                              {index, builder.ConstantR0<int32>(0)}, {1, 4}));
+
+  ComputeAndCompareR3<float>(&builder, {{{3, 6, 0, 1}}, {{24, 61, 82, 48}}},
+                             {a_data.get(), index_data.get()});
+}
+
+XLA_TEST_F(UtilTest, SimpleSliceUpdate) {
+  xla::ComputationBuilder builder(client_, TestName());
+
+  xla::ComputationDataHandle a, b, x, y;
+  auto a_data = CreateR2Parameter<float>(AValsFull(), 0, "a", &builder, &a);
+  auto b_data = CreateR2Parameter<float>({{9, 1, -10}}, 1, "b", &builder, &b);
+  auto x_data = CreateR0Parameter<int>(2, 2, "x", &builder, &x);
+  auto y_data = CreateR0Parameter<int>(1, 3, "y", &builder, &y);
+
+  auto result = DynamicUpdateSliceInMinorDims(&builder, a, b, {x, y});
+  TF_ASSERT_OK(result.status());
+
+  xla::Array2D<float> expected(
+      {{{2, 0, 1, 2}, {3, 6, 0, 1}, {4, 9, 1, -10}, {5, 8, 10, 11}}});
+
+  ComputeAndCompareR2<float>(
+      &builder, expected,
+      {a_data.get(), b_data.get(), x_data.get(), y_data.get()});
+}
+
+XLA_TEST_F(UtilTest, RowBatchDot) {
+  xla::ComputationBuilder builder(client_, TestName());
+
+  int n = 4;
+
+  xla::ComputationDataHandle a, row, index;
+  auto a_data =
+      CreateR3Parameter<float>(BatchedAValsFull(), 0, "a", &builder, &a);
+  auto row_data = CreateR3Parameter<float>({{{9, 1, 0, 0}}, {{2, 4, 0, 0}}}, 1,
+                                           "row", &builder, &row);
+  // Select {{3, 6, 0, 1}, {24, 61,  82,  48}} out of BatchedAValsFull().
+  auto index_data = CreateR0Parameter<int>(1, 2, "index", &builder, &index);
+
+  TF_ASSERT_OK_AND_ASSIGN(
+      auto l_index,
+      DynamicSliceInMinorDims(&builder, a,
+                              {index, builder.ConstantR0<int32>(0)}, {1, n}));
+  TF_ASSERT_OK_AND_ASSIGN(
+      auto dot, BatchDot(&builder, l_index, row,
+                         /*transpose_x=*/false, /*transpose_y=*/true));
+
+  ComputeAndCompareR3<float>(&builder, {{{33}}, {{292}}},
+                             {a_data.get(), row_data.get(), index_data.get()});
+}
+
+}  // namespace
+}  // namespace tensorflow