[XLA:python] Plumb method to get program shape / return shape from builder.
authorChris Leary <leary@google.com>
Wed, 14 Feb 2018 21:16:30 +0000 (13:16 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 14 Feb 2018 21:20:34 +0000 (13:20 -0800)
Will be useful for setting layouts on compile options for AOT local API.

PiperOrigin-RevId: 185733615

tensorflow/compiler/xla/client/computation_builder.cc
tensorflow/compiler/xla/client/computation_builder.h
tensorflow/compiler/xla/python/local_computation_builder.cc
tensorflow/compiler/xla/python/local_computation_builder.h
tensorflow/compiler/xla/python/local_computation_builder.i
tensorflow/compiler/xla/python/xla_client.py
tensorflow/compiler/xla/python/xla_client_test.py
tensorflow/compiler/xla/tests/axpy_simple_test.cc

index 46f2ed4..b1dcad6 100644 (file)
@@ -233,6 +233,26 @@ StatusOr<std::unique_ptr<Shape>> ComputationBuilder::GetShape(
   return status_or_shape;
 }
 
+StatusOr<ProgramShape> ComputationBuilder::GetProgramShape() {
+  TF_RETURN_IF_ERROR(first_error_);
+
+  GetComputationShapeRequest request;
+  *request.mutable_computation() = computation_.handle();
+  GetComputationShapeResponse response;
+
+  VLOG(2) << "making get-program-shape-request";
+  Status status = client_->stub()->GetComputationShape(&request, &response);
+  VLOG(2) << "done with get-program-shape-request";
+
+  if (!status.ok()) {
+    first_error_ = status;
+    return status;
+  }
+
+  TF_RET_CHECK(response.has_program_shape());
+  return std::move(*response.mutable_program_shape());
+}
+
 ComputationDataHandle ComputationBuilder::CheckShape(
     const ComputationDataHandle& operand, const Shape& expected_shape) {
   std::unique_ptr<Shape> actual_shape = GetShape(operand).ConsumeValueOrDie();
index ea4cdb7..7cae91e 100644 (file)
@@ -101,6 +101,9 @@ class ComputationBuilder {
   StatusOr<std::unique_ptr<Shape>> GetShape(
       const ComputationDataHandle& operand);
 
+  // Retrieves the (inferred) result for the current computation's shape.
+  StatusOr<ProgramShape> GetProgramShape();
+
   // Checks that the operand has the given expected shape. Returns the operand
   // if yes, fails with a CHECK error if no.
   ComputationDataHandle CheckShape(const ComputationDataHandle& operand,
index 3b0d837..a89146d 100644 (file)
@@ -303,6 +303,11 @@ std::unique_ptr<Shape> LocalComputationBuilder::GetShape(
   return builder_.GetShape(operand).ConsumeValueOrDie();
 }
 
+StatusOr<Shape> LocalComputationBuilder::GetReturnValueShape() {
+  TF_ASSIGN_OR_RETURN(ProgramShape program_shape, builder_.GetProgramShape());
+  return program_shape.result();
+}
+
 ComputationDataHandle LocalComputationBuilder::Infeed(const Shape& shape) {
   return builder_.Infeed(shape);
 }
index 4c6a504..d682204 100644 (file)
@@ -133,6 +133,9 @@ class LocalComputationBuilder {
 
   std::unique_ptr<Shape> GetShape(const ComputationDataHandle& operand);
 
+  // Returns the shape of the current return value for the computation.
+  StatusOr<Shape> GetReturnValueShape();
+
   ComputationDataHandle Infeed(const Shape& shape);
 
   void Outfeed(const ComputationDataHandle& operand, const Shape& shape,
index 114754b..fa6c8bf 100644 (file)
@@ -201,6 +201,15 @@ tensorflow::ImportNumpy();
   }
 }
 
+%typemap(out) StatusOr<Shape> {
+  if ($1.ok()) {
+    $result = numpy::PyShapeInfoFromXlaShape($1.ConsumeValueOrDie());
+  } else {
+    PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str());
+    return NULL;
+  }
+}
+
 %typemap(out) Status {
   if (!$1.ok()) {
     PyErr_SetString(
@@ -850,6 +859,7 @@ tensorflow::ImportNumpy();
 %unignore xla::swig::LocalComputationBuilder::ClearOpMetadata;
 %unignore xla::swig::LocalComputationBuilder::Parameter;
 %unignore xla::swig::LocalComputationBuilder::GetShape;
+%unignore xla::swig::LocalComputationBuilder::GetReturnValueShape;
 %unignore xla::swig::LocalComputationBuilder::Infeed;
 %unignore xla::swig::LocalComputationBuilder::Outfeed;
 %unignore xla::swig::LocalComputationBuilder::ConstantLiteral;
index b3b2b1c..2489ad9 100644 (file)
@@ -195,6 +195,12 @@ class Shape(object):
     self._minor_to_major = minor_to_major
     self._check_minor_to_major()
 
+  def __eq__(self, other):
+    # pylint: disable=protected-access
+    return (self.np_dtype == other.np_dtype and
+            self._dimensions == other._dimensions and
+            self._minor_to_major == other._minor_to_major)
+
   def __repr__(self):
     return ('xla_client.Shape(np_dtype={!r}, dimensions={!r}, '
             'minor_to_major={!r})').format(self.np_dtype, self._dimensions,
@@ -606,6 +612,9 @@ class ComputationBuilder(object):
   def GetShape(self, operand):
     return _wrap_shape(self._client.GetShape(_unwrap_data_handle(operand)))
 
+  def GetReturnValueShape(self):
+    return _wrap_shape(self._client.GetReturnValueShape())
+
   def GetComputationStats(self):
     raise NotImplementedError()
 
index 7565e81..c9d09cd 100644 (file)
@@ -86,7 +86,8 @@ class ComputationsWithConstantsTest(LocalComputationTest):
 
   def testConstantScalarSumF32(self):
     c = self._NewComputation()
-    c.Add(c.ConstantF32Scalar(1.11), c.ConstantF32Scalar(3.14))
+    root = c.Add(c.ConstantF32Scalar(1.11), c.ConstantF32Scalar(3.14))
+    self.assertEqual(c.GetShape(root), c.GetReturnValueShape())
     self._ExecuteAndCompareClose(c, expected=4.25)
 
   def testConstantScalarSumF64(self):
index 627a9c3..3f6fd7c 100644 (file)
@@ -62,6 +62,10 @@ TEST_F(AxpySimpleTest, AxpyTenValues) {
   auto ax = builder.Mul(alpha, x);
   auto axpy = builder.Add(ax, y);
 
+  TF_ASSERT_OK_AND_ASSIGN(ProgramShape shape, builder.GetProgramShape());
+
+  EXPECT_EQ("() -> f32[10]", ShapeUtil::HumanString(shape));
+
   std::vector<float> expected = {
       1.85840735, -1.85840735, 2.28318531,   -2.28318531,  -6.42477796,
       6.42477796, 10.56637061, -10.56637061, -14.70796327, 14.70796327};