<< "Do not add XlaOp from builder " << operand.builder_->name()
<< " to builder " << this->name();
instr.add_operand_ids(operand.handle());
- // TODO(b/74197823): Set metadata and sharding.
}
+
+ *instr.mutable_metadata() = metadata_;
+ if (sharding_) {
+ *instr.mutable_sharding() = *sharding_;
+ }
+
instructions_.push_back(instr);
XlaOp op(handle, this);
// Returns the computation name.
const string& name() const { return name_; }
+ // Sets OpMetadata that will be added to all instructions until cleared.
+ //
+ // OpMetadata is often applied to a series of XLA HLO instructions. As a
+ // result, OpMetadata is set on the Computation Builder. All subsequent
+ // instructions generated via this Computation Builder will have the same
+ // OpMetadata attached until a call to ClearOpMetadata.
+ void SetOpMetadata(const OpMetadata& metadata) { metadata_ = metadata; }
+
+ // Clears the HloMetadata state.
+ void ClearOpMetadata() { metadata_.Clear(); }
+
+ // Sets an OpSharding that will be attached to all instructions until cleared.
+ void SetSharding(const OpSharding& sharding) { sharding_ = sharding; }
+
+ // Clears the sharding. Ops will be sharded according to the default placement
+ // policy.
+ void ClearSharding() { sharding_ = tensorflow::gtl::nullopt; }
+
+ // Returns the OpSharding that will be attached to all instructions.
+ const tensorflow::gtl::optional<OpSharding>& sharding() const {
+ return sharding_;
+ }
+
// Sets the builder to a mode where it will die immediately when an error is
// encountered, rather than producing it in a deferred fashion when Build() is
// called (which is the default).
// The unique parameter numbers.
tensorflow::gtl::FlatSet<int64> parameter_numbers_;
+ // The metadata to attach to each op. This is structured as a "modal"-like
+ // operation, in order to simplify client code (and not sprinkle this metadata
+ // throughout the TensorFlow op kernel implementations).
+ OpMetadata metadata_;
+
+ // Sharding for this operator. This is structured as a "model"-like operation,
+ // in order to simplify client code, similar to metadata_.
+ tensorflow::gtl::optional<OpSharding> sharding_;
+
// Mode bit that indicates whether to die when a first error is encountered.
bool die_immediately_on_error_ = false;
};
repeated int64 operand_ids = 36;
repeated int64 control_predecessor_ids = 37;
repeated int64 called_computation_ids = 38;
+
+ xla.OpSharding sharding = 40;
}
// Serialization of HloComputation.
deps = [
":local_client_test_base",
"//tensorflow/compiler/xla:test_helpers",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
- "//tensorflow/compiler/xla/service:computation_tracker",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/service:cpu_plugin",
"//tensorflow/compiler/xla/service:local_service",
"//tensorflow/core:test_main",
limitations under the License.
==============================================================================*/
-#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/compiler/xla/client/local_client.h"
-#include "tensorflow/compiler/xla/service/computation_tracker.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/service/local_service.h"
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/tests/local_client_test_base.h"
metadata_.set_op_name("my_sum_op");
}
- void BuildAddComputation(ComputationBuilder* builder) {
+ void BuildAddComputation(XlaBuilder* builder) {
auto x = builder->Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
auto y = builder->Parameter(1, ShapeUtil::MakeShape(F32, {}), "y");
builder->Add(x, y);
};
TEST_F(HloMetadataTest, MetadataPropagation) {
- ComputationBuilder builder(local_client_, "add");
+ XlaBuilder builder("add");
builder.SetOpMetadata(metadata_);
BuildAddComputation(&builder);
builder.ClearOpMetadata();
}
TEST_F(HloMetadataTest, MetadataClearing) {
- ComputationBuilder builder(local_client_, "add");
+ XlaBuilder builder("add");
builder.SetOpMetadata(metadata_);
// Some other pretend computation here.
builder.ClearOpMetadata();