return Transfer(*data, shape_with_output_layout);
}
+StatusOr<std::unique_ptr<Literal>> Client::ComputeConstant(
+ const XlaComputation& computation, const Layout* output_layout) const {
+ ComputeConstantGraphRequest request;
+ *request.mutable_computation() = computation.proto();
+ if (output_layout != nullptr) {
+ *request.mutable_output_layout() = *output_layout;
+ }
+
+ ComputeConstantResponse response;
+
+ VLOG(2) << "making compute-constant-graph request";
+ Status s = stub_->ComputeConstantGraph(&request, &response);
+ VLOG(2) << "done with request";
+
+ if (!s.ok()) {
+ return s;
+ }
+
+ VLOG(3) << "ComputeConstant: {" << response.DebugString() << "}";
+
+ if (!response.has_literal()) {
+ return InternalError(
+ "no computed literal in the provided response in ComputeConstantGraph "
+ "request");
+ }
+ return Literal::CreateFromProto(response.literal());
+}
+
StatusOr<Computation> Client::LoadSnapshot(const SessionModule& module) {
LoadComputationSnapshotRequest request;
*request.mutable_module() = module;
const ExecutionOptions* execution_options = nullptr,
ExecutionProfile* execution_profile = nullptr);
+ // Computes the value of the given computation using a non-optimized
+ // interpreter on the host.
+ //
+ // The computation must not depend on any parameters, or on stateful operators
+ // such as `RngNormal` or `Infeed`.
+ //
+ // This functionality can be useful when translating a computation into XLA
+ // where something that looked dynamic is required by XLA to be specified as a
+ // constant. E.g. the source computation (outside of XLA) may include a
+ // dynamic computation of the shape of something and ComputeConstant lets you
+ // determine what the value of that computation is in the case where the value
+ // can be determined at compile time.
+ //
+ // If output_layout is non-null, then the output of the computation will be
+ // stored using that layout.
+ //
+ // TODO(b/74197823): This is a part of a NOT YET ready refactor.
+ StatusOr<std::unique_ptr<Literal>> ComputeConstant(
+ const XlaComputation& computation,
+ const Layout* output_layout = nullptr) const;
+
// Unregister the memory for the given GlobalData on the device.
Status Unregister(const GlobalData& data);
hdrs = ["xla_builder.h"],
deps = [
":xla_computation",
+ "//tensorflow/compiler/xla:execution_options_util",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
#include <functional>
#include <numeric>
+#include <queue>
#include <string>
#include <utility>
+#include "tensorflow/compiler/xla/execution_options_util.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/shape_inference.h"
#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/mutex.h"
}
XlaBuilder::XlaBuilder(const string& computation_name)
- : name_(computation_name), unique_id_(GetUniqueId()) {}
+ : name_(computation_name) {}
XlaBuilder::~XlaBuilder() {}
return op.ConsumeValueOrDie();
}
-StatusOr<ProgramShape> XlaBuilder::GetProgramShape(int64* root_id) {
+StatusOr<ProgramShape> XlaBuilder::GetProgramShape(int64* root_id) const {
TF_RETURN_IF_ERROR(first_error_);
TF_RET_CHECK(root_id != nullptr);
+
ProgramShape program_shape;
// Not all instructions can be roots. Walk backwards from the last added
return program_shape;
}
-StatusOr<ProgramShape> XlaBuilder::GetProgramShape() {
- int64 root_id;
- return GetProgramShape(&root_id);
+StatusOr<ProgramShape> XlaBuilder::GetProgramShape() const {
+ int64 root;
+ return GetProgramShape(&root);
+}
+
+void XlaBuilder::IsConstantVisitor(const int64 op_handle,
+ std::set<int64>* visited,
+ bool* is_constant) const {
+ if (visited->count(op_handle) != 0 || !*is_constant) {
+ return;
+ }
+
+ CHECK(op_handle < instructions_.size() && op_handle >= 0);
+
+ const HloInstructionProto& instr = instructions_[op_handle];
+ const HloOpcode opcode = StringToHloOpcode(instr.opcode()).ValueOrDie();
+ switch (opcode) {
+ default:
+ for (const int64 operand_id : instr.operand_ids()) {
+ IsConstantVisitor(operand_id, visited, is_constant);
+ }
+ // TODO(b/32495713): We aren't checking the called computations.
+ break;
+
+ // Non functional ops.
+ case HloOpcode::kRng:
+ case HloOpcode::kCrossReplicaSum:
+ // TODO(b/33009255): Implmement constant folding for cross replica sum.
+ case HloOpcode::kInfeed:
+ case HloOpcode::kOutfeed:
+ case HloOpcode::kHostCompute:
+ case HloOpcode::kCall:
+ // TODO(b/32495713): We aren't checking the to_apply computation itself,
+ // so we conservatively say that computations containing the Call op
+ // cannot be constant. We cannot set is_functional=false in other similar
+ // cases since we're already relying on IsConstant to return true.
+ case HloOpcode::kCustomCall:
+ case HloOpcode::kWhile:
+ // TODO(b/32495713): We aren't checking the condition and body
+ // computations themselves.
+ case HloOpcode::kSend:
+ case HloOpcode::kRecv:
+ case HloOpcode::kParameter:
+ *is_constant = false;
+ break;
+ }
+ if (!*is_constant) {
+ VLOG(1) << "Non-constant: " << instr.name();
+ }
+ visited->insert(op_handle);
}
XlaComputation XlaBuilder::BuildAndNoteError() {
}
HloComputationProto entry;
+ entry.set_id(GetUniqueId()); // Give the computation a global unique id.
+ entry.set_name(StrCat(name_, entry.id())); // Ensure that the name is unique.
{
int64 root_id;
- ProgramShape program_shape;
- TF_ASSIGN_OR_RETURN(program_shape, GetProgramShape(&root_id));
- entry.mutable_program_shape()->Swap(&program_shape);
+ TF_ASSIGN_OR_RETURN(*entry.mutable_program_shape(),
+ GetProgramShape(&root_id));
entry.set_root_id(root_id);
}
for (auto& instruction : instructions_) {
+ // Ensures that the instruction names are unique among the whole graph.
+ const string& new_name =
+ StrCat(instruction.name(), ".", entry.id(), ".", instruction.id());
+ instruction.set_name(new_name);
entry.add_instructions()->Swap(&instruction);
}
- entry.set_id(unique_id_);
- entry.set_name(StrCat(name_, entry.id())); // Ensure that the name is unique.
XlaComputation computation(entry.id());
HloModuleProto* module = computation.mutable_proto();
module->set_name(entry.name());
const string& name) {
return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
- if (parameter_numbers_.find(parameter_number) != parameter_numbers_.end()) {
+ if (!parameter_numbers_.insert(parameter_number).second) {
return InvalidArgument("parameter %lld already registered",
parameter_number);
}
- parameter_numbers_.insert(parameter_number);
instr.set_parameter_number(parameter_number);
instr.set_name(name);
*instr.mutable_shape() = shape;
});
}
-StatusOr<bool> XlaBuilder::IsConstant(const XlaOp& operand,
- int64 num_parameters) {
- return Unimplemented("IsConstant is not implemented.");
+StatusOr<bool> XlaBuilder::IsConstant(const XlaOp& operand) const {
+ TF_RETURN_IF_ERROR(first_error_);
+
+ // Verify that the handle is valid.
+ TF_RETURN_IF_ERROR(LookUpInstruction(operand).status());
+
+ bool is_constant = true;
+ std::set<int64> visited;
+ IsConstantVisitor(operand.handle(), &visited, &is_constant);
+ return is_constant;
}
-StatusOr<std::unique_ptr<Literal>> XlaBuilder::ComputeConstant(
- const XlaOp& operand, const Layout* output_layout,
- tensorflow::gtl::ArraySlice<Literal> parameters) {
- return Unimplemented("ComputeConstant is not implemented");
+StatusOr<XlaComputation> XlaBuilder::BuildConstantSubGraph(
+ const XlaOp& root_op) const {
+ TF_ASSIGN_OR_RETURN(bool is_constant, IsConstant(root_op));
+ if (!is_constant) {
+ auto op_status = LookUpInstruction(root_op);
+ string op_string =
+ op_status.ok() ? op_status.ValueOrDie()->name() : "<unknown operation>";
+ return InvalidArgument(
+ "Operand to BuildConstantSubGraph depends on a parameter.\n\n"
+ " op requested for constant subgraph: %s\n\n"
+ "This is an internal error that typically happens when the XLA user "
+ "(e.g. TensorFlow) is attempting to determine a value that must be a "
+ "compile-time constant (e.g. an array dimension) but it is not capable "
+ "of being evaluated at XLA compile time.\n\n"
+ "Please file a usability bug with the framework being used (e.g. "
+ "TensorFlow).",
+ op_string.c_str());
+ }
+
+ TF_ASSIGN_OR_RETURN(const HloInstructionProto* root,
+ LookUpInstruction(root_op));
+ TF_ASSIGN_OR_RETURN(HloOpcode opcode, StringToHloOpcode(root->opcode()));
+ if (!CanBeRoot(opcode)) {
+ return InvalidArgument("the operand with opcode %s cannot be root",
+ root->opcode().c_str());
+ }
+
+ HloComputationProto entry;
+ entry.set_id(GetUniqueId()); // Give the computation a global unique id.
+ entry.set_name(StrCat(name_, entry.id(), "_compute_constant"));
+ entry.set_root_id(root->id());
+ ProgramShape* program_shape = entry.mutable_program_shape();
+ *program_shape->mutable_result() = root->shape();
+
+ // We use std::set to keep the instruction ids in ascending order (which is
+ // also a valid denpendency order). The related ops will be added to the
+ // subgraph in the same order.
+ std::set<int64> related_ops;
+ tensorflow::gtl::FlatSet<int64> related_calls; // Related computations.
+ std::queue<int64> worklist;
+ worklist.push(root->id());
+ related_ops.insert(root->id());
+ while (!worklist.empty()) {
+ int64 node = worklist.front();
+ worklist.pop();
+ for (int64 id : instructions_[node].operand_ids()) {
+ if (related_ops.insert(id).second) {
+ worklist.push(id);
+ }
+ }
+ for (int64 called_id : instructions_[node].called_computation_ids()) {
+ related_calls.insert(called_id);
+ }
+ }
+
+ // Add related ops to the computation.
+ for (int64 id : related_ops) {
+ auto* instr = entry.add_instructions();
+ *instr = instructions_[id];
+ // Ensures that the instruction names are unique among the graph.
+ const string& new_name =
+ StrCat(instr->name(), ".", entry.id(), ".", instr->id());
+ instr->set_name(new_name);
+ }
+
+ XlaComputation computation(entry.id());
+ HloModuleProto* module = computation.mutable_proto();
+ module->set_name(entry.name());
+ module->set_id(entry.id());
+ module->set_entry_computation_name(entry.name());
+ module->set_entry_computation_id(entry.id());
+ *module->mutable_program_shape() = *program_shape;
+ for (auto& e : embedded_) {
+ if (related_calls.find(e.second.id()) != related_calls.end()) {
+ *module->add_computations() = e.second;
+ }
+ }
+ *module->add_computations() = std::move(entry);
+
+ return std::move(computation);
}
std::unique_ptr<XlaBuilder> XlaBuilder::CreateSubBuilder(
return sub_builder;
}
-Status XlaBuilder::SetReturnValue(const XlaOp& operand) {
- return Unimplemented("SetReturnValue is not implemented.");
-}
-
/* static */ ConvolutionDimensionNumbers
XlaBuilder::CreateDefaultConvDimensionNumbers(int num_spatial_dims) {
ConvolutionDimensionNumbers dimension_numbers;
instr.set_id(handle);
instr.set_opcode(HloOpcodeString(opcode));
if (instr.name().empty()) {
- instr.set_name(StrCat(instr.opcode(), ".", unique_id_, ".", handle));
- } else {
- // Append the handle to make sure the name is unique.
- instr.set_name(StrCat(instr.name(), ".", unique_id_, ".", handle));
+ instr.set_name(StrCat(instr.opcode()));
}
for (const auto& operand : operands) {
if (operand.builder_ == nullptr) {
XlaOp Recv(const Shape& shape, const ChannelHandle& handle);
// Returns true if 'operand' is a compile-time constant. A compile-time
- // constant does not depend on parameters with index greater than or equal to
- // `num_parameters`, or on stateful operators such as `RngNormal` or `Infeed`.
- // Unlike `ComputeConstant`, `IsConstant` tests whether a computation is a
- // compile-time constant without evaluating the computation.
- StatusOr<bool> IsConstant(const XlaOp& operand, int64 num_parameters = 0);
+ // constant does not depend on any parameters, or on stateful operators such
+ // as `RngNormal` or `Infeed`.
+ //
+ // This tests whether a computation is a compile-time constant without
+ // evaluating the computation.
+ StatusOr<bool> IsConstant(const XlaOp& operand) const;
// Normalizes operand across spatial and batch dimensions for each feature.
//
const XlaOp& grad_output, float epsilon,
int64 feature_index);
- // Computes the value of a constant indicated by a XlaOp using a non-optimized
- // interpreter on the host.
- //
- // The operand must represent a constant value, which in this case
- // means that it must not statically depend on any parameter of the
- // computation that is being built other then the ones specified on the
- // parameter list. The parameters in the list will be indexed by their
- // parameter id property so the number of parameters specified should be at
- // least as many as the largest used parameter index.
- //
- // `IsConstant` can be used to test whether a computation is a compile-time
- // constant without evaluation it. `ComputeConstant` only succeeds for
- // computations where `IsConstant` returns true.
- //
- // This functionality can be useful when translating a computation
- // into XLA where something that looked dynamic is required by
- // XLA to be specified as a constant. E.g. the source
- // computation (outside of XLA) may include a dynamic
- // computation of the shape of something and ComputeConstant lets
- // you determine what the value of that computation is in the case
- // where the value can be determined at compile time.
- //
- // If output_layout is non-null, then the output of the computation
- // will be stored using that layout.
- StatusOr<std::unique_ptr<Literal>> ComputeConstant(
- const XlaOp& operand, const Layout* output_layout = nullptr,
- tensorflow::gtl::ArraySlice<Literal> parameters = {});
-
// Returns a new XlaBuilder whose resultant Computation is used only by this
// XlaBuilder. The sub-XlaBuilder has the same die_immediately_on_error
// behavior as the parent.
std::unique_ptr<XlaBuilder> CreateSubBuilder(const string& computation_name);
- // Modifies the computation being built so that executions of it will return
- // the value associated with operand, rather than the last expression enqueued
- // on the XlaBuilder. Any subsequent operations added to the XlaBuilder will
- // not have any effect unless SetReturnValue is called again.
- Status SetReturnValue(const XlaOp& operand);
-
// Builds the computation with the requested operations, or returns a non-ok
- // status.
+ // status. Note that all ops that have been enqueued will be moved to the
+ // computation being returned.
StatusOr<XlaComputation> Build();
// Builds the computation with the requested operations, or notes an error in
// instead.
XlaComputation BuildAndNoteError();
+ // Returns a subgraph that roots on the given root. If the root is not a
+ // compile-time constant (see `IsConstant`), returns an error.
+ //
+ // This will copy the needed ops/computations to the subgraph.
+ StatusOr<XlaComputation> BuildConstantSubGraph(const XlaOp& root_op) const;
+
// Returns the first error that was encountered while building the
// computation. When an error is encountered, by default we return a vacuous
// XlaOp and inform the user of the error that occurred while
StatusOr<Shape> GetShape(const XlaOp& op) const;
// Returns the (inferred) result for the current computation's shape.
- StatusOr<ProgramShape> GetProgramShape();
+ StatusOr<ProgramShape> GetProgramShape() const;
private:
StatusOr<XlaOp> AddInstruction(
// Returns the (inferred) result for the program shape for the current
// computation and fills the root_id in the pointer.
- StatusOr<ProgramShape> GetProgramShape(int64* root_id);
+ StatusOr<ProgramShape> GetProgramShape(int64* root_id) const;
+
+ // A visitor which checks whether an operation is a compile-time constant,
+ // meaning that it doesn't depend on any parameters, or on any stateful
+ // operation such as `RngNormal` or `Infeed`. The visitor walks the
+ // computation starting at a given operation and sets is_constant to false iff
+ // a parameter or stateful operation is encountered.
+ void IsConstantVisitor(const int64 op_handle, std::set<int64>* visited,
+ bool* is_constant) const;
- string name_; // Name to use for the built computation.
- int64 unique_id_; // The unique id for the built computation.
+ string name_; // Name to use for the built computation.
// The first error encountered while building the computation.
// This is OK until the first error is encountered.
// Since the shape_with_output_layout option in ExecutionOption is
// non-effective to the Evaluator results, explicit relayout here.
+ //
+ // TODO(b/77824332): Make HloEvaluator take care of the re-layout.
+ if (arg->has_output_layout()) {
+ result_literal = result_literal->Relayout(arg->output_layout());
+ }
+ *result->mutable_literal() = result_literal->ToProto();
+
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status Service::ComputeConstantGraph(
+ const ComputeConstantGraphRequest* arg, ComputeConstantResponse* result) {
+ if (!arg->has_computation()) {
+ return InvalidArgument("computations may not be empty");
+ }
+ if (!arg->computation().has_program_shape()) {
+ return InvalidArgument("program shape may not be empty");
+ }
+ if (arg->computation().program_shape().parameters_size() != 0) {
+ return InvalidArgument(
+ "constant computation may not depend on any parameters.");
+ }
+
+ ProgramShape program_shape = arg->computation().program_shape();
+ TF_DCHECK_OK(ShapeUtil::ValidateShape(program_shape.result()));
+ if (arg->has_output_layout()) {
+ TF_RETURN_IF_ERROR(LayoutUtil::ValidateLayoutForShape(
+ arg->output_layout(), program_shape.result()));
+ }
+
+ HloModuleConfig config(program_shape);
+
+ TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module,
+ HloModule::CreateFromProto(arg->computation(), config));
+
+ HloEvaluator evaluator;
+ TF_ASSIGN_OR_RETURN(auto result_literal,
+ evaluator.Evaluate<std::unique_ptr<Literal>>(
+ *module, /*arg_literals=*/{}));
+
+ // Since the result layout is non-effective to the Evaluator results, explicit
+ // relayout here.
+ //
+ // TODO(b/77824332): Make HloEvaluator take care of the re-layout.
if (arg->has_output_layout()) {
result_literal = result_literal->Relayout(arg->output_layout());
}
// Computes the value of a constant expression.
tensorflow::Status ComputeConstant(const ComputeConstantRequest* arg,
ComputeConstantResponse* result) override;
+ tensorflow::Status ComputeConstantGraph(
+ const ComputeConstantGraphRequest* arg,
+ ComputeConstantResponse* result) override;
// Returns the shape (with layout) of an array associated with a given data
// handle.
virtual tensorflow::Status ComputeConstant(
const ComputeConstantRequest* arg, ComputeConstantResponse* result) = 0;
+ virtual tensorflow::Status ComputeConstantGraph(
+ const ComputeConstantGraphRequest* arg,
+ ComputeConstantResponse* result) = 0;
+
// Methods used by Computation.
virtual tensorflow::Status SnapshotComputation(
const SnapshotComputationRequest* ag,
"//tensorflow/compiler/xla/client:computation",
"//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:global_data",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
#include "tensorflow/compiler/xla/client/computation.h"
#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/compiler/xla/client/global_data.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
#include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/types.h"
}
StatusOr<std::unique_ptr<Literal>> ComputeConstantLiteral(
- Client* client, const ComputationDataHandle& operand,
- ComputationBuilder* builder, Layout* output_layout = nullptr,
- tensorflow::gtl::ArraySlice<Literal> parameters = {}) {
- TF_ASSIGN_OR_RETURN(auto computed, builder->ComputeConstant(
- operand, output_layout, parameters));
+ Client* client, const XlaOp& operand, XlaBuilder* builder,
+ Layout* output_layout = nullptr) {
+ TF_ASSIGN_OR_RETURN(auto subgraph, builder->BuildConstantSubGraph(operand));
+ TF_ASSIGN_OR_RETURN(auto computed,
+ client->ComputeConstant(subgraph, output_layout));
return std::move(computed);
}
template <class Scalar>
+ StatusOr<Scalar> ComputeConstantScalar(Client* client, const XlaOp& operand,
+ XlaBuilder* builder) {
+ TF_ASSIGN_OR_RETURN(auto literal, ComputeConstantLiteral(client, operand,
+ builder, nullptr));
+ return literal->Get<Scalar>({});
+ }
+
+ template <class Scalar>
StatusOr<Scalar> ComputeConstantScalar(
Client* client, const ComputationDataHandle& operand,
ComputationBuilder* builder,
tensorflow::gtl::ArraySlice<Literal> parameters = {}) {
- TF_ASSIGN_OR_RETURN(
- auto literal,
- ComputeConstantLiteral(client, operand, builder, nullptr, parameters));
+ TF_ASSIGN_OR_RETURN(auto literal,
+ builder->ComputeConstant(
+ operand, /*output_layout=*/nullptr, parameters));
return literal->Get<Scalar>({});
}
- bool IsConstant(const ComputationDataHandle& operand,
- ComputationBuilder* builder, int64 num_parameters = 0) {
- StatusOr<bool> result = builder->IsConstant(operand, num_parameters);
+ bool IsConstant(const XlaOp& operand, XlaBuilder* builder) {
+ StatusOr<bool> result = builder->IsConstant(operand);
EXPECT_TRUE(result.ok()) << result.status();
return result.ok() ? result.ValueOrDie() : false;
}
TEST_F(ComputeConstantTest, ScalarInt32Literal) {
for (ClientType client_type : client_types) {
Client* client = ClientOrDie(platform_, client_type);
- ComputationBuilder b(client, TestName());
+ XlaBuilder b(TestName());
auto computation = b.ConstantR0<int32>(42);
EXPECT_TRUE(IsConstant(computation, &b));
TEST_F(ComputeConstantTest, ScalarFloatAdd) {
for (ClientType client_type : client_types) {
Client* client = ClientOrDie(platform_, client_type);
- ComputationBuilder b(client, TestName());
+ XlaBuilder b(TestName());
auto computation =
b.Add(b.ConstantR0<float>(42.5f), b.ConstantR0<float>(1.5f));
EXPECT_TRUE(IsConstant(computation, &b));
TEST_F(ComputeConstantTest, ScalarRng) {
for (ClientType client_type : client_types) {
Client* client = ClientOrDie(platform_, client_type);
- ComputationBuilder b(client, TestName());
+ XlaBuilder b(TestName());
auto computation =
b.RngUniform(b.ConstantR0<float>(1.1f), b.ConstantR0<float>(2.1f),
ShapeUtil::MakeShape(F32, {}));
std::vector<Literal> arguments;
arguments.push_back(std::move(*Literal::CreateR0(42.5f)));
- EXPECT_TRUE(IsConstant(computation, &b, arguments.size()));
-
- auto value =
- ComputeConstantScalar<float>(client, computation, &b, arguments);
- ASSERT_TRUE(value.ok()) << value.status();
- EXPECT_EQ(value.ValueOrDie(), 44.0f);
+ TF_ASSERT_OK_AND_ASSIGN(bool is_constant,
+ b.IsConstant(computation, arguments.size()));
+ EXPECT_TRUE(is_constant);
+
+ TF_ASSERT_OK_AND_ASSIGN(
+ auto value,
+ ComputeConstantScalar<float>(client, computation, &b, arguments));
+ EXPECT_EQ(value, 44.0f);
}
}
TEST_F(ComputeConstantTest, DirectParamMissing) {
for (ClientType client_type : client_types) {
Client* client = ClientOrDie(platform_, client_type);
- ComputationBuilder b(client, TestName());
+ XlaBuilder b(TestName());
auto computation = b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "param");
EXPECT_FALSE(IsConstant(computation, &b));
TEST_F(ComputeConstantTest, IndirectParamMissing) {
for (ClientType client_type : client_types) {
Client* client = ClientOrDie(platform_, client_type);
- ComputationBuilder b(client, TestName());
+ XlaBuilder b(TestName());
auto computation =
b.Add(b.ConstantR0<float>(1.0f),
b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "param"));
TEST_F(ComputeConstantTest, UnrelatedParam) {
for (ClientType client_type : client_types) {
Client* client = ClientOrDie(platform_, client_type);
- ComputationBuilder b(client, TestName());
+ XlaBuilder b(TestName());
auto param_a = b.Parameter(10, ShapeUtil::MakeShape(F32, {}), "param0");
auto constant_4 =
EXPECT_TRUE(IsConstant(constant_13, &b));
- auto value = ComputeConstantScalar<float>(client, constant_13, &b);
- ASSERT_TRUE(value.ok()) << value.status();
- EXPECT_EQ(value.ValueOrDie(), 13.0f);
+ TF_ASSERT_OK_AND_ASSIGN(
+ auto value, ComputeConstantScalar<float>(client, constant_13, &b));
+ EXPECT_EQ(value, 13.0f);
}
}
TEST_F(ComputeConstantTest, NonScalarAdd) {
for (ClientType client_type : client_types) {
Client* client = ClientOrDie(platform_, client_type);
- ComputationBuilder b(client, TestName());
+ XlaBuilder b(TestName());
auto computation =
b.Add(b.ConstantR1<int32>({1, 2}), b.ConstantR1<int32>({3, 4}));
EXPECT_TRUE(IsConstant(computation, &b));
- auto computed = ComputeConstantLiteral(client, computation, &b);
- ASSERT_TRUE(computed.ok()) << computed.status();
+ TF_ASSERT_OK_AND_ASSIGN(auto computed,
+ ComputeConstantLiteral(client, computation, &b));
std::unique_ptr<Literal> expected_literal =
Literal::CreateR1<int32>({4, 6});
- LiteralTestUtil::ExpectEqual(*expected_literal, *computed.ValueOrDie());
+ LiteralTestUtil::ExpectEqual(*expected_literal, *computed);
}
}
TEST_F(ComputeConstantTest, IntegerDivide) {
for (ClientType client_type : client_types) {
Client* client = ClientOrDie(platform_, client_type);
- ComputationBuilder b(client, TestName());
+ XlaBuilder b(TestName());
auto computation = b.Div(b.ConstantR0<int32>(15), b.ConstantR0<int32>(3));
EXPECT_TRUE(IsConstant(computation, &b));
- auto computed = ComputeConstantLiteral(client, computation, &b);
- ASSERT_TRUE(computed.ok()) << computed.status();
+ TF_ASSERT_OK_AND_ASSIGN(auto computed,
+ ComputeConstantLiteral(client, computation, &b));
std::unique_ptr<Literal> expected_literal = Literal::CreateR0<int32>(5);
- LiteralTestUtil::ExpectEqual(*expected_literal, *computed.ValueOrDie());
+ LiteralTestUtil::ExpectEqual(*expected_literal, *computed);
}
}
XLA_TEST_F(ComputeConstantTest, Layout) {
for (ClientType client_type : client_types) {
Client* client = ClientOrDie(platform_, client_type);
- ComputationBuilder b(client, TestName());
+ XlaBuilder b(TestName());
std::vector<std::vector<int64>> layouts = {{0, 1}, {1, 0}};
for (const std::vector<int64>& layout : layouts) {
auto layout_proto = LayoutUtil::MakeLayout(layout);
- auto computed = ComputeConstantLiteral(
- client,
- b.Add(b.ConstantR2<int32>({{1, 2}, {3, 4}}),
- b.ConstantR2<int32>({{10, 20}, {30, 40}})),
- &b, &layout_proto);
- ASSERT_TRUE(computed.ok()) << computed.status();
+ TF_ASSERT_OK_AND_ASSIGN(
+ auto computed, ComputeConstantLiteral(
+ client,
+ b.Add(b.ConstantR2<int32>({{1, 2}, {3, 4}}),
+ b.ConstantR2<int32>({{10, 20}, {30, 40}})),
+ &b, &layout_proto));
std::unique_ptr<Literal> expected_literal =
Literal::CreateR2WithLayout<int32>({{11, 22}, {33, 44}},
LayoutUtil::MakeLayout(layout));
- LiteralTestUtil::AssertEqualShapesAndLayouts(
- expected_literal->shape(), computed.ValueOrDie()->shape());
- LiteralTestUtil::ExpectEqual(*expected_literal, *computed.ValueOrDie());
+ LiteralTestUtil::AssertEqualShapesAndLayouts(expected_literal->shape(),
+ computed->shape());
+ LiteralTestUtil::ExpectEqual(*expected_literal, *computed);
}
}
}
repeated LiteralProto parameters = 4;
}
+message ComputeConstantGraphRequest {
+ HloModuleProto computation = 1;
+ Layout output_layout = 2;
+}
+
message ComputeConstantResponse {
// A LiteralProto is returned directly for this request, instead of a
// ComputationDataHandle.