HloSharding parsing from string, used by new Sharding HloMatcher for ease of use.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Thu, 24 May 2018 01:38:34 +0000 (18:38 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 24 May 2018 01:41:13 +0000 (18:41 -0700)
PiperOrigin-RevId: 197825588

tensorflow/compiler/xla/service/BUILD
tensorflow/compiler/xla/service/hlo_matchers.h
tensorflow/compiler/xla/service/hlo_matchers_test.cc
tensorflow/compiler/xla/service/hlo_sharding_test.cc
tensorflow/compiler/xla/tools/parser/hlo_parser.cc
tensorflow/compiler/xla/tools/parser/hlo_parser.h

index d172264..749873e 100644 (file)
@@ -376,6 +376,7 @@ cc_library(
         ":hlo",
         "//tensorflow/compiler/xla:test",
         "//tensorflow/compiler/xla/tests:xla_internal_test_main",
+        "//tensorflow/compiler/xla/tools/parser:hlo_parser",
         "//tensorflow/core:lib",
     ],
 )
@@ -387,7 +388,6 @@ tf_cc_test(
         ":hlo_matchers",
         "//tensorflow/compiler/xla:shape_util",
         "//tensorflow/compiler/xla/tests:xla_internal_test_main",
-        "//tensorflow/compiler/xla/tools/parser:hlo_parser",
     ],
 )
 
@@ -431,6 +431,7 @@ tf_cc_test(
         "//tensorflow/compiler/xla:util",
         "//tensorflow/compiler/xla/tests:hlo_test_base",
         "//tensorflow/compiler/xla/tests:xla_internal_test_main",
+        "//tensorflow/compiler/xla/tools/parser:hlo_parser",
     ],
 )
 
index c33bdad..dfefad3 100644 (file)
@@ -18,6 +18,7 @@ limitations under the License.
 
 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
 #include "tensorflow/compiler/xla/test.h"
+#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
 #include "tensorflow/core/lib/gtl/optional.h"
 
 namespace xla {
@@ -324,6 +325,12 @@ inline ::testing::Matcher<const ::xla::HloInstruction*> Sharding(
   return ::testing::MakeMatcher(
       new ::xla::testing::HloShardingMatcher(sharding));
 }
+// Matcher for Sharding from sharding string
+inline ::testing::Matcher<const ::xla::HloInstruction*> Sharding(
+    tensorflow::StringPiece sharding) {
+  return ::testing::MakeMatcher(new ::xla::testing::HloShardingMatcher(
+      xla::tools::ParseSharding(sharding).ValueOrDie()));
+}
 // Verifies that no HloSharding is set for an HLO instruction.
 inline ::testing::Matcher<const ::xla::HloInstruction*> NoSharding() {
   return ::testing::MakeMatcher(
index 016cc01..1d10e3c 100644 (file)
@@ -15,7 +15,6 @@ limitations under the License.
 
 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
 #include "tensorflow/compiler/xla/shape_util.h"
-#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
 
 namespace op = xla::testing::opcode_matchers;
 using ::testing::_;
@@ -147,6 +146,18 @@ TEST(HloMatchersTest, ShardingMatcher) {
                                             "param.1");
   p1->set_sharding(HloSharding::AssignDevice(1));
 
+  auto tuple_shape = ShapeUtil::MakeTupleShape(
+      {ShapeUtil::MakeShape(F32, {7}), ShapeUtil::MakeShape(S32, {9}),
+       ShapeUtil::MakeShape(F32, {11})});
+  auto p2 = HloInstruction::CreateParameter(1, tuple_shape, "param.2");
+  Array<int64> assignment({2});
+  assignment.SetValues({0, 1});
+  auto sharding = HloSharding::Tuple(
+      tuple_shape,
+      {HloSharding::Tile(ShapeUtil::MakeShape(F32, {5}), assignment),
+       HloSharding::AssignDevice(1), HloSharding::Replicate()});
+  p2->set_sharding(sharding);
+
   EXPECT_THAT(p0.get(), op::NoSharding());
   EXPECT_THAT(p0.get(),
               ::testing::Not(op::Sharding(HloSharding::AssignDevice(1))));
@@ -155,6 +166,11 @@ TEST(HloMatchersTest, ShardingMatcher) {
               ::testing::Not(op::Sharding(HloSharding::AssignDevice(0))));
   EXPECT_THAT(p1.get(), op::Sharding(HloSharding::AssignDevice(1)));
 
+  EXPECT_THAT(
+      p2.get(),
+      op::Sharding(
+          "{{f32[5] devices=[2]0,1}, {maximal device=1}, {replicated}}"));
+
   EXPECT_THAT(Explain(p0.get(), op::Sharding(HloSharding::AssignDevice(1))),
               "%param.0 = f32[5]{0} parameter(0) has no sharding (expected: "
               "{maximal device=1})");
index 3bf0d25..94d1a32 100644 (file)
@@ -13,8 +13,6 @@ See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
 
-#include "tensorflow/compiler/xla/service/hlo_sharding.h"
-
 #include <set>
 #include <unordered_map>
 #include <utility>
@@ -25,6 +23,7 @@ limitations under the License.
 #include "tensorflow/compiler/xla/test.h"
 #include "tensorflow/compiler/xla/test_helpers.h"
 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
 #include "tensorflow/compiler/xla/util.h"
 
 namespace xla {
@@ -312,5 +311,48 @@ TEST_F(HloShardingTest, OstreamTest) {
   EXPECT_EQ(oss.str(), "{f32[3,5,7,11] devices=[1,1,2,2]0,1,2,3}");
 }
 
+TEST_F(HloShardingTest, Parse) {
+  auto check = [](const HloSharding& sharding) {
+    TF_ASSERT_OK_AND_ASSIGN(auto parsed_sharding,
+                            tools::ParseSharding(sharding.ToString()));
+    EXPECT_EQ(sharding, parsed_sharding);
+  };
+  check(HloSharding::Replicate());
+  check(HloSharding::AssignDevice(2));
+  check(HloSharding::Tile(ShapeUtil::MakeShape(F32, {3, 1, 3, 7}),
+                          Array4D<int64>({{{{0}, {1}}}})));
+  // Empty tuple.
+  check(HloSharding::Tuple(ShapeUtil::MakeTupleShape({}), {}));
+  {
+    // Non-nested tuple.
+    auto tuple_shape =
+        ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {3, 1, 5, 7}),
+                                   ShapeUtil::MakeShape(F32, {3, 5, 7}),
+                                   ShapeUtil::MakeShape(F32, {3, 7})});
+    check(HloSharding::Tuple(
+        tuple_shape, {HloSharding::Tile(ShapeUtil::MakeShape(F32, {3, 1, 3, 7}),
+                                        Array4D<int64>({{{{0}, {1}}}})),
+                      HloSharding::Replicate(), HloSharding::AssignDevice(1)}));
+  }
+  {
+    // Nested tuple.
+    auto tuple_shape = ShapeUtil::MakeTupleShape(
+        {ShapeUtil::MakeShape(F32, {3, 1, 5, 7}),
+         ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {3, 5, 7}),
+                                    ShapeUtil::MakeShape(F32, {3, 7})})});
+    std::vector<HloSharding> leaf_shardings = {
+        HloSharding::Tile(ShapeUtil::MakeShape(F32, {3, 1, 3, 7}),
+                          Array4D<int64>({{{{0}, {1}}}})),
+        HloSharding::Replicate(), HloSharding::AssignDevice(1)};
+    ShapeTree<HloSharding> sharding_tree(tuple_shape, HloSharding::Replicate());
+    // Assign leaf_shardings to sharding_tree leaves.
+    auto it = leaf_shardings.begin();
+    for (auto& index_to_sharding : sharding_tree.leaves()) {
+      index_to_sharding.second = *it++;
+    }
+    check(HloSharding::Tuple(sharding_tree));
+  }
+}
+
 }  // namespace
 }  // namespace xla
