From 4ec02c23174b07540d190cec620347ee6f31a8d8 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 28 Mar 2018 12:18:24 -0700 Subject: [PATCH] [XLA] Redesign: add the rest of the XlaBuilder public methods. PiperOrigin-RevId: 190812260 --- .../compiler/xla/client/xla_client/xla_builder.cc | 107 ++++++++++++++++++++- .../compiler/xla/client/xla_client/xla_builder.h | 71 ++++++++++++++ .../xla/client/xla_client/xla_computation.h | 2 + 3 files changed, 179 insertions(+), 1 deletion(-) diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc index 7d39701..1b94f9a 100644 --- a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc @@ -128,6 +128,18 @@ StatusOr XlaBuilder::GetProgramShape() { return GetProgramShape(&root_id); } +XlaComputation XlaBuilder::BuildAndNoteError() { + DCHECK(parent_builder_ != nullptr); + auto build_status = Build(); + if (!build_status.ok()) { + parent_builder_->NoteError( + AddStatus(build_status.status(), + tensorflow::strings::StrCat("error from: ", name_))); + return {}; + } + return build_status.ConsumeValueOrDie(); +} + StatusOr XlaBuilder::Build() { if (!first_error_.ok()) { string backtrace; @@ -945,6 +957,99 @@ XlaOp XlaBuilder::Recv(const Shape& shape, const ChannelHandle& handle) { return UnimplementedOp(); } +StatusOr XlaBuilder::IsConstant(const XlaOp& operand, + int64 num_parameters) { + return Unimplemented("IsConstant is not implemented."); +} + +StatusOr> XlaBuilder::ComputeConstant( + const XlaOp& operand, const Layout* output_layout, + tensorflow::gtl::ArraySlice parameters) { + return Unimplemented("ComputeConstant is not implemented"); +} + +std::unique_ptr XlaBuilder::CreateSubBuilder( + const string& computation_name) { + auto sub_builder = MakeUnique(computation_name); + sub_builder->parent_builder_ = this; + sub_builder->die_immediately_on_error_ = this->die_immediately_on_error_; + return sub_builder; +} + +Status XlaBuilder::SetReturnValue(const XlaOp& operand) { + return Unimplemented("SetReturnValue is not implemented."); +} + +/* static */ ConvolutionDimensionNumbers +XlaBuilder::CreateDefaultConvDimensionNumbers(int num_spatial_dims) { + ConvolutionDimensionNumbers dimension_numbers; + dimension_numbers.set_input_batch_dimension(kConvBatchDimension); + dimension_numbers.set_input_feature_dimension(kConvFeatureDimension); + dimension_numbers.set_output_batch_dimension(kConvBatchDimension); + dimension_numbers.set_output_feature_dimension(kConvFeatureDimension); + dimension_numbers.set_kernel_output_feature_dimension( + kConvKernelOutputDimension); + dimension_numbers.set_kernel_input_feature_dimension( + kConvKernelInputDimension); + for (int i = 0; i < num_spatial_dims; ++i) { + dimension_numbers.add_input_spatial_dimensions(i + 2); + dimension_numbers.add_kernel_spatial_dimensions(i + 2); + dimension_numbers.add_output_spatial_dimensions(i + 2); + } + return dimension_numbers; +} + +/* static */ Status XlaBuilder::Validate( + const ConvolutionDimensionNumbers& dnum) { + if (dnum.input_spatial_dimensions_size() < 2) { + return FailedPrecondition("input spacial dimension < 2: %d", + dnum.input_spatial_dimensions_size()); + } + if (dnum.kernel_spatial_dimensions_size() < 2) { + return FailedPrecondition("kernel spacial dimension < 2: %d", + dnum.kernel_spatial_dimensions_size()); + } + if (dnum.output_spatial_dimensions_size() < 2) { + return FailedPrecondition("output spacial dimension < 2: %d", + dnum.output_spatial_dimensions_size()); + } + + if (std::set( + {dnum.input_batch_dimension(), dnum.input_feature_dimension(), + dnum.input_spatial_dimensions(0), dnum.input_spatial_dimensions(1)}) + .size() != 4) { + return FailedPrecondition( + "dimension numbers for the input are not unique: (%lld, %lld, %lld, " + "%lld)", + dnum.input_batch_dimension(), dnum.input_feature_dimension(), + dnum.input_spatial_dimensions(0), dnum.input_spatial_dimensions(1)); + } + if (std::set({dnum.kernel_output_feature_dimension(), + dnum.kernel_input_feature_dimension(), + dnum.kernel_spatial_dimensions(0), + dnum.kernel_spatial_dimensions(1)}) + .size() != 4) { + return FailedPrecondition( + "dimension numbers for the weight are not unique: (%lld, %lld, %lld, " + "%lld)", + dnum.kernel_output_feature_dimension(), + dnum.kernel_input_feature_dimension(), + dnum.kernel_spatial_dimensions(0), dnum.kernel_spatial_dimensions(1)); + } + if (std::set({dnum.output_batch_dimension(), + dnum.output_feature_dimension(), + dnum.output_spatial_dimensions(0), + dnum.output_spatial_dimensions(1)}) + .size() != 4) { + return FailedPrecondition( + "dimension numbers for the output are not unique: (%lld, %lld, %lld, " + "%lld)", + dnum.output_batch_dimension(), dnum.output_feature_dimension(), + dnum.output_spatial_dimensions(0), dnum.output_spatial_dimensions(1)); + } + return Status::OK(); +} + StatusOr XlaBuilder::AddInstruction( HloInstructionProto&& instr, HloOpcode opcode, tensorflow::gtl::ArraySlice operands) { @@ -986,7 +1091,7 @@ StatusOr XlaBuilder::LookUpInstruction( } XlaOp XlaBuilder::UnimplementedOp() { - NoteError(Unimplemented("Op not yet implemented")); + NoteError(Unimplemented("Op not implemented")); return {}; } diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.h b/tensorflow/compiler/xla/client/xla_client/xla_builder.h index c5c3515..f66feb9 100644 --- a/tensorflow/compiler/xla/client/xla_client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.h @@ -335,6 +335,26 @@ class XlaBuilder { XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs, const DotDimensionNumbers& dimension_numbers); + // Default dimension numbers used for a 2D convolution. + static constexpr int64 kConvBatchDimension = 0; + static constexpr int64 kConvFeatureDimension = 1; + static constexpr int64 kConvFirstSpatialDimension = 2; + static constexpr int64 kConvSecondSpatialDimension = 3; + static constexpr int64 kConvKernelOutputDimension = 0; + static constexpr int64 kConvKernelInputDimension = 1; + static constexpr int64 kConvKernelFirstSpatialDimension = 2; + static constexpr int64 kConvKernelSecondSpatialDimension = 3; + + // Creates a default ConvolutionDimensionNumbers. For a 2D convolution, for + // the input operand {batch, feature, height, width} = {0, 1, 2, 3} and for + // the kernel operand + // {output_feature, input_feature, height, width} = {0, 1, 2, 3}. + static ConvolutionDimensionNumbers CreateDefaultConvDimensionNumbers( + int num_spatial_dims = 2); + + // Returns an error if the convolution dimension numbers have conflicts. + static Status Validate(const ConvolutionDimensionNumbers& dnum); + // Enqueues a convolution instruction onto the computation, which uses the // default convolution dimension numbers. XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs, @@ -711,10 +731,59 @@ class XlaBuilder { const XlaOp& grad_output, float epsilon, int64 feature_index); + // Computes the value of a constant indicated by a XlaOp using a non-optimized + // interpreter on the host. + // + // The operand must represent a constant value, which in this case + // means that it must not statically depend on any parameter of the + // computation that is being built other then the ones specified on the + // parameter list. The parameters in the list will be indexed by their + // parameter id property so the number of parameters specified should be at + // least as many as the largest used parameter index. + // + // `IsConstant` can be used to test whether a computation is a compile-time + // constant without evaluation it. `ComputeConstant` only succeeds for + // computations where `IsConstant` returns true. + // + // This functionality can be useful when translating a computation + // into XLA where something that looked dynamic is required by + // XLA to be specified as a constant. E.g. the source + // computation (outside of XLA) may include a dynamic + // computation of the shape of something and ComputeConstant lets + // you determine what the value of that computation is in the case + // where the value can be determined at compile time. + // + // If output_layout is non-null, then the output of the computation + // will be stored using that layout. + StatusOr> ComputeConstant( + const XlaOp& operand, const Layout* output_layout = nullptr, + tensorflow::gtl::ArraySlice parameters = {}); + + // Returns a new XlaBuilder whose resultant Computation is used only by this + // XlaBuilder. The sub-XlaBuilder has the same die_immediately_on_error + // behavior as the parent. + std::unique_ptr CreateSubBuilder(const string& computation_name); + + // Modifies the computation being built so that executions of it will return + // the value associated with operand, rather than the last expression enqueued + // on the XlaBuilder. Any subsequent operations added to the XlaBuilder will + // not have any effect unless SetReturnValue is called again. + Status SetReturnValue(const XlaOp& operand); + // Builds the computation with the requested operations, or returns a non-ok // status. StatusOr Build(); + // Builds the computation with the requested operations, or notes an error in + // the parent XlaBuilder and returns an empty computation if building failed. + // This function is intended to be used where the returned XlaComputation is + // only used by the parent XlaBuilder and hence further operation on the + // returned XlaComputation will simply be error'ed out if an error occurred + // while building this computation. If the built computation is to be used by + // a XlaBuilder other than the parent XlaBuilder then Build() should be used + // instead. + XlaComputation BuildAndNoteError(); + // Returns the first error that was encountered while building the // computation. When an error is encountered, by default we return a vacuous // XlaOp and inform the user of the error that occurred while @@ -814,6 +883,8 @@ class XlaBuilder { // Mode bit that indicates whether to die when a first error is encountered. bool die_immediately_on_error_ = false; + + XlaBuilder* parent_builder_{nullptr}; }; template diff --git a/tensorflow/compiler/xla/client/xla_client/xla_computation.h b/tensorflow/compiler/xla/client/xla_client/xla_computation.h index 5b89747..78e1e3c 100644 --- a/tensorflow/compiler/xla/client/xla_client/xla_computation.h +++ b/tensorflow/compiler/xla/client/xla_client/xla_computation.h @@ -29,6 +29,8 @@ namespace xla { // TODO(b/74197823): Replace xla::Computation with this one. class XlaComputation { public: + XlaComputation() : unique_id_(-1) {} + XlaComputation(const XlaComputation&) = delete; XlaComputation& operator=(const XlaComputation&) = delete; -- 2.7.4