From: A. Unique TensorFlower Date: Mon, 16 Apr 2018 20:32:12 +0000 (-0700) Subject: Adding several utility functions to TF2XLA to help with the Cholesky refactor. Mainl... X-Git-Tag: upstream/v1.9.0_rc1~287^2~1^2~38 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=c877eb3fcdff70ed43bfbd54df9eb678e3268eb5;p=platform%2Fupstream%2Ftensorflow.git Adding several utility functions to TF2XLA to help with the Cholesky refactor. Mainly responsible for handling batching properly. PiperOrigin-RevId: 193090634 --- diff --git a/tensorflow/compiler/tf2xla/lib/BUILD b/tensorflow/compiler/tf2xla/lib/BUILD index 344773c..ea6e1a4 100644 --- a/tensorflow/compiler/tf2xla/lib/BUILD +++ b/tensorflow/compiler/tf2xla/lib/BUILD @@ -126,6 +126,30 @@ cc_library( ], ) +xla_test( + name = "util_test", + srcs = ["util_test.cc"], + deps = [ + ":batch_dot", + ":util", + "//tensorflow/compiler/xla:array2d", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client:global_data", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + cc_library( name = "while_loop", srcs = ["while_loop.cc"], diff --git a/tensorflow/compiler/tf2xla/lib/util.cc b/tensorflow/compiler/tf2xla/lib/util.cc index f579669..31d823c 100644 --- a/tensorflow/compiler/tf2xla/lib/util.cc +++ b/tensorflow/compiler/tf2xla/lib/util.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -140,13 +140,47 @@ xla::StatusOr SliceInMinorDims( return builder->Slice(x, padded_start, padded_end, strides); } +std::vector PrependMajorDims(xla::ComputationBuilder* builder, + const gtl::ArraySlice& major_dims, + const gtl::ArraySlice& indices) { + std::vector output(indices.size() + major_dims.size()); + std::copy(major_dims.begin(), major_dims.end(), output.begin()); + std::copy(indices.begin(), indices.end(), output.begin() + major_dims.size()); + return output; +} + +xla::StatusOr DynamicSliceInMinorDims( + xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x, + const std::vector& starts, + const gtl::ArraySlice& sizes) { + TF_ASSIGN_OR_RETURN(std::unique_ptr shape, builder->GetShape(x)); + const int64 n_dims = xla::ShapeUtil::Rank(*shape); + int64 n_minor_dims = starts.size(); + TF_RET_CHECK(n_minor_dims == sizes.size()); + TF_RET_CHECK(n_minor_dims <= n_dims); + gtl::ArraySlice major_dims(xla::AsInt64Slice(shape->dimensions()), + /*pos=*/0, + /*len=*/n_dims - sizes.size()); + TF_ASSIGN_OR_RETURN(auto padded_starts, + PrependZerosInMajorDims(builder, x, starts)); + auto padded_sizes = PrependMajorDims(builder, major_dims, sizes); + return builder->DynamicSlice(x, padded_starts, padded_sizes); +} + xla::StatusOr UpdateSlice( xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x, const xla::ComputationDataHandle& update, gtl::ArraySlice start) { // TODO(phawkins): make int64 work on all backends, remove the int32 cast. std::vector start_as_int32(start.begin(), start.end()); - return builder->DynamicUpdateSlice( - x, update, builder->ConstantR1(start_as_int32)); + auto start_constant = builder->ConstantR1(start_as_int32); + TF_ASSIGN_OR_RETURN(std::unique_ptr shape, builder->GetShape(x)); + const int64 n_dims = xla::ShapeUtil::Rank(*shape); + TF_ASSIGN_OR_RETURN(std::unique_ptr start_constant_shape, + builder->GetShape(start_constant)); + const int64 start_length = + xla::ShapeUtil::GetDimension(*start_constant_shape, -1); + TF_RET_CHECK(start_length == n_dims); + return builder->DynamicUpdateSlice(x, update, start_constant); } xla::StatusOr UpdateSliceInMinorDims( @@ -162,6 +196,29 @@ xla::StatusOr UpdateSliceInMinorDims( return UpdateSlice(builder, x, update, padded_start); } +xla::StatusOr DynamicUpdateSliceInMinorDims( + xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x, + const xla::ComputationDataHandle& update, + const std::vector& starts) { + TF_ASSIGN_OR_RETURN(auto padded_starts, + PrependZerosInMajorDims(builder, x, starts)); + return builder->DynamicUpdateSlice(x, update, padded_starts); +} + +xla::StatusOr PrependZerosInMajorDims( + xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x, + const std::vector& starts) { + TF_ASSIGN_OR_RETURN(std::unique_ptr shape, builder->GetShape(x)); + const int64 n_dims = xla::ShapeUtil::Rank(*shape); + auto zero = builder->Reshape(builder->ConstantR0(0), {1}); + std::vector padded_starts(n_dims, zero); + for (int i = 0; i < starts.size(); ++i) { + padded_starts[n_dims - starts.size() + i] = + builder->Reshape(starts[i], {1}); + } + return builder->ConcatInDim(padded_starts, 0); +} + xla::StatusOr TransposeInMinorDims( xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x) { TF_ASSIGN_OR_RETURN(std::unique_ptr shape, builder->GetShape(x)); diff --git a/tensorflow/compiler/tf2xla/lib/util.h b/tensorflow/compiler/tf2xla/lib/util.h index 51f8baa..b684123 100644 --- a/tensorflow/compiler/tf2xla/lib/util.h +++ b/tensorflow/compiler/tf2xla/lib/util.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -32,16 +32,39 @@ xla::ComputationDataHandle Zeros(xla::ComputationBuilder* builder, xla::ComputationDataHandle FloatLiteral(xla::ComputationBuilder* builder, xla::PrimitiveType type, double value); +// Makes a 1D tensor [0, ..., x, y] from two tensors x and y with zeros +// prepended until the array is length n_dims. +xla::ComputationDataHandle PrependZerosInMajorDims( + xla::ComputationBuilder* builder, + gtl::ArraySlice starts); + // Returns a integer scalar constant of 'type' with 'value'. // If 'type' is complex, returns a real value with zero imaginary component. xla::ComputationDataHandle IntegerLiteral(xla::ComputationBuilder* builder, xla::PrimitiveType type, int64 value); +// Builds a vector of zeros of length rank(x) with the last two values being +// those in `starts`. +xla::StatusOr PrependZerosInMajorDims( + xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x, + const std::vector& starts); + // Performs a slice in the minor dimensions of a Tensor. xla::StatusOr SliceInMinorDims( xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x, gtl::ArraySlice start, gtl::ArraySlice end); +// Builds a 1-d vector out of a concatenation of `major_dims` and `starts`. +std::vector PrependMajorDims(xla::ComputationBuilder* builder, + const gtl::ArraySlice& major_dims, + const gtl::ArraySlice& indices); + +// Performs a dynamic slice in the minor dimensions of a Tensor. +xla::StatusOr DynamicSliceInMinorDims( + xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x, + const std::vector& starts, + const gtl::ArraySlice& sizes); + // Updates a slice of 'x', i.e., // x[start[0], ..., start[n]] = update xla::StatusOr UpdateSlice( @@ -54,6 +77,11 @@ xla::StatusOr UpdateSliceInMinorDims( xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x, const xla::ComputationDataHandle& update, gtl::ArraySlice start); +xla::StatusOr DynamicUpdateSliceInMinorDims( + xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x, + const xla::ComputationDataHandle& update, + const std::vector& starts); + // Transposes a stack of matrices `x` by swapping the last two dimensions. xla::StatusOr TransposeInMinorDims( xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x); diff --git a/tensorflow/compiler/tf2xla/lib/util_test.cc b/tensorflow/compiler/tf2xla/lib/util_test.cc new file mode 100644 index 0000000..b6bd33a --- /dev/null +++ b/tensorflow/compiler/tf2xla/lib/util_test.cc @@ -0,0 +1,145 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/lib/util.h" + +#include +#include +#include + +#include "tensorflow/compiler/tf2xla/lib/batch_dot.h" +#include "tensorflow/compiler/xla/array2d.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace tensorflow { +namespace { + +using UtilTest = xla::ClientLibraryTestBase; +using UtilLeftLookingTest = xla::ClientLibraryTestBase; + +xla::Array2D BValsRight() { + return {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}; +} + +xla::Array2D BValsLeft() { + return {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}, {10, 11, 12}}; +} + +xla::Array2D AValsFull() { + return {{2, 0, 1, 2}, {3, 6, 0, 1}, {4, 7, 9, 0}, {5, 8, 10, 11}}; +} + +xla::Array3D BatchedAValsFull() { + return {{ + {2, 0, 1, 2}, + {3, 6, 0, 1}, + {4, 7, 9, 0}, + {5, 8, 10, 11}, + }, + { + {16, 24, 8, 12}, + {24, 61, 82, 48}, + {8, 82, 456, 106}, + {12, 48, 106, 62}, + }}; +} + +XLA_TEST_F(UtilTest, Simple2dLookup) { + xla::ComputationBuilder builder(client_, TestName()); + + xla::ComputationDataHandle a, x, y; + auto a_data = CreateR2Parameter(BValsRight(), 0, "a", &builder, &a); + auto x_data = CreateR0Parameter(2, 1, "x", &builder, &x); + auto y_data = CreateR0Parameter(1, 2, "y", &builder, &y); + auto result = DynamicSliceInMinorDims(&builder, a, {x, y}, {1, 1}); + TF_ASSERT_OK(result.status()); + + ComputeAndCompareR2(&builder, {{10}}, + {a_data.get(), x_data.get(), y_data.get()}, + xla::ErrorSpec(1e-2, 1e-2)); +} + +XLA_TEST_F(UtilTest, Simple3dLookup) { + xla::ComputationBuilder builder(client_, TestName()); + + xla::ComputationDataHandle a, index; + auto a_data = + CreateR3Parameter(BatchedAValsFull(), 0, "a", &builder, &a); + auto index_data = CreateR0Parameter(1, 1, "index", &builder, &index); + + TF_ASSERT_OK_AND_ASSIGN( + auto l_index, + DynamicSliceInMinorDims(&builder, a, + {index, builder.ConstantR0(0)}, {1, 4})); + + ComputeAndCompareR3(&builder, {{{3, 6, 0, 1}}, {{24, 61, 82, 48}}}, + {a_data.get(), index_data.get()}); +} + +XLA_TEST_F(UtilTest, SimpleSliceUpdate) { + xla::ComputationBuilder builder(client_, TestName()); + + xla::ComputationDataHandle a, b, x, y; + auto a_data = CreateR2Parameter(AValsFull(), 0, "a", &builder, &a); + auto b_data = CreateR2Parameter({{9, 1, -10}}, 1, "b", &builder, &b); + auto x_data = CreateR0Parameter(2, 2, "x", &builder, &x); + auto y_data = CreateR0Parameter(1, 3, "y", &builder, &y); + + auto result = DynamicUpdateSliceInMinorDims(&builder, a, b, {x, y}); + TF_ASSERT_OK(result.status()); + + xla::Array2D expected( + {{{2, 0, 1, 2}, {3, 6, 0, 1}, {4, 9, 1, -10}, {5, 8, 10, 11}}}); + + ComputeAndCompareR2( + &builder, expected, + {a_data.get(), b_data.get(), x_data.get(), y_data.get()}); +} + +XLA_TEST_F(UtilTest, RowBatchDot) { + xla::ComputationBuilder builder(client_, TestName()); + + int n = 4; + + xla::ComputationDataHandle a, row, index; + auto a_data = + CreateR3Parameter(BatchedAValsFull(), 0, "a", &builder, &a); + auto row_data = CreateR3Parameter({{{9, 1, 0, 0}}, {{2, 4, 0, 0}}}, 1, + "row", &builder, &row); + // Select {{3, 6, 0, 1}, {24, 61, 82, 48}} out of BatchedAValsFull(). + auto index_data = CreateR0Parameter(1, 2, "index", &builder, &index); + + TF_ASSERT_OK_AND_ASSIGN( + auto l_index, + DynamicSliceInMinorDims(&builder, a, + {index, builder.ConstantR0(0)}, {1, n})); + TF_ASSERT_OK_AND_ASSIGN( + auto dot, BatchDot(&builder, l_index, row, + /*transpose_x=*/false, /*transpose_y=*/true)); + + ComputeAndCompareR3(&builder, {{{33}}, {{292}}}, + {a_data.get(), row_data.get(), index_data.get()}); +} + +} // namespace +} // namespace tensorflow