"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/tests:client_library_test_base",
)
cc_library(
- name = "computation_builder",
- srcs = ["computation_builder.cc"],
- hdrs = ["computation_builder.h"],
- deps = [
- ":client",
- ":computation",
- ":global_data",
- ":padding",
- "//tensorflow/compiler/xla:array",
- "//tensorflow/compiler/xla:array2d",
- "//tensorflow/compiler/xla:array3d",
- "//tensorflow/compiler/xla:array4d",
- "//tensorflow/compiler/xla:literal_util",
- "//tensorflow/compiler/xla:shape_util",
- "//tensorflow/compiler/xla:status_macros",
- "//tensorflow/compiler/xla:statusor",
- "//tensorflow/compiler/xla:types",
- "//tensorflow/compiler/xla:util",
- "//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla:xla_proto",
- "//tensorflow/core:lib",
- ],
-)
-
-cc_library(
name = "sharding_builder",
srcs = ["sharding_builder.cc"],
hdrs = ["sharding_builder.h"],
+++ /dev/null
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/compiler/xla/client/computation_builder.h"
-
-#include <stddef.h>
-#include <array>
-#include <numeric>
-#include <set>
-#include <vector>
-
-#include "tensorflow/compiler/xla/ptr_util.h"
-#include "tensorflow/compiler/xla/shape_util.h"
-#include "tensorflow/compiler/xla/status_macros.h"
-#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/compiler/xla/util.h"
-#include "tensorflow/compiler/xla/xla.pb.h"
-#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/strings/strcat.h"
-#include "tensorflow/core/platform/logging.h"
-#include "tensorflow/core/platform/protobuf.h"
-
-namespace xla {
-
-ComputationBuilder::ComputationBuilder(Client* client,
- const string& computation_name)
- : name_(computation_name), client_(client) {}
-
-ComputationBuilder::~ComputationBuilder() {}
-
-void ComputationBuilder::NoteError(const Status& error) {
- if (die_immediately_on_error_) {
- LOG(FATAL) << "error building computation: " << error;
- }
-
- if (first_error_.ok()) {
- first_error_ = error;
- first_error_backtrace_.CreateCurrent(/*skip_count=*/1);
- }
-}
-
-std::unique_ptr<ComputationBuilder> ComputationBuilder::CreateSubBuilder(
- const string& computation_name) {
- auto sub_builder = MakeUnique<ComputationBuilder>(client_, computation_name);
- sub_builder->parent_builder_ = this;
- sub_builder->die_immediately_on_error_ = die_immediately_on_error_;
- return sub_builder;
-}
-
-Status ComputationBuilder::PrepareComputation() {
- TF_RETURN_IF_ERROR(first_error_);
-
- if (!computation_.IsNull()) {
- return Status::OK();
- }
-
- ComputationRequest request;
- request.set_name(name_);
- ComputationResponse response;
-
- VLOG(2) << "making computation request";
- Status s = client_->stub()->Computation(&request, &response);
- VLOG(2) << "done with computation request";
-
- if (!s.ok()) {
- NoteError(s);
- return first_error_;
- }
-
- computation_ = Computation(client_->stub(), response.computation());
- return Status::OK();
-}
-
-Status ComputationBuilder::RunOp(OpRequest* op_request,
- OpResponse* op_response) {
- TF_RETURN_IF_ERROR(first_error_);
- TF_RETURN_IF_ERROR(PrepareComputation());
-
- // Fill in fields that are set on every OpRequest.
- *op_request->mutable_computation() = computation_.handle();
- *op_request->mutable_metadata() = metadata_;
- if (sharding_) {
- *op_request->mutable_sharding() = *sharding_;
- }
-
- const string& op_name =
- OpRequest::descriptor()->FindFieldByNumber(op_request->op_case())->name();
- VLOG(2) << "running op request: " << op_name;
- Status status = client_->stub()->Op(op_request, op_response);
- VLOG(2) << "done with op request: " << op_name;
- return status;
-}
-
-void ComputationBuilder::RunOpAndNoteError(OpRequest* op_request) {
- OpResponse op_response;
- Status status = RunOp(op_request, &op_response);
- if (!status.ok()) {
- NoteError(status);
- }
-}
-
-ComputationDataHandle ComputationBuilder::RunOpAndParseResponse(
- OpRequest* op_request) {
- OpResponse op_response;
- Status status = RunOp(op_request, &op_response);
- if (!status.ok()) {
- NoteError(status);
- return ComputationDataHandle();
- }
- if (op_response.output().handle() == 0) {
- NoteError(InternalError("No output handle"));
- return ComputationDataHandle();
- }
- return op_response.output();
-}
-
-bool ComputationBuilder::MakeWindow(
- tensorflow::gtl::ArraySlice<int64> window_dimensions,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
- tensorflow::gtl::ArraySlice<int64> lhs_dilation,
- tensorflow::gtl::ArraySlice<int64> rhs_dilation, Window* window) {
- const auto verify_size = [&](const size_t x, const char* x_name) {
- if (x == 0 || x == window_dimensions.size()) {
- return true;
- } else {
- NoteError(InvalidArgument(
- "%s", tensorflow::strings::StrCat(
- "Window has different number of window dimensions than of ",
- x_name, "\nNumber of window dimensions: ",
- window_dimensions.size(), "\nNumber of ", x_name, ": ", x,
- "\n")
- .c_str())); //
- return false;
- }
- };
- if (!verify_size(window_strides.size(), "window strides") ||
- !verify_size(padding.size(), "padding entries") ||
- !verify_size(lhs_dilation.size(), "lhs dilation factors") ||
- !verify_size(rhs_dilation.size(), "rhs dilation factors")) {
- return false;
- }
-
- window->Clear();
- for (size_t i = 0; i < window_dimensions.size(); i++) {
- auto dim = window->add_dimensions();
- dim->set_size(window_dimensions[i]);
- if (!window_strides.empty()) {
- dim->set_stride(window_strides[i]);
- } else {
- dim->set_stride(1);
- }
- if (!padding.empty()) {
- dim->set_padding_low(padding[i].first);
- dim->set_padding_high(padding[i].second);
- } else {
- dim->set_padding_low(0);
- dim->set_padding_high(0);
- }
- if (!lhs_dilation.empty()) {
- dim->set_base_dilation(lhs_dilation[i]);
- } else {
- dim->set_base_dilation(1);
- }
- if (!rhs_dilation.empty()) {
- dim->set_window_dilation(rhs_dilation[i]);
- } else {
- dim->set_window_dilation(1);
- }
- dim->set_window_reversal(false);
- }
- return true;
-}
-
-ComputationDataHandle ComputationBuilder::ConstantLiteral(
- const LiteralSlice& literal) {
- OpRequest op_request;
- ConstantRequest* request = op_request.mutable_constant_request();
- *request->mutable_literal() = literal.ToProto();
- VLOG(3) << "created constant: " << request->literal().ShortDebugString();
- return RunOpAndParseResponse(&op_request);
-}
-
-ComputationDataHandle ComputationBuilder::Parameter(int64 parameter_number,
- const Shape& shape,
- const string& name) {
- OpRequest op_request;
- ParameterRequest* request = op_request.mutable_parameter_request();
- *request->mutable_shape() = shape;
- request->set_parameter(parameter_number);
- request->set_name(name);
- return RunOpAndParseResponse(&op_request);
-}
-
-StatusOr<std::unique_ptr<Shape>> ComputationBuilder::GetShapeWithoutNoteError(
- const ComputationDataHandle& operand) {
- GetLocalShapeRequest request;
- *request.mutable_computation() = computation_.handle();
- *request.mutable_operand() = operand;
- GetLocalShapeResponse response;
-
- VLOG(2) << "making get-shape request";
- TF_RETURN_IF_ERROR(client_->stub()->GetLocalShape(&request, &response));
- VLOG(2) << "done with request";
-
- TF_RET_CHECK(response.has_shape());
- std::unique_ptr<Shape> shape = WrapUnique(response.release_shape());
- TF_RET_CHECK(shape != nullptr);
- return std::move(shape);
-}
-
-StatusOr<std::unique_ptr<Shape>> ComputationBuilder::GetShape(
- const ComputationDataHandle& operand) {
- TF_RETURN_IF_ERROR(first_error_);
-
- auto status_or_shape = GetShapeWithoutNoteError(operand);
- if (!status_or_shape.ok()) {
- NoteError(status_or_shape.status());
- return first_error_;
- }
- return status_or_shape;
-}
-
-StatusOr<ProgramShape> ComputationBuilder::GetProgramShape() {
- TF_RETURN_IF_ERROR(first_error_);
-
- GetComputationShapeRequest request;
- *request.mutable_computation() = computation_.handle();
- GetComputationShapeResponse response;
-
- VLOG(2) << "making get-program-shape-request";
- Status status = client_->stub()->GetComputationShape(&request, &response);
- VLOG(2) << "done with get-program-shape-request";
-
- if (!status.ok()) {
- first_error_ = status;
- return status;
- }
-
- TF_RET_CHECK(response.has_program_shape());
- return std::move(*response.mutable_program_shape());
-}
-
-ComputationDataHandle ComputationBuilder::Slice(
- const ComputationDataHandle& operand,
- tensorflow::gtl::ArraySlice<int64> start_indices,
- tensorflow::gtl::ArraySlice<int64> limit_indices,
- tensorflow::gtl::ArraySlice<int64> strides) {
- OpRequest op_request;
- SliceRequest* request = op_request.mutable_slice_request();
- *request->mutable_operand() = operand;
- for (int64 index : start_indices) {
- request->add_start_indices(index);
- }
- for (int64 index : limit_indices) {
- request->add_limit_indices(index);
- }
- for (int64 index : strides) {
- request->add_strides(index);
- }
- return RunOpAndParseResponse(&op_request);
-}
-
-ComputationDataHandle ComputationBuilder::SliceInDim(
- const ComputationDataHandle& operand, int64 start_index, int64 limit_index,
- int64 stride, int64 dimno) {
- StatusOr<std::unique_ptr<Shape>> shape_status = GetShape(operand);
- if (!shape_status.ok()) {
- NoteError(shape_status.status());
- return ComputationDataHandle{};
- }
- const Shape& shape = *shape_status.ValueOrDie();
- std::vector<int64> starts(ShapeUtil::Rank(shape), 0);
- std::vector<int64> limits(shape.dimensions().begin(),
- shape.dimensions().end());
- std::vector<int64> strides(ShapeUtil::Rank(shape), 1);
- starts[dimno] = start_index;
- limits[dimno] = limit_index;
- strides[dimno] = stride;
- return Slice(operand, starts, limits, strides);
-}
-
-ComputationDataHandle ComputationBuilder::DynamicSlice(
- const ComputationDataHandle& operand,
- const ComputationDataHandle& start_indices,
- tensorflow::gtl::ArraySlice<int64> slice_sizes) {
- OpRequest op_request;
- DynamicSliceRequest* request = op_request.mutable_dynamic_slice_request();
- *request->mutable_operand() = operand;
- *request->mutable_start_indices() = start_indices;
- for (int64 index : slice_sizes) {
- request->add_slice_sizes(index);
- }
- return RunOpAndParseResponse(&op_request);
-}
-
-ComputationDataHandle ComputationBuilder::DynamicUpdateSlice(
- const ComputationDataHandle& operand, const ComputationDataHandle& update,
- const ComputationDataHandle& start_indices) {
- OpRequest op_request;
- DynamicUpdateSliceRequest* request =
- op_request.mutable_dynamic_update_slice_request();
- *request->mutable_operand() = operand;
- *request->mutable_update() = update;
- *request->mutable_start_indices() = start_indices;
- return RunOpAndParseResponse(&op_request);
-}
-
-ComputationDataHandle ComputationBuilder::ConcatInDim(
- tensorflow::gtl::ArraySlice<ComputationDataHandle> operands,
- int64 dimension) {
- OpRequest op_request;
- ConcatenateRequest* request = op_request.mutable_concatenate_request();
- for (const ComputationDataHandle& operand : operands) {
- *request->add_operands() = operand;
- }
- request->set_dimension(dimension);
- return RunOpAndParseResponse(&op_request);
-}
-
-ComputationDataHandle ComputationBuilder::Broadcast(
- const ComputationDataHandle& operand,
- tensorflow::gtl::ArraySlice<int64> broadcast_sizes) {
- OpRequest op_request;
- BroadcastRequest* request = op_request.mutable_broadcast_request();
- *request->mutable_operand() = operand;
- for (int64 size : broadcast_sizes) {
- request->add_broadcast_sizes(size);
- }
- return RunOpAndParseResponse(&op_request);
-}
-
-ComputationDataHandle ComputationBuilder::Pad(
- const ComputationDataHandle& operand,
- const ComputationDataHandle& padding_value,
- const PaddingConfig& padding_config) {
- OpRequest op_request;
- PadRequest* request = op_request.mutable_pad_request();
- *request->mutable_operand() = operand;
- *request->mutable_padding_value() = padding_value;
- *request->mutable_padding_config() = padding_config;
- return RunOpAndParseResponse(&op_request);
-}
-
-ComputationDataHandle ComputationBuilder::Reshape(
- const ComputationDataHandle& operand,
- tensorflow::gtl::ArraySlice<int64> dimensions,
- tensorflow::gtl::ArraySlice<int64> new_sizes) {
- OpRequest op_request;
- ReshapeRequest* request = op_request.mutable_reshape_request();
- *request->mutable_operand() = operand;
- for (int64 dimension : dimensions) {
- request->add_dimensions(dimension);
- }
- for (int64 new_size : new_sizes) {
- request->add_new_sizes(new_size);
- }
- return RunOpAndParseResponse(&op_request);
-}
-
-ComputationDataHandle ComputationBuilder::Reshape(
- const ComputationDataHandle& operand,
- tensorflow::gtl::ArraySlice<int64> new_sizes) {
- if (!first_error_.ok()) {
- return ComputationDataHandle();
- }
-
- StatusOr<std::unique_ptr<Shape>> shape = GetShape(operand);
- if (!shape.ok()) {
- return ComputationDataHandle();
- }
- std::vector<int64> dimensions(shape.ValueOrDie()->dimensions().size());
- std::iota(dimensions.begin(), dimensions.end(), 0);
- return Reshape(operand, dimensions, new_sizes);
-}
-
-ComputationDataHandle ComputationBuilder::Collapse(
- const ComputationDataHandle& operand,
- tensorflow::gtl::ArraySlice<int64> dimensions) {
- if (!first_error_.ok()) {
- return ComputationDataHandle();
- }
-
- // Don't support out-of-order collapse here.
- // Checks that the collapsed dimensions are in order and consecutive.
- for (tensorflow::gtl::ArraySlice<int64>::size_type i = 1;
- i < dimensions.size(); ++i) {
- if (dimensions[i] - 1 != dimensions[i - 1]) {
- NoteError(InvalidArgument(
- "Collapsed dimensions are not in order and consecutive."));
- return ComputationDataHandle();
- }
- }
-
- // Create a new sizes vector from the old shape, replacing the collapsed
- // dimensions by the product of their sizes.
- StatusOr<std::unique_ptr<Shape>> shape_or_status = GetShape(operand);
- if (!shape_or_status.ok()) {
- return ComputationDataHandle();
- }
- std::unique_ptr<Shape> original_shape = shape_or_status.ConsumeValueOrDie();
-
- VLOG(3) << "original shape: " << ShapeUtil::HumanString(*original_shape);
- VLOG(3) << "dims to collapse: "
- << tensorflow::str_util::Join(dimensions, ",");
-
- if (dimensions.size() <= 1) {
- // Not collapsing anything, trivially we can return the operand versus
- // enqueueing a trivial reshape.
- return operand;
- }
-
- std::vector<int64> new_sizes;
- for (int i = 0; i < ShapeUtil::Rank(*original_shape); ++i) {
- if (i <= dimensions.front() || i > dimensions.back()) {
- new_sizes.push_back(original_shape->dimensions(i));
- } else {
- new_sizes.back() *= original_shape->dimensions(i);
- }
- }
-
- VLOG(3) << "new sizes: [" << tensorflow::str_util::Join(new_sizes, ",")
- << "]";
-
- return Reshape(operand, new_sizes);
-}
-
-void ComputationBuilder::Trace(const string& tag,
- const ComputationDataHandle& operand) {
- OpRequest op_request;
- TraceRequest* request = op_request.mutable_trace_request();
- request->set_tag(tag);
- *request->mutable_operand() = operand;
- RunOpAndNoteError(&op_request);
-}
-
-ComputationDataHandle ComputationBuilder::Select(
- const ComputationDataHandle& pred, const ComputationDataHandle& on_true,
- const ComputationDataHandle& on_false) {
- return TernaryOp(TRIOP_SELECT, pred, on_true, on_false);
-}
-
-ComputationDataHandle ComputationBuilder::Tuple(
- tensorflow::gtl::ArraySlice<ComputationDataHandle> elements) {
- OpRequest op_request;
- VariadicOpRequest* request = op_request.mutable_variadic_op_request();
- request->set_varop(VAROP_TUPLE);
- for (const ComputationDataHandle& operand : elements) {
- *request->add_operands() = operand;
- }
- return RunOpAndParseResponse(&op_request);
-}
-
-ComputationDataHandle ComputationBuilder::GetTupleElement(
- const ComputationDataHandle& tuple_data, int64 index) {
- OpRequest op_request;
- GetTupleElementRequest* request =
- op_request.mutable_get_tuple_element_request();
- *request->mutable_operand() = tuple_data;
- request->set_index(index);
- return RunOpAndParseResponse(&op_request);
-}
-
-ComputationDataHandle ComputationBuilder::Eq(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
- return BinaryOp(BINOP_EQ, lhs, rhs, broadcast_dimensions);
-}
-
-ComputationDataHandle ComputationBuilder::Ne(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
- return BinaryOp(BINOP_NE, lhs, rhs, broadcast_dimensions);
-}
-
-ComputationDataHandle ComputationBuilder::Ge(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
- return BinaryOp(BINOP_GE, lhs, rhs, broadcast_dimensions);
-}
-
-ComputationDataHandle ComputationBuilder::Gt(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
- return BinaryOp(BINOP_GT, lhs, rhs, broadcast_dimensions);
-}
-
-ComputationDataHandle ComputationBuilder::Le(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
- return BinaryOp(BINOP_LE, lhs, rhs, broadcast_dimensions);
-}
-
-ComputationDataHandle ComputationBuilder::Lt(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
- return BinaryOp(BINOP_LT, lhs, rhs, broadcast_dimensions);
-}
-
-ComputationDataHandle ComputationBuilder::Dot(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs) {
- StatusOr<std::unique_ptr<Shape>> lhs_shape_or_status = GetShape(lhs);
- if (!lhs_shape_or_status.ok()) {
- return ComputationDataHandle();
- }
- std::unique_ptr<Shape> lhs_shape = lhs_shape_or_status.ConsumeValueOrDie();
-
- DotDimensionNumbers dimension_numbers;
- dimension_numbers.add_lhs_contracting_dimensions(
- lhs_shape->dimensions_size() == 1 ? 0 : 1);
- dimension_numbers.add_rhs_contracting_dimensions(0);
- return DotGeneral(lhs, rhs, dimension_numbers);
-}
-
-ComputationDataHandle ComputationBuilder::DotGeneral(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
- const DotDimensionNumbers& dimension_numbers) {
- OpRequest op_request;
- DotRequest* request = op_request.mutable_dot_request();
- *request->mutable_lhs() = lhs;
- *request->mutable_rhs() = rhs;
- *request->mutable_dimension_numbers() = dimension_numbers;
- return RunOpAndParseResponse(&op_request);
-}
-
-ComputationDataHandle ComputationBuilder::Conv(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
- tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding) {
- return ConvWithGeneralDimensions(
- lhs, rhs, window_strides, padding,
- CreateDefaultConvDimensionNumbers(window_strides.size()));
-}
-
-ComputationDataHandle ComputationBuilder::ConvWithGeneralPadding(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding) {
- return ConvGeneral(lhs, rhs, window_strides, padding,
- CreateDefaultConvDimensionNumbers(window_strides.size()));
-}
-
-bool ComputationBuilder::VerifyConvolution(
- const Shape& lhs_shape, const Shape& rhs_shape,
- const ConvolutionDimensionNumbers& dimension_numbers) {
- if (ShapeUtil::Rank(lhs_shape) != ShapeUtil::Rank(rhs_shape)) {
- NoteError(
- InvalidArgument("Convolution arguments must have same number of "
- "dimensions. Got: %s and %s",
- ShapeUtil::HumanString(lhs_shape).c_str(),
- ShapeUtil::HumanString(rhs_shape).c_str()));
- return false;
- }
- int num_dims = ShapeUtil::Rank(lhs_shape);
- if (num_dims < 2) {
- NoteError(InvalidArgument(
- "Convolution expects argument arrays with >= 3 dimensions. "
- "Got: %s and %s",
- ShapeUtil::HumanString(lhs_shape).c_str(),
- ShapeUtil::HumanString(rhs_shape).c_str()));
- return false;
- }
- int num_spatial_dims = num_dims - 2;
-
- const auto check_spatial_dimensions =
- [&](const char* const field_name,
- const tensorflow::protobuf::RepeatedField<tensorflow::protobuf_int64>&
- numbers) {
- if (numbers.size() != num_spatial_dims) {
- NoteError(InvalidArgument("Expected %d elements for %s, but got %d.",
- num_spatial_dims, field_name,
- numbers.size()));
- return false;
- }
- for (int i = 0; i < numbers.size(); ++i) {
- if (numbers.Get(i) < 0 || numbers.Get(i) >= num_dims) {
- NoteError(
- InvalidArgument("Convolution %s[%d] is out of bounds: %lld",
- field_name, i, numbers.Get(i)));
- return false;
- }
- }
- return true;
- };
- return check_spatial_dimensions(
- "input_spatial_dimensions",
- dimension_numbers.input_spatial_dimensions()) &&
- check_spatial_dimensions(
- "kernel_spatial_dimensions",
- dimension_numbers.kernel_spatial_dimensions()) &&
- check_spatial_dimensions(
- "output_spatial_dimensions",
- dimension_numbers.output_spatial_dimensions());
-}
-
-ComputationDataHandle ComputationBuilder::ConvWithGeneralDimensions(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
- tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding,
- const ConvolutionDimensionNumbers& dimension_numbers) {
- if (!first_error_.ok() || !PrepareComputation().ok()) {
- return ComputationDataHandle();
- }
-
- StatusOr<std::unique_ptr<Shape>> lhs_shape_or_status = GetShape(lhs);
- if (!lhs_shape_or_status.ok()) {
- return ComputationDataHandle();
- }
-
- StatusOr<std::unique_ptr<Shape>> rhs_shape_or_status = GetShape(rhs);
- if (!rhs_shape_or_status.ok()) {
- return ComputationDataHandle();
- }
-
- std::unique_ptr<Shape> lhs_shape = lhs_shape_or_status.ConsumeValueOrDie();
- std::unique_ptr<Shape> rhs_shape = rhs_shape_or_status.ConsumeValueOrDie();
-
- if (!VerifyConvolution(*lhs_shape, *rhs_shape, dimension_numbers)) {
- NoteError(InternalError("failed to verify convolution"));
- return ComputationDataHandle();
- }
-
- std::vector<int64> base_area_dimensions(
- dimension_numbers.input_spatial_dimensions_size());
- for (std::vector<int64>::size_type i = 0; i < base_area_dimensions.size();
- ++i) {
- base_area_dimensions[i] =
- lhs_shape->dimensions(dimension_numbers.input_spatial_dimensions(i));
- }
-
- std::vector<int64> window_dimensions(
- dimension_numbers.kernel_spatial_dimensions_size());
- for (std::vector<int64>::size_type i = 0; i < window_dimensions.size(); ++i) {
- window_dimensions[i] =
- rhs_shape->dimensions(dimension_numbers.kernel_spatial_dimensions(i));
- }
-
- return ConvGeneral(lhs, rhs, window_strides,
- MakePadding(base_area_dimensions, window_dimensions,
- window_strides, padding),
- dimension_numbers);
-}
-
-ComputationDataHandle ComputationBuilder::ConvGeneral(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
- const ConvolutionDimensionNumbers& dimension_numbers) {
- return ConvGeneralDilated(lhs, rhs, window_strides, padding, {}, {},
- dimension_numbers);
-}
-
-ComputationDataHandle ComputationBuilder::ConvGeneralDilated(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
- tensorflow::gtl::ArraySlice<int64> lhs_dilation,
- tensorflow::gtl::ArraySlice<int64> rhs_dilation,
- const ConvolutionDimensionNumbers& dimension_numbers) {
- if (!first_error_.ok() || !PrepareComputation().ok()) {
- return ComputationDataHandle();
- }
-
- StatusOr<std::unique_ptr<Shape>> lhs_shape_or_status = GetShape(lhs);
- if (!lhs_shape_or_status.ok()) {
- return ComputationDataHandle();
- }
-
- StatusOr<std::unique_ptr<Shape>> rhs_shape_or_status = GetShape(rhs);
- if (!rhs_shape_or_status.ok()) {
- return ComputationDataHandle();
- }
-
- std::unique_ptr<Shape> lhs_shape = lhs_shape_or_status.ConsumeValueOrDie();
- std::unique_ptr<Shape> rhs_shape = rhs_shape_or_status.ConsumeValueOrDie();
- if (!VerifyConvolution(*lhs_shape, *rhs_shape, dimension_numbers)) {
- // Error is recorded in VerifyConvolution.
- return ComputationDataHandle();
- }
-
- std::vector<int64> window_dimensions(
- dimension_numbers.kernel_spatial_dimensions_size());
- for (std::vector<int64>::size_type i = 0; i < window_dimensions.size(); ++i) {
- window_dimensions[i] =
- rhs_shape->dimensions(dimension_numbers.kernel_spatial_dimensions(i));
- }
-
- OpRequest op_request;
- ConvolveRequest* request = op_request.mutable_convolve_request();
- *request->mutable_lhs() = lhs;
- *request->mutable_rhs() = rhs;
- *request->mutable_dimension_numbers() = dimension_numbers;
-
- if (!MakeWindow(window_dimensions, window_strides, padding, lhs_dilation,
- rhs_dilation, request->mutable_window())) {
- // Error is recorded in MakeWindow.
- return ComputationDataHandle();
- }
-
- return RunOpAndParseResponse(&op_request);
-}
-
-ComputationDataHandle ComputationBuilder::Fft(
- const ComputationDataHandle& operand, const FftType fft_type,
- const tensorflow::gtl::ArraySlice<int64> fft_length) {
- OpRequest op_request;
- FftRequest* request = op_request.mutable_fft_request();
- *request->mutable_operand() = operand;
- request->set_fft_type(fft_type);
- for (int64 dim_len : fft_length) {
- request->add_fft_length(dim_len);
- }
- return RunOpAndParseResponse(&op_request);
-}
-
-ComputationDataHandle ComputationBuilder::Infeed(const Shape& shape,
- const string& config) {
- OpRequest op_request;
- InfeedRequest* request = op_request.mutable_infeed_request();
- *request->mutable_shape() = shape;
- *request->mutable_config() = config;
- return RunOpAndParseResponse(&op_request);
-}
-
-void ComputationBuilder::Outfeed(const ComputationDataHandle& operand,
- const Shape& shape_with_layout,
- const string& outfeed_config) {
- OpRequest op_request;
- OutfeedRequest* request = op_request.mutable_outfeed_request();
- request->set_outfeed_config(outfeed_config);
- *request->mutable_operand() = operand;
- *request->mutable_shape() = shape_with_layout;
- RunOpAndNoteError(&op_request);
-}
-
-ComputationDataHandle ComputationBuilder::Call(
- const Computation& computation,
- tensorflow::gtl::ArraySlice<ComputationDataHandle> operands) {
- OpRequest op_request;
- CallRequest* request = op_request.mutable_call_request();
- *request->mutable_to_apply() = computation.handle();
- for (const ComputationDataHandle& operand : operands) {
- *request->add_operands() = operand;
- }
- return RunOpAndParseResponse(&op_request);
-}
-
-ComputationDataHandle ComputationBuilder::CustomCall(
- const string& call_target_name,
- tensorflow::gtl::ArraySlice<ComputationDataHandle> operands,
- const Shape& shape) {
- OpRequest op_request;
- CustomCallRequest* request = op_request.mutable_custom_call_request();
- request->set_call_target_name(call_target_name);
- for (const ComputationDataHandle& operand : operands) {
- *request->add_operands() = operand;
- }
- *request->mutable_shape() = shape;
- return RunOpAndParseResponse(&op_request);
-}
-
-ComputationDataHandle ComputationBuilder::HostCompute(
- tensorflow::gtl::ArraySlice<ComputationDataHandle> operands,
- const string& channel_name, int64 cost_estimate_ns, const Shape& shape) {
- OpRequest op_request;
- HostComputeRequest* request = op_request.mutable_host_compute_request();
- for (const ComputationDataHandle& operand : operands) {
- *request->add_operands() = operand;
- }
- *request->mutable_shape() = shape;
- request->set_channel_name(channel_name);
- request->set_cost_estimate_ns(cost_estimate_ns);
- return RunOpAndParseResponse(&op_request);
-}
-
-ComputationDataHandle ComputationBuilder::Complex(
- const ComputationDataHandle& real, const ComputationDataHandle& imag,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
- return BinaryOp(BINOP_COMPLEX, real, imag, broadcast_dimensions);
-}
-
-ComputationDataHandle ComputationBuilder::Conj(
- const ComputationDataHandle& operand) {
- return Complex(Real(operand), Neg(Imag(operand)));
-}
-
-ComputationDataHandle ComputationBuilder::Add(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
- return BinaryOp(BINOP_ADD, lhs, rhs, broadcast_dimensions);
-}
-
-ComputationDataHandle ComputationBuilder::Sub(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
- return BinaryOp(BINOP_SUB, lhs, rhs, broadcast_dimensions);
-}
-
-ComputationDataHandle ComputationBuilder::Mul(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
- return BinaryOp(BINOP_MUL, lhs, rhs, broadcast_dimensions);
-}
-
-ComputationDataHandle ComputationBuilder::Div(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
- return BinaryOp(BINOP_DIV, lhs, rhs, broadcast_dimensions);
-}
-
-ComputationDataHandle ComputationBuilder::Rem(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
- return BinaryOp(BINOP_REM, lhs, rhs, broadcast_dimensions);
-}
-
-ComputationDataHandle ComputationBuilder::Max(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
- return BinaryOp(BINOP_MAX, lhs, rhs, broadcast_dimensions);
-}
-
-ComputationDataHandle ComputationBuilder::Min(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
- return BinaryOp(BINOP_MIN, lhs, rhs, broadcast_dimensions);
-}
-
-ComputationDataHandle ComputationBuilder::And(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
- return BinaryOp(BINOP_AND, lhs, rhs, broadcast_dimensions);
-}
-
-ComputationDataHandle ComputationBuilder::Or(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
- return BinaryOp(BINOP_OR, lhs, rhs, broadcast_dimensions);
-}
-
-// TODO(b/65209188): Create a dedicated lowering for Xor
-ComputationDataHandle ComputationBuilder::Xor(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
- return Or(And(Not(lhs), rhs, broadcast_dimensions),
- And(lhs, Not(rhs), broadcast_dimensions));
-}
-
-ComputationDataHandle ComputationBuilder::Not(
- const ComputationDataHandle& operand) {
- return UnaryOp(UNOP_NOT, operand);
-}
-
-ComputationDataHandle ComputationBuilder::ShiftLeft(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
- return BinaryOp(BINOP_SHIFT_LEFT, lhs, rhs, broadcast_dimensions);
-}
-
-ComputationDataHandle ComputationBuilder::ShiftRightArithmetic(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
- return BinaryOp(BINOP_SHIFT_RIGHT_ARITHMETIC, lhs, rhs, broadcast_dimensions);
-}
-
-ComputationDataHandle ComputationBuilder::ShiftRightLogical(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
- return BinaryOp(BINOP_SHIFT_RIGHT_LOGICAL, lhs, rhs, broadcast_dimensions);
-}
-
-ComputationDataHandle ComputationBuilder::Abs(
- const ComputationDataHandle& operand) {
- return UnaryOp(UNOP_ABS, operand);
-}
-
-ComputationDataHandle ComputationBuilder::Atan2(
- const ComputationDataHandle& y, const ComputationDataHandle& x,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
- return BinaryOp(BINOP_ATAN2, y, x, broadcast_dimensions);
-}
-
-ComputationDataHandle ComputationBuilder::Exp(
- const ComputationDataHandle& operand) {
- return UnaryOp(UNOP_EXP, operand);
-}
-
-ComputationDataHandle ComputationBuilder::Expm1(
- const ComputationDataHandle& operand) {
- return UnaryOp(UNOP_EXPM1, operand);
-}
-
-ComputationDataHandle ComputationBuilder::Floor(
- const ComputationDataHandle& operand) {
- return UnaryOp(UNOP_FLOOR, operand);
-}
-
-ComputationDataHandle ComputationBuilder::Ceil(
- const ComputationDataHandle& operand) {
- return UnaryOp(UNOP_CEIL, operand);
-}
-
-ComputationDataHandle ComputationBuilder::Round(
- const ComputationDataHandle& operand) {
- return UnaryOp(UNOP_ROUND_NEAREST_AFZ, operand);
-}
-
-ComputationDataHandle ComputationBuilder::Log(
- const ComputationDataHandle& operand) {
- return UnaryOp(UNOP_LOG, operand);
-}
-
-ComputationDataHandle ComputationBuilder::Log1p(
- const ComputationDataHandle& operand) {
- return UnaryOp(UNOP_LOG1P, operand);
-}
-
-ComputationDataHandle ComputationBuilder::Sign(
- const ComputationDataHandle& operand) {
- return UnaryOp(UNOP_SIGN, operand);
-}
-
-ComputationDataHandle ComputationBuilder::Cos(
- const ComputationDataHandle& operand) {
- return UnaryOp(UNOP_COS, operand);
-}
-
-ComputationDataHandle ComputationBuilder::Sin(
- const ComputationDataHandle& operand) {
- return UnaryOp(UNOP_SIN, operand);
-}
-
-ComputationDataHandle ComputationBuilder::Tanh(
- const ComputationDataHandle& operand) {
- return UnaryOp(UNOP_TANH, operand);
-}
-
-ComputationDataHandle ComputationBuilder::Real(
- const ComputationDataHandle& operand) {
- return UnaryOp(UNOP_REAL, operand);
-}
-
-ComputationDataHandle ComputationBuilder::Imag(
- const ComputationDataHandle& operand) {
- return UnaryOp(UNOP_IMAG, operand);
-}
-
-ComputationDataHandle ComputationBuilder::IsFinite(
- const ComputationDataHandle& operand) {
- return UnaryOp(UNOP_IS_FINITE, operand);
-}
-
-ComputationDataHandle ComputationBuilder::Transpose(
- const ComputationDataHandle& operand,
- tensorflow::gtl::ArraySlice<int64> permutation) {
- OpRequest op_request;
- TransposeRequest* request = op_request.mutable_transpose_request();
- *request->mutable_operand() = operand;
- for (int64 dimension : permutation) {
- request->add_dimensions(dimension);
- }
- return RunOpAndParseResponse(&op_request);
-}
-
-ComputationDataHandle ComputationBuilder::Rev(
- const ComputationDataHandle& operand,
- tensorflow::gtl::ArraySlice<int64> dimensions) {
- OpRequest op_request;
- ReverseRequest* request = op_request.mutable_reverse_request();
- *request->mutable_operand() = operand;
- for (int64 dimension : dimensions) {
- request->add_dimensions(dimension);
- }
- return RunOpAndParseResponse(&op_request);
-}
-
-ComputationDataHandle ComputationBuilder::Sort(
- const ComputationDataHandle& operand) {
- return UnaryOp(UNOP_SORT, operand);
-}
-
-ComputationDataHandle ComputationBuilder::SqrtF32(
- const ComputationDataHandle& operand) {
- return BinaryOp(BINOP_POW, operand, ConstantR0<float>(0.5),
- /*broadcast_dimensions=*/{});
-}
-
-ComputationDataHandle ComputationBuilder::Pow(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
- return BinaryOp(BINOP_POW, lhs, rhs, broadcast_dimensions);
-}
-
-ComputationDataHandle ComputationBuilder::ConvertElementType(
- const ComputationDataHandle& operand, PrimitiveType new_element_type) {
- if (!first_error_.ok() || !PrepareComputation().ok()) {
- return ComputationDataHandle();
- }
-
- StatusOr<std::unique_ptr<Shape>> shape_status = GetShape(operand);
- if (!shape_status.ok()) {
- return ComputationDataHandle();
- }
- std::unique_ptr<Shape> original = shape_status.ConsumeValueOrDie();
-
- OpRequest op_request;
- ConvertRequest* request = op_request.mutable_convert_request();
- *request->mutable_operand() = operand;
- request->set_new_element_type(new_element_type);
- return RunOpAndParseResponse(&op_request);
-}
-
-ComputationDataHandle ComputationBuilder::BitcastConvertType(
- const ComputationDataHandle& operand, PrimitiveType new_element_type) {
- if (!first_error_.ok() || !PrepareComputation().ok()) {
- return ComputationDataHandle();
- }
-
- StatusOr<std::unique_ptr<Shape>> shape_status = GetShape(operand);
- if (!shape_status.ok()) {
- return ComputationDataHandle();
- }
- std::unique_ptr<Shape> original = shape_status.ConsumeValueOrDie();
-
- OpRequest op_request;
- ConvertRequest* request = op_request.mutable_bitcast_convert_request();
- *request->mutable_operand() = operand;
- request->set_new_element_type(new_element_type);
- return RunOpAndParseResponse(&op_request);
-}
-
-ComputationDataHandle ComputationBuilder::SquareF32(
- const ComputationDataHandle& operand) {
- return BinaryOp(BINOP_POW, operand, ConstantR0<float>(2.0),
- /*broadcast_dimensions=*/{});
-}
-
-ComputationDataHandle ComputationBuilder::ReciprocalF32(
- const ComputationDataHandle& operand) {
- return BinaryOp(BINOP_POW, operand, ConstantR0<float>(-1.0),
- /*broadcast_dimensions=*/{});
-}
-
-ComputationDataHandle ComputationBuilder::Neg(
- const ComputationDataHandle& operand) {
- return UnaryOp(UNOP_NEGATE, operand);
-}
-
-ComputationDataHandle ComputationBuilder::Clz(
- const ComputationDataHandle& operand) {
- return UnaryOp(UNOP_CLZ, operand);
-}
-
-ComputationDataHandle ComputationBuilder::Clamp(
- const ComputationDataHandle& min, const ComputationDataHandle& operand,
- const ComputationDataHandle& max) {
- return TernaryOp(TRIOP_CLAMP, min, operand, max);
-}
-
-ComputationDataHandle ComputationBuilder::UnaryOp(
- UnaryOperation unop, const ComputationDataHandle& operand) {
- OpRequest op_request;
- UnaryOpRequest* request = op_request.mutable_unary_op_request();
- request->set_unop(unop);
- *request->mutable_operand() = operand;
- return RunOpAndParseResponse(&op_request);
-}
-
-ComputationDataHandle ComputationBuilder::BinaryOp(
- BinaryOperation binop, const ComputationDataHandle& lhs,
- const ComputationDataHandle& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
- OpRequest op_request;
- BinaryOpRequest* request = op_request.mutable_binary_op_request();
- request->set_binop(binop);
- *request->mutable_lhs() = lhs;
- *request->mutable_rhs() = rhs;
- for (int64 dimension : broadcast_dimensions) {
- request->add_broadcast_dimensions(dimension);
- }
- return RunOpAndParseResponse(&op_request);
-}
-
-ComputationDataHandle ComputationBuilder::RngOp(
- RandomDistribution distribution,
- tensorflow::gtl::ArraySlice<ComputationDataHandle> parameters,
- const Shape& shape) {
- OpRequest op_request;
- RngRequest* request = op_request.mutable_rng_request();
- request->set_distribution(distribution);
- for (const ComputationDataHandle& param : parameters) {
- *request->add_parameter() = param;
- }
- *request->mutable_shape() = shape;
- return RunOpAndParseResponse(&op_request);
-}
-
-ComputationDataHandle ComputationBuilder::TernaryOp(
- TernaryOperation triop, const ComputationDataHandle& lhs,
- const ComputationDataHandle& rhs, const ComputationDataHandle& ehs) {
- OpRequest op_request;
- TernaryOpRequest* request = op_request.mutable_ternary_op_request();
- request->set_triop(triop);
- *request->mutable_lhs() = lhs;
- *request->mutable_rhs() = rhs;
- *request->mutable_ehs() = ehs;
- return RunOpAndParseResponse(&op_request);
-}
-
-Status ComputationBuilder::SetReturnValue(
- const ComputationDataHandle& operand) {
- TF_RETURN_IF_ERROR(first_error_);
-
- SetReturnValueRequest request;
- *request.mutable_computation() = computation_.handle();
- *request.mutable_operand() = operand;
-
- SetReturnValueResponse response;
-
- VLOG(2) << "making set-handle-to-execute request";
- Status s = client_->stub()->SetReturnValue(&request, &response);
- VLOG(2) << "done with request";
-
- if (!s.ok()) {
- NoteError(s);
- return first_error_;
- }
-
- return Status::OK();
-}
-
-StatusOr<bool> ComputationBuilder::IsConstant(
- const ComputationDataHandle& operand, int64 num_parameters) {
- TF_RETURN_IF_ERROR(first_error_);
-
- IsConstantRequest request;
- *request.mutable_computation() = computation_.handle();
- *request.mutable_operand() = operand;
- request.set_num_parameters(num_parameters);
- IsConstantResponse response;
-
- VLOG(2) << "making IsConstant request";
- Status s = client_->stub()->IsConstant(&request, &response);
- VLOG(2) << "done with request";
-
- if (!s.ok()) {
- return s;
- }
- return response.is_constant();
-}
-
-StatusOr<std::unique_ptr<Literal>> ComputationBuilder::ComputeConstant(
- const ComputationDataHandle& operand, const Layout* output_layout,
- tensorflow::gtl::ArraySlice<Literal> parameters) {
- TF_RETURN_IF_ERROR(first_error_);
-
- ComputeConstantRequest request;
- *request.mutable_computation() = computation_.handle();
- *request.mutable_operand() = operand;
- if (output_layout != nullptr) {
- *request.mutable_output_layout() = *output_layout;
- }
- for (const auto& param : parameters) {
- *request.add_parameters() = param.ToProto();
- }
-
- ComputeConstantResponse response;
-
- VLOG(2) << "making compute-constant request";
- Status s = client_->stub()->ComputeConstant(&request, &response);
- VLOG(2) << "done with request";
-
- if (!s.ok()) {
- return s;
- }
-
- VLOG(3) << "ComputeConstant: {" << response.DebugString() << "}";
-
- if (!response.has_literal()) {
- return InternalError(
- "no computed literal in the provided response in ComputeConstant "
- "request");
- }
- return Literal::CreateFromProto(response.literal());
-}
-
-ComputationDataHandle ComputationBuilder::Map(
- tensorflow::gtl::ArraySlice<ComputationDataHandle> operands,
- const Computation& computation,
- tensorflow::gtl::ArraySlice<int64> dimensions,
- tensorflow::gtl::ArraySlice<ComputationDataHandle> static_operands) {
- OpRequest op_request;
- MapRequest* request = op_request.mutable_map_request();
- for (const ComputationDataHandle& operand : operands) {
- *request->add_operands() = operand;
- }
- *request->mutable_to_apply() = computation.handle();
- for (int64 dimension : dimensions) {
- request->add_dimensions(dimension);
- }
- for (const ComputationDataHandle& sop : static_operands) {
- *request->add_static_operands() = sop;
- }
- return RunOpAndParseResponse(&op_request);
-}
-
-ComputationDataHandle ComputationBuilder::RngNormal(
- const ComputationDataHandle& mu, const ComputationDataHandle& sigma,
- const Shape& shape) {
- return RngOp(RandomDistribution::RNG_NORMAL, {mu, sigma}, shape);
-}
-
-ComputationDataHandle ComputationBuilder::RngUniform(
- const ComputationDataHandle& a, const ComputationDataHandle& b,
- const Shape& shape) {
- return RngOp(RandomDistribution::RNG_UNIFORM, {a, b}, shape);
-}
-
-ComputationDataHandle ComputationBuilder::While(
- const Computation& condition, const Computation& body,
- const ComputationDataHandle& init) {
- OpRequest op_request;
- WhileRequest* request = op_request.mutable_while_request();
- *request->mutable_condition() = condition.handle();
- *request->mutable_body() = body.handle();
- *request->mutable_init() = init;
- return RunOpAndParseResponse(&op_request);
-}
-
-ComputationDataHandle ComputationBuilder::Gather(
- const ComputationDataHandle& input,
- const ComputationDataHandle& gather_indices,
- const GatherDimensionNumbers& dimension_numbers,
- tensorflow::gtl::ArraySlice<int64> window_bounds) {
- OpRequest op_request;
- GatherRequest* gather_request = op_request.mutable_gather_request();
- *gather_request->mutable_input() = input;
- *gather_request->mutable_gather_indices() = gather_indices;
- *gather_request->mutable_dimension_numbers() = dimension_numbers;
- for (int64 window_bound : window_bounds) {
- gather_request->add_window_bounds(window_bound);
- }
- return RunOpAndParseResponse(&op_request);
-}
-
-ComputationDataHandle ComputationBuilder::Conditional(
- const ComputationDataHandle& predicate,
- const ComputationDataHandle& true_operand,
- const Computation& true_computation,
- const ComputationDataHandle& false_operand,
- const Computation& false_computation) {
- OpRequest op_request;
- ConditionalRequest* request = op_request.mutable_conditional_request();
- *request->mutable_predicate() = predicate;
- *request->mutable_true_operand() = true_operand;
- *request->mutable_true_computation() = true_computation.handle();
- *request->mutable_false_operand() = false_operand;
- *request->mutable_false_computation() = false_computation.handle();
- return RunOpAndParseResponse(&op_request);
-}
-
-ComputationDataHandle ComputationBuilder::Reduce(
- const ComputationDataHandle& operand,
- const ComputationDataHandle& init_value, const Computation& computation,
- tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce) {
- OpRequest op_request;
- ReduceRequest* request = op_request.mutable_reduce_request();
- *request->mutable_operand() = operand;
- *request->mutable_init_value() = init_value;
- for (int64 dimension : dimensions_to_reduce) {
- request->add_dimensions(dimension);
- }
- *request->mutable_to_apply() = computation.handle();
- return RunOpAndParseResponse(&op_request);
-}
-
-ComputationDataHandle ComputationBuilder::ReduceAll(
- const ComputationDataHandle& operand,
- const ComputationDataHandle& init_value, const Computation& computation) {
- if (!first_error_.ok() || !PrepareComputation().ok()) {
- return ComputationDataHandle();
- }
-
- StatusOr<std::unique_ptr<Shape>> shape = GetShape(operand);
- if (!shape.ok()) {
- return ComputationDataHandle();
- }
-
- std::vector<int64> all_dimnos(ShapeUtil::Rank(*shape.ValueOrDie()));
- std::iota(all_dimnos.begin(), all_dimnos.end(), 0);
- return Reduce(operand, init_value, computation, all_dimnos);
-}
-
-ComputationDataHandle ComputationBuilder::ReduceWindow(
- const ComputationDataHandle& operand,
- const ComputationDataHandle& init_value, const Computation& computation,
- tensorflow::gtl::ArraySlice<int64> window_dimensions,
- tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding) {
- if (!first_error_.ok()) {
- return ComputationDataHandle();
- }
-
- StatusOr<std::unique_ptr<Shape>> shape = GetShape(operand);
- if (!shape.ok()) {
- return ComputationDataHandle();
- }
-
- Status padding_valid =
- ValidatePaddingValues(AsInt64Slice(shape.ValueOrDie()->dimensions()),
- window_dimensions, window_strides);
- if (!padding_valid.ok()) {
- first_error_ = padding_valid;
- return ComputationDataHandle();
- }
-
- std::vector<std::pair<int64, int64>> padding_values =
- MakePadding(AsInt64Slice(shape.ValueOrDie()->dimensions()),
- window_dimensions, window_strides, padding);
- return ReduceWindowWithGeneralPadding(operand, init_value, computation,
- window_dimensions, window_strides,
- padding_values);
-}
-
-ComputationDataHandle ComputationBuilder::ReduceWindowWithGeneralPadding(
- const ComputationDataHandle& operand,
- const ComputationDataHandle& init_value, const Computation& computation,
- tensorflow::gtl::ArraySlice<int64> window_dimensions,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding) {
- OpRequest op_request;
- ReduceWindowRequest* request = op_request.mutable_reduce_window_request();
- *request->mutable_operand() = operand;
- *request->mutable_to_apply() = computation.handle();
- *request->mutable_init_value() = init_value;
-
- if (!MakeWindow(window_dimensions, window_strides, padding, {}, {},
- request->mutable_window())) {
- NoteError(InternalError("failed to make window"));
- return ComputationDataHandle();
- }
-
- return RunOpAndParseResponse(&op_request);
-}
-
-ComputationDataHandle ComputationBuilder::BatchNormTraining(
- const ComputationDataHandle& operand, const ComputationDataHandle& scale,
- const ComputationDataHandle& offset, float epsilon, int64 feature_index) {
- OpRequest op_request;
- BatchNormTrainingRequest* request =
- op_request.mutable_batch_norm_training_request();
- *request->mutable_operand() = operand;
- *request->mutable_scale() = scale;
- *request->mutable_offset() = offset;
- request->set_epsilon(epsilon);
- request->set_feature_index(feature_index);
- return RunOpAndParseResponse(&op_request);
-}
-
-ComputationDataHandle ComputationBuilder::BatchNormInference(
- const ComputationDataHandle& operand, const ComputationDataHandle& scale,
- const ComputationDataHandle& offset, const ComputationDataHandle& mean,
- const ComputationDataHandle& variance, float epsilon, int64 feature_index) {
- OpRequest op_request;
- BatchNormInferenceRequest* request =
- op_request.mutable_batch_norm_inference_request();
- *request->mutable_operand() = operand;
- *request->mutable_scale() = scale;
- *request->mutable_offset() = offset;
- *request->mutable_mean() = mean;
- *request->mutable_variance() = variance;
- request->set_epsilon(epsilon);
- request->set_feature_index(feature_index);
- return RunOpAndParseResponse(&op_request);
-}
-
-ComputationDataHandle ComputationBuilder::BatchNormGrad(
- const ComputationDataHandle& operand, const ComputationDataHandle& scale,
- const ComputationDataHandle& batch_mean,
- const ComputationDataHandle& batch_var,
- const ComputationDataHandle& grad_output, float epsilon,
- int64 feature_index) {
- OpRequest op_request;
- BatchNormGradRequest* request = op_request.mutable_batch_norm_grad_request();
- *request->mutable_operand() = operand;
- *request->mutable_scale() = scale;
- *request->mutable_mean() = batch_mean;
- *request->mutable_variance() = batch_var;
- *request->mutable_grad_output() = grad_output;
- request->set_epsilon(epsilon);
- request->set_feature_index(feature_index);
- return RunOpAndParseResponse(&op_request);
-}
-
-ComputationDataHandle ComputationBuilder::CrossReplicaSum(
- const ComputationDataHandle& operand) {
- OpRequest op_request;
- CrossReplicaSumRequest* request =
- op_request.mutable_cross_replica_sum_request();
- *request->mutable_operand() = operand;
- return RunOpAndParseResponse(&op_request);
-}
-
-ComputationDataHandle ComputationBuilder::SelectAndScatter(
- const ComputationDataHandle& operand, const Computation& select,
- tensorflow::gtl::ArraySlice<int64> window_dimensions,
- tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding,
- const ComputationDataHandle& source,
- const ComputationDataHandle& init_value, const Computation& scatter) {
- if (!first_error_.ok()) {
- return ComputationDataHandle();
- }
-
- StatusOr<std::unique_ptr<Shape>> shape = GetShape(operand);
- if (!shape.ok()) {
- return ComputationDataHandle();
- }
- return SelectAndScatterWithGeneralPadding(
- operand, select, window_dimensions, window_strides,
- MakePadding(AsInt64Slice(shape.ValueOrDie()->dimensions()),
- window_dimensions, window_strides, padding),
- source, init_value, scatter);
-}
-
-ComputationDataHandle ComputationBuilder::SelectAndScatterWithGeneralPadding(
- const ComputationDataHandle& operand, const Computation& select,
- tensorflow::gtl::ArraySlice<int64> window_dimensions,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
- const ComputationDataHandle& source,
- const ComputationDataHandle& init_value, const Computation& scatter) {
- OpRequest op_request;
- SelectAndScatterRequest* request =
- op_request.mutable_select_and_scatter_request();
- *request->mutable_operand() = operand;
- *request->mutable_select() = select.handle();
- *request->mutable_source() = source;
- *request->mutable_init_value() = init_value;
- *request->mutable_scatter() = scatter.handle();
-
- if (!MakeWindow(window_dimensions, window_strides, padding, {}, {},
- request->mutable_window())) {
- NoteError(InternalError("failed to make window"));
- return ComputationDataHandle();
- }
-
- return RunOpAndParseResponse(&op_request);
-}
-
-ComputationDataHandle ComputationBuilder::ReducePrecision(
- const ComputationDataHandle& operand, const int exponent_bits,
- const int mantissa_bits) {
- OpRequest op_request;
- ReducePrecisionRequest* request =
- op_request.mutable_reduce_precision_request();
- *request->mutable_operand() = operand;
- request->set_exponent_bits(exponent_bits);
- request->set_mantissa_bits(mantissa_bits);
- return RunOpAndParseResponse(&op_request);
-}
-
-void ComputationBuilder::Send(const ComputationDataHandle& operand,
- const ChannelHandle& handle) {
- OpRequest op_request;
- SendRequest* request = op_request.mutable_send_request();
- *request->mutable_operand() = operand;
- *request->mutable_channel_handle() = handle;
- *op_request.mutable_computation() = computation_.handle();
- RunOpAndNoteError(&op_request);
-}
-
-ComputationDataHandle ComputationBuilder::Recv(const Shape& shape,
- const ChannelHandle& handle) {
- OpRequest op_request;
- RecvRequest* request = op_request.mutable_recv_request();
- *request->mutable_shape() = shape;
- *request->mutable_channel_handle() = handle;
- return RunOpAndParseResponse(&op_request);
-}
-
-Computation ComputationBuilder::BuildAndNoteError() {
- DCHECK(parent_builder_ != nullptr);
- auto build_status = Build();
- if (!build_status.ok()) {
- parent_builder_->NoteError(
- AddStatus(build_status.status(),
- tensorflow::strings::StrCat("error from: ", name_)));
- return Computation();
- }
- return build_status.ConsumeValueOrDie();
-}
-
-StatusOr<Computation> ComputationBuilder::Build() {
- if (!first_error_.ok()) {
- string backtrace;
- first_error_backtrace_.Dump(tensorflow::DebugWriteToString, &backtrace);
- return AppendStatus(first_error_, backtrace);
- }
-
- if (computation_.IsNull()) {
- return FailedPrecondition("no computation was built");
- }
-
- return {std::move(computation_)};
-}
-
-/* static */ ConvolutionDimensionNumbers
-ComputationBuilder::CreateDefaultConvDimensionNumbers(int num_spatial_dims) {
- ConvolutionDimensionNumbers dimension_numbers;
- dimension_numbers.set_input_batch_dimension(kConvBatchDimension);
- dimension_numbers.set_input_feature_dimension(kConvFeatureDimension);
- dimension_numbers.set_output_batch_dimension(kConvBatchDimension);
- dimension_numbers.set_output_feature_dimension(kConvFeatureDimension);
- dimension_numbers.set_kernel_output_feature_dimension(
- kConvKernelOutputDimension);
- dimension_numbers.set_kernel_input_feature_dimension(
- kConvKernelInputDimension);
- for (int i = 0; i < num_spatial_dims; ++i) {
- dimension_numbers.add_input_spatial_dimensions(i + 2);
- dimension_numbers.add_kernel_spatial_dimensions(i + 2);
- dimension_numbers.add_output_spatial_dimensions(i + 2);
- }
- return dimension_numbers;
-}
-
-/* static */ StatusOr<ConvolutionDimensionNumbers>
-ComputationBuilder::CreateConvDimensionNumbers(
- int64 input_batch, int64 input_feature, int64 input_first_spatial,
- int64 input_second_spatial, int64 output_batch, int64 output_feature,
- int64 output_first_spatial, int64 output_second_spatial,
- int64 kernel_output_feature, int64 kernel_input_feature,
- int64 kernel_first_spatial, int64 kernel_second_spatial) {
- if (std::set<int64>({input_batch, input_feature, input_first_spatial,
- input_second_spatial})
- .size() != 4) {
- return FailedPrecondition(
- "dimension numbers for the input are not unique: (%lld, %lld, %lld, "
- "%lld)",
- input_batch, input_feature, input_first_spatial, input_second_spatial);
- }
- if (std::set<int64>({kernel_output_feature, kernel_input_feature,
- kernel_first_spatial, kernel_second_spatial})
- .size() != 4) {
- return FailedPrecondition(
- "dimension numbers for the weight are not unique: (%lld, %lld, %lld, "
- "%lld)",
- kernel_output_feature, kernel_input_feature, kernel_first_spatial,
- kernel_second_spatial);
- }
- if (std::set<int64>({output_batch, output_feature, output_first_spatial,
- output_second_spatial})
- .size() != 4) {
- return FailedPrecondition(
- "dimension numbers for the output are not unique: (%lld, %lld, %lld, "
- "%lld)",
- output_batch, output_feature, output_first_spatial,
- output_second_spatial);
- }
- ConvolutionDimensionNumbers dimension_numbers;
- dimension_numbers.set_input_batch_dimension(input_batch);
- dimension_numbers.set_input_feature_dimension(input_feature);
- dimension_numbers.add_input_spatial_dimensions(input_first_spatial);
- dimension_numbers.add_input_spatial_dimensions(input_second_spatial);
- dimension_numbers.set_kernel_output_feature_dimension(kernel_output_feature);
- dimension_numbers.set_kernel_input_feature_dimension(kernel_input_feature);
- dimension_numbers.add_kernel_spatial_dimensions(kernel_first_spatial);
- dimension_numbers.add_kernel_spatial_dimensions(kernel_second_spatial);
- dimension_numbers.set_output_batch_dimension(output_batch);
- dimension_numbers.set_output_feature_dimension(output_feature);
- dimension_numbers.add_output_spatial_dimensions(output_first_spatial);
- dimension_numbers.add_output_spatial_dimensions(output_second_spatial);
- return dimension_numbers;
-}
-
-} // namespace xla
+++ /dev/null
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_COMPUTATION_BUILDER_H_
-#define TENSORFLOW_COMPILER_XLA_CLIENT_COMPUTATION_BUILDER_H_
-
-#include <functional>
-#include <initializer_list>
-#include <memory>
-#include <string>
-#include <utility>
-
-#include "tensorflow/compiler/xla/array.h"
-#include "tensorflow/compiler/xla/array2d.h"
-#include "tensorflow/compiler/xla/array3d.h"
-#include "tensorflow/compiler/xla/array4d.h"
-#include "tensorflow/compiler/xla/client/client.h"
-#include "tensorflow/compiler/xla/client/computation.h"
-#include "tensorflow/compiler/xla/client/global_data.h"
-#include "tensorflow/compiler/xla/client/padding.h"
-#include "tensorflow/compiler/xla/literal_util.h"
-#include "tensorflow/compiler/xla/statusor.h"
-#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/core/bitmap.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
-#include "tensorflow/core/platform/macros.h"
-#include "tensorflow/core/platform/stacktrace.h"
-#include "tensorflow/core/platform/types.h"
-
-namespace xla {
-
-// Wraps an XLA client with a convenient interface for building up
-// computations. Any errors encountered in building up the computation are
-// deferred from being handled until Build() is called.
-//
-// Thread-compatible.
-//
-// TODO(b/74197823): Deprecated. Use XlaBuilder instead.
-class ComputationBuilder {
- public:
- // client: client in which to build the computation.
- // computation_name: name to use for the built computation.
- ComputationBuilder(Client* client, const string& computation_name);
-
- ~ComputationBuilder();
-
- // Returns the client the builder was initialized with.
- Client* client() const { return client_; }
-
- // Returns the computation name.
- const string& name() const { return name_; }
-
- // Sets OpMetadata that will be added to all instructions until cleared.
- //
- // OpMetadata is often applied to a series of XLA HLO instructions. As a
- // result, OpMetadata is set on the Computation Builder. All subsequent
- // instructions generated via this Computation Builder will have the same
- // OpMetadata attached until a call to ClearOpMetadata.
- void SetOpMetadata(const OpMetadata& metadata) { metadata_ = metadata; }
-
- // Clears the HloMetadata state.
- void ClearOpMetadata() { metadata_.Clear(); }
-
- // Sets an OpSharding that will be attached to all instructions until cleared.
- void SetSharding(const OpSharding& sharding) { sharding_ = sharding; }
-
- // Clears the sharding. Ops will be sharded according to the default placement
- // policy.
- void ClearSharding() { sharding_ = tensorflow::gtl::nullopt; }
-
- // Returns the OpSharding that will be attached to all instructions.
- const tensorflow::gtl::optional<OpSharding>& sharding() const {
- return sharding_;
- }
-
- // Sets the builder to a mode where it will die immediately when an error is
- // encountered, rather than producing it in a deferred fashion when Build() is
- // called (which is the default).
- void set_die_immediately_on_error(bool enabled) {
- die_immediately_on_error_ = enabled;
- }
-
- // Enqueues a "retrieve parameter value" instruction for a parameter that was
- // passed to the computation.
- ComputationDataHandle Parameter(int64 parameter_number, const Shape& shape,
- const string& name);
-
- // Retrieves the (inferred) shape of the operand in the computation.
- StatusOr<std::unique_ptr<Shape>> GetShape(
- const ComputationDataHandle& operand);
-
- // Retrieves the (inferred) result for the current computation's shape.
- StatusOr<ProgramShape> GetProgramShape();
-
- // Enqueues a constant with the value of the given literal onto the
- // computation.
- ComputationDataHandle ConstantLiteral(const LiteralSlice& literal);
-
- // Enqueues a constant onto the computation. Methods are templated on the
- // native host type (NativeT) which corresponds to a specific XLA
- // PrimitiveType as given in the following table:
- //
- // Native Type PrimitiveType
- // -----------------------------
- // bool PRED
- // int32 S32
- // int64 S64
- // uint32 U32
- // uint64 U64
- // float F32
- // double F64
- //
- // Note: not all primitive types defined in xla_data.proto have a
- // corresponding native type yet.
- template <typename NativeT>
- ComputationDataHandle ConstantR0(NativeT value);
- template <typename NativeT>
- ComputationDataHandle ConstantR1(tensorflow::gtl::ArraySlice<NativeT> values);
- ComputationDataHandle ConstantR1(const tensorflow::core::Bitmap& values);
- template <typename NativeT>
- ComputationDataHandle ConstantR2(
- std::initializer_list<std::initializer_list<NativeT>> values);
- template <typename NativeT>
- ComputationDataHandle ConstantFromArrayWithLayout(
- const Array<NativeT>& values, const Layout& layout);
- template <typename NativeT>
- ComputationDataHandle ConstantFromArray(const Array<NativeT>& values);
- template <typename NativeT>
- ComputationDataHandle ConstantR2FromArray2DWithLayout(
- const Array2D<NativeT>& values, const Layout& layout);
- template <typename NativeT>
- ComputationDataHandle ConstantR2FromArray2D(const Array2D<NativeT>& values);
- template <typename NativeT>
- ComputationDataHandle ConstantR3FromArray3DWithLayout(
- const Array3D<NativeT>& values, const Layout& layout);
- template <typename NativeT>
- ComputationDataHandle ConstantR3FromArray3D(const Array3D<NativeT>& values);
- template <typename NativeT>
- ComputationDataHandle ConstantR4FromArray4DWithLayout(
- const Array4D<NativeT>& values, const Layout& layout);
- template <typename NativeT>
- ComputationDataHandle ConstantR4FromArray4D(const Array4D<NativeT>& values);
-
- // Enqueues a rank one constant (vector) onto the computation. The vector has
- // size 'length' and every element has the value 'value'.
- template <typename NativeT>
- ComputationDataHandle ConstantR1(int64 length, NativeT value);
-
- // Adds dimensions to an array by duplicating the data in the array.
- //
- // The new dimensions are inserted on the left, i.e. if
- // broadcast_sizes has values {a0, ..., aN} and the operand shape
- // has dimensions {b0, ..., bM} then the shape of the output has
- // dimensions {a0, ..., aN, b0, ..., bM}.
- //
- // The new dimensions index into copies of the operand, i.e.
- //
- // output[i0, ..., iN, j0, ..., jM] = operand[j0, ..., jM]
- ComputationDataHandle Broadcast(
- const ComputationDataHandle& operand,
- tensorflow::gtl::ArraySlice<int64> broadcast_sizes);
-
- // Enqueues a pad operation onto the computation that pads the given value on
- // the edges as well as between the elements of the input. padding_config
- // specifies the padding amount for each dimension.
- ComputationDataHandle Pad(const ComputationDataHandle& operand,
- const ComputationDataHandle& padding_value,
- const PaddingConfig& padding_config);
-
- // Enqueues an operation onto the computation that flattens the operand based
- // on the dimension order (major/slowest-varying to minor/fastest-varying)
- // given, followed by reshaping it into the shape with the given dimension
- // sizes (also major to minor). Conceptually, this is a limited form of
- // "shape casting".
- ComputationDataHandle Reshape(const ComputationDataHandle& operand,
- tensorflow::gtl::ArraySlice<int64> dimensions,
- tensorflow::gtl::ArraySlice<int64> new_sizes);
-
- // Enqueues an operation onto the computation that collapses the operand, from
- // first to last dimension (C order), then reshapes it to the given dimension
- // sizes. Conceptually, this is a limited form of "shape casting".
- ComputationDataHandle Reshape(const ComputationDataHandle& operand,
- tensorflow::gtl::ArraySlice<int64> new_sizes);
-
- // Wrapper for Reshape.
- // Enqueues an operation to collapse the provided dimensions; e.g. an
- // operand with dimensions {x=256, y=2, z=2, p=32} can be collapsed to
- // {x=1024, y=32} by collapsing dims {0, 1, 2}. Collapsing dimensions must
- // be a consecutive, in-order subsequence of the operand dimensions.
- //
- // Note that collapsing a single dimension does nothing:
- //
- // {256} collapsing {0} => {256}
- // {1} collapsing {0} => {1}
- //
- // Collapsing multiple dimensions produces a single result dimension:
- //
- // {256, 2} collapsing {0,1} => {512}
- // {256, 2, 3} collapsing {0,1} => {512, 3}
- //
- // This could potentially cause data to be moved -- it provides a more
- // structured form of reshaping than an arbitrary Reshape operation.
- ComputationDataHandle Collapse(const ComputationDataHandle& operand,
- tensorflow::gtl::ArraySlice<int64> dimensions);
-
- // Enqueues a slice operation onto the computation that slices the operand
- // from the start indices to the limit indices; e.g.
- //
- // x
- // [ 0 1 2 3 ]
- // y [ 4 5 6 7 ] => slice(start={1, 1}, limit={2, 3}) => [ 5 6 ]
- // [ 8 9 a b ]
- //
- // Note that "limit" means up-to-but-not-including; i.e. [start, limit) in 1D
- // range notation.
- // The strides parameter determines the stride over the slice
- ComputationDataHandle Slice(const ComputationDataHandle& operand,
- tensorflow::gtl::ArraySlice<int64> start_indices,
- tensorflow::gtl::ArraySlice<int64> limit_indices,
- tensorflow::gtl::ArraySlice<int64> strides);
-
- // Enqueues a slice operation in a given dimension, taking all other
- // dimensions as they are; e.g. if dimno is 1 from start_index 2 to
- // limit_index 4 by 1, and the shape is f32[7,8,9], this call is short-hand
- // for:
- //
- // array[:, 2:4:1, :]
- ComputationDataHandle SliceInDim(const ComputationDataHandle& operand,
- int64 start_index, int64 limit_index,
- int64 stride, int64 dimno);
-
- // Enqueues a slice operation onto the computation that slices the 'operand'
- // from dynamic start indices which are passed in 'start_indices'.
- // The size of the slice in each dimension is passed in 'slice_sizes',
- // which specify the end point of exclusive slice intervals in each
- // dimension [start, start + size).
- // The shape of 'start_indices' must be rank == 1, with dimension size
- // equal to the rank of the 'operand'.
- // Slice index calculations are computed modulo input dimension sizes to
- // prevent dynamic start indices from generating out-of-bound array accesses.
- ComputationDataHandle DynamicSlice(
- const ComputationDataHandle& operand,
- const ComputationDataHandle& start_indices,
- tensorflow::gtl::ArraySlice<int64> slice_sizes);
-
- // Enqueues a dynamic update slice operation onto the computation, which
- // updates a slice of 'operand' with 'update' at dynamic 'start_indices'.
- // The shape of 'update' determines the shape of the slice of 'operand'
- // which is updated.
- // The indices specified in 'start_indices' specify the offset of the slice
- // of 'operand' which is updated.
- //
- // update = {10, 11} // calculated at runtime.
- // [1 2 3] start = {1, 1} // calculated at runtime. [1 2 3 ]
- // [4 5 6] => DynamicUpdateslice(data, update, start) => [4 10 11]
- // [7 8 9] [7 8 9 ]
- //
- // The shape of 'start_indices' must be rank == 1, with dimension size
- // equal to the rank of the 'operand'.
- // Slice index calculations are computed modulo update dimension sizes to
- // prevent dynamic start indices from generating out-of-bound array accesses.
- ComputationDataHandle DynamicUpdateSlice(
- const ComputationDataHandle& operand, const ComputationDataHandle& update,
- const ComputationDataHandle& start_indices);
-
- // Enqueues a concatenate instruction onto the computation. 'operands' must
- // have >= 1 entry.
- ComputationDataHandle ConcatInDim(
- tensorflow::gtl::ArraySlice<ComputationDataHandle> operands,
- int64 dimension);
-
- // Enqueue a tracing operation onto the computation; the computation will emit
- // a logging message with the operand.
- void Trace(const string& tag, const ComputationDataHandle& operand);
-
- // Enqueues a conditional-move-like select operation onto the computation;
- // predicated on pred, selects between on_true and on_false.
- ComputationDataHandle Select(const ComputationDataHandle& pred,
- const ComputationDataHandle& on_true,
- const ComputationDataHandle& on_false);
-
- // Enqueues a tuple-creation instruction onto the computation.
- ComputationDataHandle Tuple(
- tensorflow::gtl::ArraySlice<ComputationDataHandle> elements);
-
- // Enqueues a tuple-element-get instruction onto the computation.
- ComputationDataHandle GetTupleElement(const ComputationDataHandle& tuple_data,
- int64 index);
-
- // Enqueues an equal-to comparison instruction onto the computation.
- ComputationDataHandle Eq(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
-
- // Enqueues a not-equal comparison instruction onto the computation.
- ComputationDataHandle Ne(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
-
- // Enqueues a greater-or-equal comparison instruction onto the computation.
- ComputationDataHandle Ge(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
-
- // Enqueues a greater-than comparison instruction onto the computation.
- ComputationDataHandle Gt(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
-
- // Enqueues a less-than comparison instruction onto the computation.
- ComputationDataHandle Lt(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
-
- // Enqueues a less-or-equal comparison instruction onto the computation.
- ComputationDataHandle Le(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
-
- // Enqueues a dot instruction onto the computation.
- ComputationDataHandle Dot(const ComputationDataHandle& lhs,
- const ComputationDataHandle& rhs);
-
- // Enqueues a general dot instruction onto the computation.
- ComputationDataHandle DotGeneral(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
- const DotDimensionNumbers& dimension_numbers);
-
- // Default dimension numbers used for a 2D convolution.
- static constexpr int64 kConvBatchDimension = 0;
- static constexpr int64 kConvFeatureDimension = 1;
- static constexpr int64 kConvFirstSpatialDimension = 2;
- static constexpr int64 kConvSecondSpatialDimension = 3;
- static constexpr int64 kConvKernelOutputDimension = 0;
- static constexpr int64 kConvKernelInputDimension = 1;
- static constexpr int64 kConvKernelFirstSpatialDimension = 2;
- static constexpr int64 kConvKernelSecondSpatialDimension = 3;
-
- // Creates a default ConvolutionDimensionNumbers. For a 2D convolution, for
- // the input operand {batch, feature, height, width} = {0, 1, 2, 3} and for
- // the kernel operand
- // {output_feature, input_feature, height, width} = {0, 1, 2, 3}.
- static ConvolutionDimensionNumbers CreateDefaultConvDimensionNumbers(
- int num_spatial_dims = 2);
-
- // Creates a ConvolutionDimensionNumbers with the given arguments. Returns an
- // error if either the input or the weight dimension numbers have conflicts.
- static StatusOr<ConvolutionDimensionNumbers> CreateConvDimensionNumbers(
- int64 input_batch, int64 input_feature, int64 input_first_spatial,
- int64 input_second_spatial, int64 output_batch, int64 output_feature,
- int64 output_first_spatial, int64 output_second_spatial,
- int64 kernel_output_feature, int64 kernel_input_feature,
- int64 kernel_first_spatial, int64 kernel_second_spatial);
-
- // Enqueues a convolution instruction onto the computation, which uses the
- // default convolution dimension numbers.
- ComputationDataHandle Conv(const ComputationDataHandle& lhs,
- const ComputationDataHandle& rhs,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- Padding padding);
-
- // Enqueues a convolution instruction onto the computation, with the caller
- // provided padding configuration in the format returned by MakePadding().
- ComputationDataHandle ConvWithGeneralPadding(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding);
-
- // Enqueues a convolution instruction onto the computation, with the caller
- // provided dimension numbers configuration.
- ComputationDataHandle ConvWithGeneralDimensions(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
- tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding,
- const ConvolutionDimensionNumbers& dimension_numbers);
-
- // Enqueues a convolution instruction onto the computation, with the caller
- // provided padding configuration as well as the dimension numbers.
- ComputationDataHandle ConvGeneral(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
- const ConvolutionDimensionNumbers& dimension_numbers);
-
- // Enqueues a convolution instruction onto the computation, with the caller
- // provided padding configuration, dilation factors and dimension numbers.
- ComputationDataHandle ConvGeneralDilated(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
- tensorflow::gtl::ArraySlice<int64> lhs_dilation,
- tensorflow::gtl::ArraySlice<int64> rhs_dilation,
- const ConvolutionDimensionNumbers& dimension_numbers);
-
- // Enqueues an FFT instruction onto the computation, of the given type and
- // with the given FFT length.
- ComputationDataHandle Fft(const ComputationDataHandle& operand,
- FftType fft_type,
- tensorflow::gtl::ArraySlice<int64> fft_length);
-
- // Enqueues an infeed instruction onto the computation, which writes data of
- // the given shape to the infeed buffer of the device.
- ComputationDataHandle Infeed(const Shape& shape, const string& config = "");
-
- // Enqueues an outfeed instruction onto the computation. This instruction
- // generates outgoing data transfers for the given data.
- //
- // shape_with_layout communicates the laid out shape that we want to outfeed
- // -- if !ShapeUtil::Compatible(GetShape(operand), shape_with_layout) an error
- // will occur.
- void Outfeed(const ComputationDataHandle& operand,
- const Shape& shape_with_layout, const string& outfeed_config);
-
- // Enqueues a call instruction onto the computation.
- ComputationDataHandle Call(
- const Computation& computation,
- tensorflow::gtl::ArraySlice<ComputationDataHandle> operands);
-
- // Enqueues a custom call instruction onto the computation.
- // During code generation, a call instruction is emitted which targets a
- // symbol with the name |call_target_name|. The |operands| are passed to the
- // call instruction. |shape| is the resultant shape.
- ComputationDataHandle CustomCall(
- const string& call_target_name,
- tensorflow::gtl::ArraySlice<ComputationDataHandle> operands,
- const Shape& shape);
-
- // Enqueues a pseudo-op to represent host-side computation data-dependencies.
- // During code generation, host send and receive operations will be generated
- // to transfer |operands| to the host and a single result of |shape| back to
- // the device. Host send/recv operations are emitted using |channel_name|.
- // Dataflow dependencies and the |cost_estimate_ns| field may be used in HLO
- // instruction scheduling.
- ComputationDataHandle HostCompute(
- tensorflow::gtl::ArraySlice<ComputationDataHandle> operands,
- const string& channel_name, int64 cost_estimate_ns, const Shape& shape);
-
- // The following methods enqueue element-wise binary arithmetic operations
- // onto the computation. The shapes of the operands have to match unless one
- // of the operands is a scalar, or an explicit broadcast dimension is given
- // (see g3doc for more details).
-
- // Enqueues a complex compose instruction onto the computation.
- ComputationDataHandle Complex(
- const ComputationDataHandle& real, const ComputationDataHandle& imag,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
-
- // Enqueues a complex conjugate instruction onto the computation.
- ComputationDataHandle Conj(const ComputationDataHandle& operand);
-
- // Enqueues an add instruction onto the computation.
- ComputationDataHandle Add(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
-
- // Enqueues a subtract instruction onto the computation.
- ComputationDataHandle Sub(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
-
- // Enqueues a multiply instruction onto the computation.
- ComputationDataHandle Mul(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
-
- // Enqueues a divide instruction onto the computation.
- ComputationDataHandle Div(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
-
- // Enqueues a remainder instruction onto the computation.
- ComputationDataHandle Rem(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
-
- // Enqueues a max instruction onto the computation.
- ComputationDataHandle Max(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
-
- // Enqueues a min instruction onto the computation.
- ComputationDataHandle Min(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
-
- // Element-wise logical operators
- ComputationDataHandle And(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
-
- ComputationDataHandle Or(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
-
- ComputationDataHandle Xor(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
-
- ComputationDataHandle Not(const ComputationDataHandle& operand);
-
- ComputationDataHandle ShiftLeft(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
- ComputationDataHandle ShiftRightArithmetic(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
- ComputationDataHandle ShiftRightLogical(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
-
- // Reduces an array among the provided dimensions, given "computation" as a
- // reduction operator.
- ComputationDataHandle Reduce(
- const ComputationDataHandle& operand,
- const ComputationDataHandle& init_value, const Computation& computation,
- tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce);
-
- // Convenience wrapper around the above that reduces all the dimensions in the
- // operand shape.
- ComputationDataHandle ReduceAll(const ComputationDataHandle& operand,
- const ComputationDataHandle& init_value,
- const Computation& computation);
-
- // Enqueues a windowed reduce instruction onto the computation.
- ComputationDataHandle ReduceWindow(
- const ComputationDataHandle& operand,
- const ComputationDataHandle& init_value, const Computation& computation,
- tensorflow::gtl::ArraySlice<int64> window_dimensions,
- tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding);
-
- // As ReduceWindow(), but the padding is given in the format
- // returned by MakePadding().
- ComputationDataHandle ReduceWindowWithGeneralPadding(
- const ComputationDataHandle& operand,
- const ComputationDataHandle& init_value, const Computation& computation,
- tensorflow::gtl::ArraySlice<int64> window_dimensions,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding);
-
- // Returns the sum of the operand value across all replicas. All replicas
- // supply one input to the sum and all replicas receive the resulting sum.
- ComputationDataHandle CrossReplicaSum(const ComputationDataHandle& operand);
-
- // Enqueues an operation that scatters the `source` array to the selected
- // indices of each window.
- ComputationDataHandle SelectAndScatter(
- const ComputationDataHandle& operand, const Computation& select,
- tensorflow::gtl::ArraySlice<int64> window_dimensions,
- tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding,
- const ComputationDataHandle& source,
- const ComputationDataHandle& init_value, const Computation& scatter);
-
- // As SelectAndScatter(), but the padding is given in the format
- // returned by MakePadding().
- ComputationDataHandle SelectAndScatterWithGeneralPadding(
- const ComputationDataHandle& operand, const Computation& select,
- tensorflow::gtl::ArraySlice<int64> window_dimensions,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
- const ComputationDataHandle& source,
- const ComputationDataHandle& init_value, const Computation& scatter);
-
- // Enqueues an abs instruction onto the computation.
- ComputationDataHandle Abs(const ComputationDataHandle& operand);
-
- // Enqueues a atan2 instruction onto the computation.
- ComputationDataHandle Atan2(
- const ComputationDataHandle& y, const ComputationDataHandle& x,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
-
- // Enqueues an exp instruction onto the computation.
- ComputationDataHandle Exp(const ComputationDataHandle& operand);
-
- // Enqueues an expm1 instruction onto the computation.
- ComputationDataHandle Expm1(const ComputationDataHandle& operand);
-
- // Enqueues a floor instruction onto the computation.
- ComputationDataHandle Floor(const ComputationDataHandle& operand);
-
- // Enqueues a ceil instruction onto the computation.
- ComputationDataHandle Ceil(const ComputationDataHandle& operand);
-
- // Enqueues a round instruction onto the computation, rounding to nearest even
- // with half-way cases rounding away from zero.
- ComputationDataHandle Round(const ComputationDataHandle& operand);
-
- // Enqueues an log instruction (natural logarithm) onto the computation.
- ComputationDataHandle Log(const ComputationDataHandle& operand);
-
- // Enqueues an log1p instruction onto the computation.
- ComputationDataHandle Log1p(const ComputationDataHandle& operand);
-
- // Enqueues a sign instruction onto the computation.
- ComputationDataHandle Sign(const ComputationDataHandle& operand);
-
- // Enqueues a cosine instruction onto the computation.
- ComputationDataHandle Cos(const ComputationDataHandle& operand);
-
- // Enqueues a sine instruction onto the computation.
- ComputationDataHandle Sin(const ComputationDataHandle& operand);
-
- // Enqueues a tanh instruction onto the computation.
- ComputationDataHandle Tanh(const ComputationDataHandle& operand);
-
- // Enqueues a real-part instruction onto the computation.
- ComputationDataHandle Real(const ComputationDataHandle& operand);
-
- // Enqueues an imaginary-part instruction onto the computation.
- ComputationDataHandle Imag(const ComputationDataHandle& operand);
-
- // Enqueues a float32 sqrt instruction onto the computation.
- // (float32 is specified as there is an implicit float32 0.5f constant
- // exponent).
- ComputationDataHandle SqrtF32(const ComputationDataHandle& operand);
-
- // Enqueues a float32 square instruction onto the computation.
- // (float32 is specified as there is an implicit float32 2.0f constant
- // exponent).
- ComputationDataHandle SquareF32(const ComputationDataHandle& operand);
-
- // Enqueues a lhs^rhs computation onto the computation.
- ComputationDataHandle Pow(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
-
- // Enqueues an operator that tests if the operand's values are finite, i.e.,
- // not Inf or NaN. Defined only for floating-point types. Returns an array of
- // booleans with the same shape where entries are true iff the corresponding
- // entry was NaN.
- ComputationDataHandle IsFinite(const ComputationDataHandle& operand);
-
- // Enqueues a convert instruction onto the computation that changes the
- // element type of the operand array to primitive_type.
- ComputationDataHandle ConvertElementType(const ComputationDataHandle& operand,
- PrimitiveType new_element_type);
-
- // Enqueues a no-op instruction onto the computation that changes
- // the element type of the operand array to primitive_type. The
- // bit-widths of the source and destination element types must be
- // identical.
- ComputationDataHandle BitcastConvertType(const ComputationDataHandle& operand,
- PrimitiveType new_element_type);
-
- // Enqueues a float32 reciprocal instruction onto the computation.
- // (float32 is specified as there is an implicit float32 -1.0f constant
- // exponent).
- //
- // TODO(b/34468990) axe F32 suffix, can be determined by reflecting on the
- // shape of the operand.
- ComputationDataHandle ReciprocalF32(const ComputationDataHandle& operand);
-
- // Enqueues a negate instruction onto the computation.
- ComputationDataHandle Neg(const ComputationDataHandle& operand);
-
- // Enqueues a count-leading-zeros instruction onto the computation.
- ComputationDataHandle Clz(const ComputationDataHandle& operand);
-
- // Enqueues a transpose instruction onto the computation.
- ComputationDataHandle Transpose(
- const ComputationDataHandle& operand,
- tensorflow::gtl::ArraySlice<int64> permutation);
-
- // Enqueues a reverse instruction onto the computation. The order of the
- // elements in the given dimensions is reversed (i.e., the element at index i
- // is moved to index dimension_size - 1 - i).
- ComputationDataHandle Rev(const ComputationDataHandle& operand,
- tensorflow::gtl::ArraySlice<int64> dimensions);
-
- // Enqueues a sort (as increasing order) instruction onto the computation.
- ComputationDataHandle Sort(const ComputationDataHandle& operand);
-
- // Enqueues a clamp instruction onto the computation.
- ComputationDataHandle Clamp(const ComputationDataHandle& min,
- const ComputationDataHandle& operand,
- const ComputationDataHandle& max);
-
- // Enqueues a map instruction onto the computation.
- ComputationDataHandle Map(
- tensorflow::gtl::ArraySlice<ComputationDataHandle> operands,
- const Computation& computation,
- tensorflow::gtl::ArraySlice<int64> dimensions,
- tensorflow::gtl::ArraySlice<ComputationDataHandle> static_operands = {});
-
- // Enqueues a N(mu, sigma) random number generation instruction onto the
- // computation.
- ComputationDataHandle RngNormal(const ComputationDataHandle& mu,
- const ComputationDataHandle& sigma,
- const Shape& shape);
-
- // Enqueues a U(a, b) random number generation instruction onto the
- // computation. Returns values in the semi-open interval [a, b).
- ComputationDataHandle RngUniform(const ComputationDataHandle& a,
- const ComputationDataHandle& b,
- const Shape& shape);
-
- // Enqueues a while node onto the computation.
- ComputationDataHandle While(const Computation& condition,
- const Computation& body,
- const ComputationDataHandle& init);
-
- // Enqueues a conditional node onto the computation.
- ComputationDataHandle Conditional(const ComputationDataHandle& predicate,
- const ComputationDataHandle& true_operand,
- const Computation& true_computation,
- const ComputationDataHandle& false_operand,
- const Computation& false_computation);
-
- // Enqueues a ReducePrecision node onto the computation.
- ComputationDataHandle ReducePrecision(const ComputationDataHandle& operand,
- const int exponent_bits,
- const int mantissa_bits);
-
- // Enqueues a Gather node onto the computation.
- ComputationDataHandle Gather(
- const ComputationDataHandle& input,
- const ComputationDataHandle& gather_indices,
- const GatherDimensionNumbers& dimension_numbers,
- tensorflow::gtl::ArraySlice<int64> window_bounds);
-
- // Enqueues a Send node onto the computation, to send the given operand to
- // a Recv instruction that shares the same channel handle.
- void Send(const ComputationDataHandle& operand, const ChannelHandle& handle);
-
- // Enqueues a Recv node onto the computation. The data comes from a Send
- // instruction that shares the same channel handle and its shape must
- // be the same as the given shape.
- ComputationDataHandle Recv(const Shape& shape, const ChannelHandle& handle);
-
- // Returns true if 'operand' is a compile-time constant. A compile-time
- // constant does not depend on parameters with index greater than or equal to
- // `num_parameters`, or on stateful operators such as `RngNormal` or `Infeed`.
- // Unlike `ComputeConstant`, `IsConstant` tests whether a computation is a
- // compile-time constant without evaluating the computation.
- StatusOr<bool> IsConstant(const ComputationDataHandle& operand,
- int64 num_parameters = 0);
-
- // Normalizes operand across spatial and batch dimensions for each feature.
- //
- // Returns a tuple (normalized, batch_mean, batch_var) where `normalized`
- // is the normalized result and batch_mean and batch_var are the mean and
- // variance, respectively, across batch for the operand.
- ComputationDataHandle BatchNormTraining(const ComputationDataHandle& operand,
- const ComputationDataHandle& scale,
- const ComputationDataHandle& offset,
- float epsilon, int64 feature_index);
-
- // Normalizes operand across spatial and batch dimensions for each feature.
- //
- // `BatchNormInference` is equivalent to calling `BatchNormTraining` without
- // computing `mean` and `variance` for each batch inside the operation. It
- // uses the input `mean` and `variance` instead as estimated values. The
- // purpose of this op is to reduce latency in inference, hence the name
- // `BatchNormInference`.
- //
- // The output has the same shape as `operand`, and contains the normalized
- // values for each batch.
- ComputationDataHandle BatchNormInference(
- const ComputationDataHandle& operand, const ComputationDataHandle& scale,
- const ComputationDataHandle& offset, const ComputationDataHandle& mean,
- const ComputationDataHandle& variance, float epsilon,
- int64 feature_index);
-
- // Calculates the gradients of a batch norm op.
- //
- // The inputs `batch_mean` and `batch_var` represent the mean and variance
- // across the batch.
- //
- // Returns a tuple of three elements:
- // - grad_operand: Gradient with respect to input `operand`
- // - grad_offset: Gradient with respect to input `offset`
- // - grad_scale: Gradient with respect to input `scale`
- ComputationDataHandle BatchNormGrad(const ComputationDataHandle& operand,
- const ComputationDataHandle& scale,
- const ComputationDataHandle& batch_mean,
- const ComputationDataHandle& batch_var,
- const ComputationDataHandle& grad_output,
- float epsilon, int64 feature_index);
-
- // Computes the value of a constant indicated by a
- // ComputationDataHandle using a non-optimized interpreter on the host.
- //
- // The operand must be from the computation currently being built -
- // i.e., returned from this builder with no intervening call to
- // Build(). This happens to currently work regardless of that, but
- // that may stop working at any time.
- //
- // The operand must represent a constant value, which in this case
- // means that it must not statically depend on any parameter of the
- // computation that is being built other then the ones specified on the
- // parameter list. The parameters in the list will be indexed by their
- // parameter id property so the number of parameters specified should be at
- // least as many as the largest used parameter index.
- //
- // `IsConstant` can be used to test whether a computation is a compile-time
- // constant without evaluation it. `ComputeConstant` only succeeds for
- // computations where `IsConstant` returns true.
- //
- // This functionality can be useful when translating a computation
- // into XLA where something that looked dynamic is required by
- // XLA to be specified as a constant. E.g. the source
- // computation (outside of XLA) may include a dynamic
- // computation of the shape of something and ComputeConstant lets
- // you determine what the value of that computation is in the case
- // where the value can be determined at compile time.
- //
- // If output_layout is non-null, then the output of the computation
- // will be stored using that layout.
- StatusOr<std::unique_ptr<Literal>> ComputeConstant(
- const ComputationDataHandle& operand,
- const Layout* output_layout = nullptr,
- tensorflow::gtl::ArraySlice<Literal> parameters = {});
-
- // Returns a new ComputationBuilder whose resultant Computation is used only
- // by this ComputationBuilder. The sub-ComputationBuilder has the same
- // die_immediately_on_error behavior as the parent.
- std::unique_ptr<ComputationBuilder> CreateSubBuilder(
- const string& computation_name);
-
- // Modifies the computation being built so that executions of it
- // will return the value associated with operand, rather than the
- // last expression enqueued on the ComputationBuilder. Any subsequent
- // operations added to the ComputationBuilder will not have any effect unless
- // SetReturnValue is called again.
- Status SetReturnValue(const ComputationDataHandle& operand);
-
- // Builds the computation with the requested operations, or returns a non-ok
- // status.
- StatusOr<Computation> Build();
-
- // Builds the computation with the requested operations, or notes an error in
- // the parent ComputationBuilder and returns an empty computation if building
- // failed. This function is intended to be used where the returned
- // Computation is only used by the parent ComputationBuilder and hence further
- // operation on the returned Computation will simply be error'ed out if an
- // error occurred while building this computation. If the built computation is
- // to be used by a ComputationBuilder other than the parent ComputationBuilder
- // then Build() should be used instead.
- Computation BuildAndNoteError();
-
- // Returns the first error that was encountered while building the
- // computation. When an error is encountered, by default we return a vacuous
- // ComputationDataHandle and inform the user of the error that occurred while
- // building the computation when they make a final call to Build().
- //
- // See also set_die_immediately_on_error().
- Status first_error() const { return first_error_; }
-
- private:
- // Limited checking of convolution parameters. Returns false on
- // error.
- bool VerifyConvolution(const Shape& lhs_shape, const Shape& rhs_shape,
- const ConvolutionDimensionNumbers& dimension_numbers);
-
- // The parent ComputationBuilder of a sub-ComputationBuilder. The
- // parent_builder_ will be the nullptr if not a sub-ComputationBuilder.
- ComputationBuilder* parent_builder_{nullptr};
-
- // Helper function for creating a Window proto from user-supplied
- // data. Returns true if the user-supplied data was valid.
- bool MakeWindow(tensorflow::gtl::ArraySlice<int64> window_dimensions,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
- tensorflow::gtl::ArraySlice<int64> lhs_dilation,
- tensorflow::gtl::ArraySlice<int64> rhs_dilation,
- Window* window);
-
- // Internal helper method that does the building for an arbitrary unary op.
- ComputationDataHandle UnaryOp(UnaryOperation unop,
- const ComputationDataHandle& operand);
-
- // Internal helper method that does the building for an arbitrary binary op.
- // broadcast_dimensions specifies which dimensions to use for broadcasting
- // when the operation is between tensors of different ranks.
- ComputationDataHandle BinaryOp(
- BinaryOperation binop, const ComputationDataHandle& lhs,
- const ComputationDataHandle& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
-
- // Internal helper method that does the building for an arbitrary ternary op.
- ComputationDataHandle TernaryOp(TernaryOperation triop,
- const ComputationDataHandle& lhs,
- const ComputationDataHandle& rhs,
- const ComputationDataHandle& ehs);
-
- // Internal helper method that does the building for a random number generator
- // of a given distribution with an explicitly specified shape.
- ComputationDataHandle RngOp(
- RandomDistribution distribution,
- tensorflow::gtl::ArraySlice<ComputationDataHandle> parameters,
- const Shape& shape);
-
- // Populates computation_ with a valid object or returns a failing status.
- // This is used before any given operation is enqueued.
- Status PrepareComputation();
-
- // Notes that the error occurred by:
- // * storing it internally and capturing a backtrace if it's the first error
- // (this deferred value will be produced on the call to Build())
- // * dying if die_immediately_on_error_ is true
- void NoteError(const Status& error);
-
- // Helper function that runs the given op_request, filling in op_response.
- // Before the op is run, PrepareComputation is called, and common fields in
- // the op_request are filled in.
- Status RunOp(OpRequest* op_request, OpResponse* op_response);
-
- // Helper function that calls RunOp and calls NoteError on failures.
- void RunOpAndNoteError(OpRequest* op_request);
-
- // Helper function that calls RunOp and either returns the output computation
- // data handle (on success) or a vacuous computation data handle (on failure).
- ComputationDataHandle RunOpAndParseResponse(OpRequest* op_request);
-
- // Helper function that implements GetShape without noting errors. This makes
- // it easier to ensure the real GetShape will note errors on every error path.
- StatusOr<std::unique_ptr<Shape>> GetShapeWithoutNoteError(
- const ComputationDataHandle& operand);
-
- string name_; // Name to use for the built computation.
-
- // The first error encountered while building the computation.
- // This is OK until the first error is encountered.
- Status first_error_;
-
- // The saved stack trace from the point at which the first error occurred.
- tensorflow::SavedStackTrace first_error_backtrace_;
-
- // The computation that operations are enqueued onto.
- Computation computation_;
-
- // The client that the computation is created in. Not owned.
- Client* client_;
-
- // Mode bit that indicates whether to die when a first error is encountered.
- bool die_immediately_on_error_ = false;
-
- // The metadata to attach to each op. This is structured as a "modal"-like
- // operation, in order to simplify client code (and not sprinkle this metadata
- // throughout the TensorFlow op kernel implementations).
- OpMetadata metadata_;
-
- // Sharding for this operator. This is structured as a "model"-like operation,
- // in order to simplify client code, similar to metadata_.
- tensorflow::gtl::optional<OpSharding> sharding_;
-
- TF_DISALLOW_COPY_AND_ASSIGN(ComputationBuilder);
-};
-
-template <typename NativeT>
-ComputationDataHandle ComputationBuilder::ConstantR0(NativeT value) {
- return ConstantLiteral(*Literal::CreateR0<NativeT>(value));
-}
-
-template <typename NativeT>
-ComputationDataHandle ComputationBuilder::ConstantR1(
- tensorflow::gtl::ArraySlice<NativeT> values) {
- return ConstantLiteral(*Literal::CreateR1<NativeT>(values));
-}
-
-template <typename NativeT>
-ComputationDataHandle ComputationBuilder::ConstantR1(int64 length,
- NativeT value) {
- Literal literal(ShapeUtil::MakeShape(
- primitive_util::NativeToPrimitiveType<NativeT>(), {length}));
- literal.PopulateWithValue(value);
- return ConstantLiteral(literal);
-}
-
-inline ComputationDataHandle ComputationBuilder::ConstantR1(
- const tensorflow::core::Bitmap& values) {
- return ConstantLiteral(*Literal::CreateR1(values));
-}
-
-template <typename NativeT>
-ComputationDataHandle ComputationBuilder::ConstantR2(
- std::initializer_list<std::initializer_list<NativeT>> values) {
- return ConstantLiteral(*Literal::CreateR2<NativeT>(values));
-}
-
-template <typename NativeT>
-ComputationDataHandle ComputationBuilder::ConstantFromArrayWithLayout(
- const Array<NativeT>& values, const Layout& layout) {
- return ConstantLiteral(
- *Literal::CreateFromArrayWithLayout<NativeT>(values, layout));
-}
-
-template <typename NativeT>
-ComputationDataHandle ComputationBuilder::ConstantFromArray(
- const Array<NativeT>& values) {
- return ConstantLiteral(*Literal::CreateFromArray<NativeT>(values));
-}
-
-template <typename NativeT>
-ComputationDataHandle ComputationBuilder::ConstantR2FromArray2DWithLayout(
- const Array2D<NativeT>& values, const Layout& layout) {
- return ConstantLiteral(
- *Literal::CreateFromArrayWithLayout<NativeT>(values, layout));
-}
-
-template <typename NativeT>
-ComputationDataHandle ComputationBuilder::ConstantR2FromArray2D(
- const Array2D<NativeT>& values) {
- return ConstantLiteral(*Literal::CreateR2FromArray2D<NativeT>(values));
-}
-
-template <typename NativeT>
-ComputationDataHandle ComputationBuilder::ConstantR3FromArray3DWithLayout(
- const Array3D<NativeT>& values, const Layout& layout) {
- return ConstantLiteral(
- *Literal::CreateR3FromArray3DWithLayout<NativeT>(values, layout));
-}
-
-template <typename NativeT>
-ComputationDataHandle ComputationBuilder::ConstantR3FromArray3D(
- const Array3D<NativeT>& values) {
- return ConstantFromArray(values);
-}
-
-template <typename NativeT>
-ComputationDataHandle ComputationBuilder::ConstantR4FromArray4DWithLayout(
- const Array4D<NativeT>& values, const Layout& layout) {
- return ConstantFromArrayWithLayout(values, layout);
-}
-
-template <typename NativeT>
-ComputationDataHandle ComputationBuilder::ConstantR4FromArray4D(
- const Array4D<NativeT>& values) {
- return ConstantFromArray(values);
-}
-
-// RAII-style object: sets the current sharding assignment in builder on
-// construction, and sets back to the previous assignment on destruction.
-class ScopedShardingAssignment {
- public:
- ScopedShardingAssignment(xla::ComputationBuilder* builder,
- tensorflow::gtl::optional<OpSharding> sharding)
- : builder_(builder), prev_sharding_(builder->sharding()) {
- SetSharding(sharding);
- }
-
- ~ScopedShardingAssignment() { SetSharding(prev_sharding_); }
-
- private:
- void SetSharding(const tensorflow::gtl::optional<OpSharding>& sharding) {
- if (sharding.has_value()) {
- builder_->SetSharding(sharding.value());
- } else {
- builder_->ClearSharding();
- }
- }
-
- xla::ComputationBuilder* const builder_;
- tensorflow::gtl::optional<OpSharding> prev_sharding_;
-
- TF_DISALLOW_COPY_AND_ASSIGN(ScopedShardingAssignment);
-};
-
-} // namespace xla
-
-#endif // TENSORFLOW_COMPILER_XLA_CLIENT_COMPUTATION_BUILDER_H_
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/client:client_library",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/compiler/xla/service/cpu:cpu_compiler",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:client_library",
- "//tensorflow/compiler/xla/client:computation",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla/client:computation",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:util",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/client/xla_client:xla_computation",
"enable_for_xla_interpreter",
],
deps = [
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla:xla_proto",
- "//tensorflow/compiler/xla/client:computation",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/lib:arithmetic",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla/client:computation",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
],
deps = [
"//tensorflow/compiler/xla:array2d",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/lib:arithmetic",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
],
deps = [
"//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
tags = ["enable_for_xla_interpreter"],
deps = [
"//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/lib:arithmetic",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:test_helpers",
- "//tensorflow/compiler/xla/client:computation",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla/client:computation",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla:array3d",
"//tensorflow/compiler/xla:reference_util",
"//tensorflow/compiler/xla:shape_util",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/compiler/xla:array3d",
"//tensorflow/compiler/xla:reference_util",
"//tensorflow/compiler/xla:shape_util",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/compiler/xla:reference_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/compiler/xla:array4d",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:reference_util",
"//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client:padding",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:util",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client:padding",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla/client:computation",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/lib:arithmetic",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:test_helpers",
- "//tensorflow/compiler/xla/client:computation",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/compiler/xla/tests:client_library_test_base",
deps = [
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:array3d",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla/client:computation",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:array3d",
"//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/lib:arithmetic",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla/client:computation",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/lib:arithmetic",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla/client:computation",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:array4d",
"//tensorflow/compiler/xla:reference_util",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:test",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/client/xla_client:xla_computation",
"enable_for_xla_interpreter",
],
deps = [
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/client/xla_client:xla_computation",
"enable_for_xla_interpreter",
],
deps = [
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla/client:computation",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla/client:computation",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/lib:arithmetic",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla:xla_proto",
- "//tensorflow/compiler/xla/client:computation",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"enable_for_xla_interpreter",
],
deps = [
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:client_library",
- "//tensorflow/compiler/xla/client:computation",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
srcs = ["execution_profile_test.cc"],
deps = [
":client_library_test_base",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/client/xla_client:xla_computation",
args = ["--xla_hlo_profile"],
deps = [
":client_library_test_base",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:client_library",
- "//tensorflow/compiler/xla/client:computation",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:client_library",
- "//tensorflow/compiler/xla/client:computation",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/client/xla_client:xla_computation",
deps = [
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:statusor",
- "//tensorflow/compiler/xla/client:computation",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:client_library",
- "//tensorflow/compiler/xla/client:computation",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla/client:computation",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
":local_client_test_base",
":test_utils",
"//tensorflow/compiler/xla:shape_util",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
#include <memory>
#include <utility>
-#include "tensorflow/compiler/xla/client/computation.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include <string>
#include "tensorflow/compiler/xla/client/client_library.h"
-#include "tensorflow/compiler/xla/client/computation.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/execution_options_util.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include <memory>
#include <string>
-#include "tensorflow/compiler/xla/client/computation.h"
#include "tensorflow/compiler/xla/client/global_data.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/array3d.h"
#include "tensorflow/compiler/xla/array4d.h"
-#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/array3d.h"
#include "tensorflow/compiler/xla/array4d.h"
-#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/padding.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include <memory>
-#include "tensorflow/compiler/xla/client/computation.h"
#include "tensorflow/compiler/xla/client/global_data.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include <memory>
#include <vector>
-#include "tensorflow/compiler/xla/client/computation.h"
#include "tensorflow/compiler/xla/client/global_data.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include <string>
#include "tensorflow/compiler/xla/array2d.h"
-#include "tensorflow/compiler/xla/client/computation.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
#include <new>
#include <utility>
-#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include <vector>
#include "tensorflow/compiler/xla/array2d.h"
-#include "tensorflow/compiler/xla/client/computation.h"
#include "tensorflow/compiler/xla/client/global_data.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/array4d.h"
-#include "tensorflow/compiler/xla/client/computation.h"
#include "tensorflow/compiler/xla/client/global_data.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include <memory>
#include "tensorflow/compiler/xla/array2d.h"
-#include "tensorflow/compiler/xla/client/computation.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"