[XLA] Redesign: handle metadata and sharding.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 27 Mar 2018 00:39:51 +0000 (17:39 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 27 Mar 2018 00:42:55 +0000 (17:42 -0700)
- Add a xla.OpSharding field to the HloInstructionProto.
- Metatdata handling is tested.

PiperOrigin-RevId: 190553731

tensorflow/compiler/xla/client/xla_client/xla_builder.cc
tensorflow/compiler/xla/client/xla_client/xla_builder.h
tensorflow/compiler/xla/service/hlo.proto
tensorflow/compiler/xla/tests/BUILD
tensorflow/compiler/xla/tests/hlo_metadata_test.cc

index bf91efc..1b90b45 100644 (file)
@@ -896,8 +896,13 @@ StatusOr<XlaOp> XlaBuilder::AddInstruction(
         << "Do not add XlaOp from builder " << operand.builder_->name()
         << " to builder " << this->name();
     instr.add_operand_ids(operand.handle());
-    // TODO(b/74197823): Set metadata and sharding.
   }
+
+  *instr.mutable_metadata() = metadata_;
+  if (sharding_) {
+    *instr.mutable_sharding() = *sharding_;
+  }
+
   instructions_.push_back(instr);
 
   XlaOp op(handle, this);
index 22cf094..cc33356 100644 (file)
@@ -85,6 +85,29 @@ class XlaBuilder {
   // Returns the computation name.
   const string& name() const { return name_; }
 
+  // Sets OpMetadata that will be added to all instructions until cleared.
+  //
+  // OpMetadata is often applied to a series of XLA HLO instructions. As a
+  // result, OpMetadata is set on the Computation Builder. All subsequent
+  // instructions generated via this Computation Builder will have the same
+  // OpMetadata attached until a call to ClearOpMetadata.
+  void SetOpMetadata(const OpMetadata& metadata) { metadata_ = metadata; }
+
+  // Clears the HloMetadata state.
+  void ClearOpMetadata() { metadata_.Clear(); }
+
+  // Sets an OpSharding that will be attached to all instructions until cleared.
+  void SetSharding(const OpSharding& sharding) { sharding_ = sharding; }
+
+  // Clears the sharding. Ops will be sharded according to the default placement
+  // policy.
+  void ClearSharding() { sharding_ = tensorflow::gtl::nullopt; }
+
+  // Returns the OpSharding that will be attached to all instructions.
+  const tensorflow::gtl::optional<OpSharding>& sharding() const {
+    return sharding_;
+  }
+
   // Sets the builder to a mode where it will die immediately when an error is
   // encountered, rather than producing it in a deferred fashion when Build() is
   // called (which is the default).
@@ -776,6 +799,15 @@ class XlaBuilder {
   // The unique parameter numbers.
   tensorflow::gtl::FlatSet<int64> parameter_numbers_;
 
+  // The metadata to attach to each op. This is structured as a "modal"-like
+  // operation, in order to simplify client code (and not sprinkle this metadata
+  // throughout the TensorFlow op kernel implementations).
+  OpMetadata metadata_;
+
+  // Sharding for this operator. This is structured as a "model"-like operation,
+  // in order to simplify client code, similar to metadata_.
+  tensorflow::gtl::optional<OpSharding> sharding_;
+
   // Mode bit that indicates whether to die when a first error is encountered.
   bool die_immediately_on_error_ = false;
 };
index 406fead..0b446c6 100644 (file)
@@ -141,6 +141,8 @@ message HloInstructionProto {
   repeated int64 operand_ids = 36;
   repeated int64 control_predecessor_ids = 37;
   repeated int64 called_computation_ids = 38;
+
+  xla.OpSharding sharding = 40;
 }
 
 // Serialization of HloComputation.
index 3705d6c..5ab25f2 100644 (file)
@@ -1810,9 +1810,8 @@ tf_cc_test(
     deps = [
         ":local_client_test_base",
         "//tensorflow/compiler/xla:test_helpers",
-        "//tensorflow/compiler/xla/client:computation_builder",
         "//tensorflow/compiler/xla/client:local_client",
-        "//tensorflow/compiler/xla/service:computation_tracker",
+        "//tensorflow/compiler/xla/client/xla_client:xla_builder",
         "//tensorflow/compiler/xla/service:cpu_plugin",
         "//tensorflow/compiler/xla/service:local_service",
         "//tensorflow/core:test_main",
index eded207..cf971dd 100644 (file)
@@ -13,9 +13,8 @@ See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
 
-#include "tensorflow/compiler/xla/client/computation_builder.h"
 #include "tensorflow/compiler/xla/client/local_client.h"
-#include "tensorflow/compiler/xla/service/computation_tracker.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
 #include "tensorflow/compiler/xla/service/local_service.h"
 #include "tensorflow/compiler/xla/test_helpers.h"
 #include "tensorflow/compiler/xla/tests/local_client_test_base.h"
@@ -30,7 +29,7 @@ class HloMetadataTest : public LocalClientTestBase {
     metadata_.set_op_name("my_sum_op");
   }
 
-  void BuildAddComputation(ComputationBuilder* builder) {
+  void BuildAddComputation(XlaBuilder* builder) {
     auto x = builder->Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
     auto y = builder->Parameter(1, ShapeUtil::MakeShape(F32, {}), "y");
     builder->Add(x, y);
@@ -40,7 +39,7 @@ class HloMetadataTest : public LocalClientTestBase {
 };
 
 TEST_F(HloMetadataTest, MetadataPropagation) {
-  ComputationBuilder builder(local_client_, "add");
+  XlaBuilder builder("add");
   builder.SetOpMetadata(metadata_);
   BuildAddComputation(&builder);
   builder.ClearOpMetadata();
@@ -61,7 +60,7 @@ TEST_F(HloMetadataTest, MetadataPropagation) {
 }
 
 TEST_F(HloMetadataTest, MetadataClearing) {
-  ComputationBuilder builder(local_client_, "add");
+  XlaBuilder builder("add");
   builder.SetOpMetadata(metadata_);
   // Some other pretend computation here.
   builder.ClearOpMetadata();