Add support for parsing the "gather" HLO
authorSanjoy Das <sanjoy@google.com>
Mon, 26 Feb 2018 19:12:04 +0000 (11:12 -0800)
committerGunhan Gulsoy <gunan@google.com>
Tue, 27 Feb 2018 22:33:33 +0000 (14:33 -0800)
PiperOrigin-RevId: 187050345

tensorflow/compiler/xla/tools/parser/hlo_parser.cc
tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc

index cd2b843..e60a5a4 100644 (file)
@@ -1049,9 +1049,40 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
           HloInstruction::CreateDot(shape, operands[0], operands[1], dnum));
       break;
     }
-    case HloOpcode::kGather:
-      // TODO(b/72710576): HLO parsing is not implemented for Gather.
-      return TokenError("HLO parsing is not implemented for Gather");
+    case HloOpcode::kGather: {
+      optional<std::vector<int64>> output_window_dims;
+      attrs["output_window_dims"] = {
+          /*required=*/true, AttrTy::kBracedInt64List, &output_window_dims};
+      optional<std::vector<int64>> elided_window_dims;
+      attrs["elided_window_dims"] = {
+          /*required=*/true, AttrTy::kBracedInt64List, &elided_window_dims};
+      optional<std::vector<int64>> gather_dims_to_operand_dims;
+      attrs["gather_dims_to_operand_dims"] = {/*required=*/true,
+                                              AttrTy::kBracedInt64List,
+                                              &gather_dims_to_operand_dims};
+      optional<int64> index_vector_dim;
+      attrs["index_vector_dim"] = {/*required=*/true, AttrTy::kInt64,
+                                   &index_vector_dim};
+      optional<std::vector<int64>> window_bounds;
+      attrs["window_bounds"] = {/*required=*/true, AttrTy::kBracedInt64List,
+                                &window_bounds};
+
+      if (!ParseOperands(&operands, /*expected_size=*/2) ||
+          !ParseAttributes(attrs)) {
+        return false;
+      }
+
+      GatherDimensionNumbers dim_numbers = HloInstruction::MakeGatherDimNumbers(
+          /*output_window_dims=*/*output_window_dims,
+          /*elided_window_dims=*/*elided_window_dims,
+          /*gather_dims_to_operand_dims=*/*gather_dims_to_operand_dims,
+          /*index_vector_dim=*/*index_vector_dim);
+
+      instruction = builder->AddInstruction(HloInstruction::CreateGather(
+          shape, /*operand=*/operands[0], /*gather_indices=*/operands[1],
+          dim_numbers, *window_bounds));
+      break;
+    }
     case HloOpcode::kTrace:
       return TokenError(StrCat("parsing not yet implemented for op: ",
                                HloOpcodeString(opcode)));
index b8c6b59..863081d 100644 (file)
@@ -718,6 +718,18 @@ ENTRY %sparse_f32_r1 () -> f32[9] {
 
 )"
 },
+{
+"gather",
+R"(HloModule StringifyGather
+
+ENTRY %Gather (input_tensor: f32[50,49,48,47,46], gather_indices: s64[10,9,8,7,5]) -> f32[10,9,8,7,30,29,28,27,26] {
+  %input_tensor = f32[50,49,48,47,46]{4,3,2,1,0} parameter(0)
+  %gather_indices = s64[10,9,8,7,5]{4,3,2,1,0} parameter(1)
+  ROOT %gather = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} gather(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, s64[10,9,8,7,5]{4,3,2,1,0} %gather_indices), output_window_dims={4,5,6,7,8}, elided_window_dims={}, gather_dims_to_operand_dims={0,1,2,3,4}, index_vector_dim=4, window_bounds={30,29,28,27,26}
+}
+
+)"
+},
   });
   // clang-format on
 }
@@ -862,6 +874,18 @@ ENTRY dot {
 
 )"
 },
+{
+"gather",
+R"(HloModule gather
+
+ENTRY Gather {
+  input_tensor = f32[50,49,48,47,46]{4,3,2,1,0} parameter(0)
+  gather_indices = s64[10,9,8,7,5]{4,3,2,1,0} parameter(1)
+  ROOT gather = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} gather(input_tensor, gather_indices), output_window_dims={4,5,6,7,8}, elided_window_dims={}, gather_dims_to_operand_dims={0,1,2,3,4}, index_vector_dim=4, window_bounds={30,29,28,27,26}
+}
+
+)"
+},
   });
   // clang-format on
 }