return status_or_shape;
}
+StatusOr<ProgramShape> ComputationBuilder::GetProgramShape() {
+ TF_RETURN_IF_ERROR(first_error_);
+
+ GetComputationShapeRequest request;
+ *request.mutable_computation() = computation_.handle();
+ GetComputationShapeResponse response;
+
+ VLOG(2) << "making get-program-shape-request";
+ Status status = client_->stub()->GetComputationShape(&request, &response);
+ VLOG(2) << "done with get-program-shape-request";
+
+ if (!status.ok()) {
+ first_error_ = status;
+ return status;
+ }
+
+ TF_RET_CHECK(response.has_program_shape());
+ return std::move(*response.mutable_program_shape());
+}
+
ComputationDataHandle ComputationBuilder::CheckShape(
const ComputationDataHandle& operand, const Shape& expected_shape) {
std::unique_ptr<Shape> actual_shape = GetShape(operand).ConsumeValueOrDie();
StatusOr<std::unique_ptr<Shape>> GetShape(
const ComputationDataHandle& operand);
+ // Retrieves the (inferred) result for the current computation's shape.
+ StatusOr<ProgramShape> GetProgramShape();
+
// Checks that the operand has the given expected shape. Returns the operand
// if yes, fails with a CHECK error if no.
ComputationDataHandle CheckShape(const ComputationDataHandle& operand,
return builder_.GetShape(operand).ConsumeValueOrDie();
}
+StatusOr<Shape> LocalComputationBuilder::GetReturnValueShape() {
+ TF_ASSIGN_OR_RETURN(ProgramShape program_shape, builder_.GetProgramShape());
+ return program_shape.result();
+}
+
ComputationDataHandle LocalComputationBuilder::Infeed(const Shape& shape) {
return builder_.Infeed(shape);
}
std::unique_ptr<Shape> GetShape(const ComputationDataHandle& operand);
+ // Returns the shape of the current return value for the computation.
+ StatusOr<Shape> GetReturnValueShape();
+
ComputationDataHandle Infeed(const Shape& shape);
void Outfeed(const ComputationDataHandle& operand, const Shape& shape,
}
}
+%typemap(out) StatusOr<Shape> {
+ if ($1.ok()) {
+ $result = numpy::PyShapeInfoFromXlaShape($1.ConsumeValueOrDie());
+ } else {
+ PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str());
+ return NULL;
+ }
+}
+
%typemap(out) Status {
if (!$1.ok()) {
PyErr_SetString(
%unignore xla::swig::LocalComputationBuilder::ClearOpMetadata;
%unignore xla::swig::LocalComputationBuilder::Parameter;
%unignore xla::swig::LocalComputationBuilder::GetShape;
+%unignore xla::swig::LocalComputationBuilder::GetReturnValueShape;
%unignore xla::swig::LocalComputationBuilder::Infeed;
%unignore xla::swig::LocalComputationBuilder::Outfeed;
%unignore xla::swig::LocalComputationBuilder::ConstantLiteral;
self._minor_to_major = minor_to_major
self._check_minor_to_major()
+ def __eq__(self, other):
+ # pylint: disable=protected-access
+ return (self.np_dtype == other.np_dtype and
+ self._dimensions == other._dimensions and
+ self._minor_to_major == other._minor_to_major)
+
def __repr__(self):
return ('xla_client.Shape(np_dtype={!r}, dimensions={!r}, '
'minor_to_major={!r})').format(self.np_dtype, self._dimensions,
def GetShape(self, operand):
return _wrap_shape(self._client.GetShape(_unwrap_data_handle(operand)))
+ def GetReturnValueShape(self):
+ return _wrap_shape(self._client.GetReturnValueShape())
+
def GetComputationStats(self):
raise NotImplementedError()
def testConstantScalarSumF32(self):
c = self._NewComputation()
- c.Add(c.ConstantF32Scalar(1.11), c.ConstantF32Scalar(3.14))
+ root = c.Add(c.ConstantF32Scalar(1.11), c.ConstantF32Scalar(3.14))
+ self.assertEqual(c.GetShape(root), c.GetReturnValueShape())
self._ExecuteAndCompareClose(c, expected=4.25)
def testConstantScalarSumF64(self):
auto ax = builder.Mul(alpha, x);
auto axpy = builder.Add(ax, y);
+ TF_ASSERT_OK_AND_ASSIGN(ProgramShape shape, builder.GetProgramShape());
+
+ EXPECT_EQ("() -> f32[10]", ShapeUtil::HumanString(shape));
+
std::vector<float> expected = {
1.85840735, -1.85840735, 2.28318531, -2.28318531, -6.42477796,
6.42477796, 10.56637061, -10.56637061, -14.70796327, 14.70796327};