index d0e7af8..e990b6a 100644 (file)
@@ -56,6 +56,11 @@ class HloParser {
   // Returns the error information.
   string GetError() const { return Join(error_, "\n"); }
 
+  // Stand alone parsing for sharding. The parser string is supposed to
+  // contain the body of the sharding, i.e. just the rhs of the "sharding={...}"
+  // attribute string.
+  StatusOr<HloSharding> ParseShardingOnly();
+
  private:
   // ParseXXX returns false if an error occurred.
   bool ParseHloModule();
@@ -2673,6 +2678,18 @@ bool HloParser::AddComputation(const string& name, HloComputation* computation,
   return true;
 }
 
+StatusOr<HloSharding> HloParser::ParseShardingOnly() {
+  lexer_.Lex();
+  OpSharding op_sharding;
+  if (!ParseSharding(&op_sharding)) {
+    return InvalidArgument("Syntax error:\n%s", GetError().c_str());
+  }
+  if (lexer_.GetKind() != TokKind::kEof) {
+    return InvalidArgument("Syntax error:\nExtra content after sharding");
+  }
+  return HloSharding::FromProto(op_sharding);
+}
+
 }  // namespace
 
 StatusOr<std::unique_ptr<HloModule>> Parse(StringPiece str,
@@ -2689,5 +2706,11 @@ StatusOr<std::unique_ptr<HloModule>> Parse(StringPiece str) {
   return Parse(str, config);
 }
 
+StatusOr<HloSharding> ParseSharding(tensorflow::StringPiece str) {
+  HloModuleConfig config;
+  HloParser parser(str, config);
+  return parser.ParseShardingOnly();
+}
+
 }  // namespace tools
 }  // namespace xla
index 2f97a2b..f7854f4 100644 (file)
@@ -36,6 +36,10 @@ StatusOr<std::unique_ptr<HloModule>> Parse(tensorflow::StringPiece str,
 // format, parses the string and creates a HloModule with default config.
 StatusOr<std::unique_ptr<HloModule>> Parse(tensorflow::StringPiece str);
 
+// Parse sharding from str. str is supposed to contain the body of the
+// sharding, i.e. just the rhs of the "sharding={...}" attribute string.
+StatusOr<HloSharding> ParseSharding(tensorflow::StringPiece str);
+
 }  // namespace tools
 }  // namespace xla