[XLA] Add support for CustomCall in HLO parser.
authorJustin Lebar <jlebar@google.com>
Thu, 14 Dec 2017 01:45:55 +0000 (17:45 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 14 Dec 2017 01:49:39 +0000 (17:49 -0800)
PiperOrigin-RevId: 178984357

tensorflow/compiler/xla/service/hlo_instruction.cc
tensorflow/compiler/xla/tools/parser/hlo_parser.cc
tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc

index 1dab6076a52b7fb146575e9a012184e4a42137bc..220d5044a29a8ab724cf56394a9fbf7c6e4010e4 100644 (file)
@@ -2082,6 +2082,10 @@ std::vector<string> HloInstruction::ExtraAttributesToString() const {
     extra.push_back(StrCat("exponent_bits=", exponent_bits_));
     extra.push_back(StrCat("mantissa_bits=", mantissa_bits_));
   }
+  if (opcode() == HloOpcode::kCustomCall) {
+    extra.push_back(
+        StrCat("custom_call_target=\"", CEscape(custom_call_target_), "\""));
+  }
   return extra;
 }
 
index 192f134cb9454fee1bf6477d7ffe577eb75de7cd..4f67ed23801f9b8eb50b7c959f0796ca4e6c578d 100644 (file)
@@ -901,7 +901,17 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
           /*false_computation_arg=*/operands[2], *false_computation));
       break;
     }
-    case HloOpcode::kCustomCall:
+    case HloOpcode::kCustomCall: {
+      optional<string> custom_call_target;
+      attrs["custom_call_target"] = {/*required=*/true, AttrTy::kString,
+                                     &custom_call_target};
+      if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
+        return false;
+      }
+      instruction = builder->AddInstruction(HloInstruction::CreateCustomCall(
+          shape, operands, *custom_call_target));
+      break;
+    }
     case HloOpcode::kTrace:
       return TokenError(StrCat("parsing not yet implemented for op: ",
                                HloOpcodeString(opcode)));
index 3b1f81134b0a10383311a3ce1c5984679349b6f1..61d8902855f47a11716f8a60b082c6c25ea9b8af 100644 (file)
@@ -728,7 +728,20 @@ ENTRY %Parameters1.v4 () -> f32[] {
 }
 
 )"
+},
+
+// CustomCall
+{
+"CustomCall",
+R"(HloModule custom_call:
+
+ENTRY %CustomCall () -> f32[1,2,3] {
+  %constant = f32[1]{0} constant({12345})
+  ROOT %custom-call = f32[1,2,3]{0,2,1} custom-call(f32[1]{0} %constant), custom_call_target="foo\"bar"
 }
+
+)"
+},
   });
   // clang-format on
 }