From 98c717b782ba402af9c9651dd5fdeb418e6a8e16 Mon Sep 17 00:00:00 2001 From: Chris Leary Date: Wed, 14 Feb 2018 13:16:30 -0800 Subject: [PATCH] [XLA:python] Plumb method to get program shape / return shape from builder. Will be useful for setting layouts on compile options for AOT local API. PiperOrigin-RevId: 185733615 --- .../compiler/xla/client/computation_builder.cc | 20 ++++++++++++++++++++ tensorflow/compiler/xla/client/computation_builder.h | 3 +++ .../compiler/xla/python/local_computation_builder.cc | 5 +++++ .../compiler/xla/python/local_computation_builder.h | 3 +++ .../compiler/xla/python/local_computation_builder.i | 10 ++++++++++ tensorflow/compiler/xla/python/xla_client.py | 9 +++++++++ tensorflow/compiler/xla/python/xla_client_test.py | 3 ++- tensorflow/compiler/xla/tests/axpy_simple_test.cc | 4 ++++ 8 files changed, 56 insertions(+), 1 deletion(-) diff --git a/tensorflow/compiler/xla/client/computation_builder.cc b/tensorflow/compiler/xla/client/computation_builder.cc index 46f2ed4..b1dcad6 100644 --- a/tensorflow/compiler/xla/client/computation_builder.cc +++ b/tensorflow/compiler/xla/client/computation_builder.cc @@ -233,6 +233,26 @@ StatusOr> ComputationBuilder::GetShape( return status_or_shape; } +StatusOr ComputationBuilder::GetProgramShape() { + TF_RETURN_IF_ERROR(first_error_); + + GetComputationShapeRequest request; + *request.mutable_computation() = computation_.handle(); + GetComputationShapeResponse response; + + VLOG(2) << "making get-program-shape-request"; + Status status = client_->stub()->GetComputationShape(&request, &response); + VLOG(2) << "done with get-program-shape-request"; + + if (!status.ok()) { + first_error_ = status; + return status; + } + + TF_RET_CHECK(response.has_program_shape()); + return std::move(*response.mutable_program_shape()); +} + ComputationDataHandle ComputationBuilder::CheckShape( const ComputationDataHandle& operand, const Shape& expected_shape) { std::unique_ptr actual_shape = GetShape(operand).ConsumeValueOrDie(); diff --git a/tensorflow/compiler/xla/client/computation_builder.h b/tensorflow/compiler/xla/client/computation_builder.h index ea4cdb7..7cae91e 100644 --- a/tensorflow/compiler/xla/client/computation_builder.h +++ b/tensorflow/compiler/xla/client/computation_builder.h @@ -101,6 +101,9 @@ class ComputationBuilder { StatusOr> GetShape( const ComputationDataHandle& operand); + // Retrieves the (inferred) result for the current computation's shape. + StatusOr GetProgramShape(); + // Checks that the operand has the given expected shape. Returns the operand // if yes, fails with a CHECK error if no. ComputationDataHandle CheckShape(const ComputationDataHandle& operand, diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc index 3b0d837..a89146d 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.cc +++ b/tensorflow/compiler/xla/python/local_computation_builder.cc @@ -303,6 +303,11 @@ std::unique_ptr LocalComputationBuilder::GetShape( return builder_.GetShape(operand).ConsumeValueOrDie(); } +StatusOr LocalComputationBuilder::GetReturnValueShape() { + TF_ASSIGN_OR_RETURN(ProgramShape program_shape, builder_.GetProgramShape()); + return program_shape.result(); +} + ComputationDataHandle LocalComputationBuilder::Infeed(const Shape& shape) { return builder_.Infeed(shape); } diff --git a/tensorflow/compiler/xla/python/local_computation_builder.h b/tensorflow/compiler/xla/python/local_computation_builder.h index 4c6a504..d682204 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.h +++ b/tensorflow/compiler/xla/python/local_computation_builder.h @@ -133,6 +133,9 @@ class LocalComputationBuilder { std::unique_ptr GetShape(const ComputationDataHandle& operand); + // Returns the shape of the current return value for the computation. + StatusOr GetReturnValueShape(); + ComputationDataHandle Infeed(const Shape& shape); void Outfeed(const ComputationDataHandle& operand, const Shape& shape, diff --git a/tensorflow/compiler/xla/python/local_computation_builder.i b/tensorflow/compiler/xla/python/local_computation_builder.i index 114754b..fa6c8bf 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.i +++ b/tensorflow/compiler/xla/python/local_computation_builder.i @@ -201,6 +201,15 @@ tensorflow::ImportNumpy(); } } +%typemap(out) StatusOr { + if ($1.ok()) { + $result = numpy::PyShapeInfoFromXlaShape($1.ConsumeValueOrDie()); + } else { + PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); + return NULL; + } +} + %typemap(out) Status { if (!$1.ok()) { PyErr_SetString( @@ -850,6 +859,7 @@ tensorflow::ImportNumpy(); %unignore xla::swig::LocalComputationBuilder::ClearOpMetadata; %unignore xla::swig::LocalComputationBuilder::Parameter; %unignore xla::swig::LocalComputationBuilder::GetShape; +%unignore xla::swig::LocalComputationBuilder::GetReturnValueShape; %unignore xla::swig::LocalComputationBuilder::Infeed; %unignore xla::swig::LocalComputationBuilder::Outfeed; %unignore xla::swig::LocalComputationBuilder::ConstantLiteral; diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index b3b2b1c..2489ad9 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -195,6 +195,12 @@ class Shape(object): self._minor_to_major = minor_to_major self._check_minor_to_major() + def __eq__(self, other): + # pylint: disable=protected-access + return (self.np_dtype == other.np_dtype and + self._dimensions == other._dimensions and + self._minor_to_major == other._minor_to_major) + def __repr__(self): return ('xla_client.Shape(np_dtype={!r}, dimensions={!r}, ' 'minor_to_major={!r})').format(self.np_dtype, self._dimensions, @@ -606,6 +612,9 @@ class ComputationBuilder(object): def GetShape(self, operand): return _wrap_shape(self._client.GetShape(_unwrap_data_handle(operand))) + def GetReturnValueShape(self): + return _wrap_shape(self._client.GetReturnValueShape()) + def GetComputationStats(self): raise NotImplementedError() diff --git a/tensorflow/compiler/xla/python/xla_client_test.py b/tensorflow/compiler/xla/python/xla_client_test.py index 7565e81..c9d09cd 100644 --- a/tensorflow/compiler/xla/python/xla_client_test.py +++ b/tensorflow/compiler/xla/python/xla_client_test.py @@ -86,7 +86,8 @@ class ComputationsWithConstantsTest(LocalComputationTest): def testConstantScalarSumF32(self): c = self._NewComputation() - c.Add(c.ConstantF32Scalar(1.11), c.ConstantF32Scalar(3.14)) + root = c.Add(c.ConstantF32Scalar(1.11), c.ConstantF32Scalar(3.14)) + self.assertEqual(c.GetShape(root), c.GetReturnValueShape()) self._ExecuteAndCompareClose(c, expected=4.25) def testConstantScalarSumF64(self): diff --git a/tensorflow/compiler/xla/tests/axpy_simple_test.cc b/tensorflow/compiler/xla/tests/axpy_simple_test.cc index 627a9c3..3f6fd7c 100644 --- a/tensorflow/compiler/xla/tests/axpy_simple_test.cc +++ b/tensorflow/compiler/xla/tests/axpy_simple_test.cc @@ -62,6 +62,10 @@ TEST_F(AxpySimpleTest, AxpyTenValues) { auto ax = builder.Mul(alpha, x); auto axpy = builder.Add(ax, y); + TF_ASSERT_OK_AND_ASSIGN(ProgramShape shape, builder.GetProgramShape()); + + EXPECT_EQ("() -> f32[10]", ShapeUtil::HumanString(shape)); + std::vector expected = { 1.85840735, -1.85840735, 2.28318531, -2.28318531, -6.42477796, 6.42477796, 10.56637061, -10.56637061, -14.70796327, 14.70796327}; -- 2.7.4