[XLA::Interpreter] Add support for kCall to HloEvaluator. Also enable
authorKay Zhu <kayzhu@google.com>
Mon, 26 Feb 2018 23:37:27 +0000 (15:37 -0800)
committerGunhan Gulsoy <gunan@google.com>
Tue, 27 Feb 2018 22:33:33 +0000 (14:33 -0800)
xla/tests/call_test to run on interpreter.

PiperOrigin-RevId: 187092587

tensorflow/compiler/xla/service/hlo_evaluator.cc
tensorflow/compiler/xla/service/hlo_evaluator.h
tensorflow/compiler/xla/tests/BUILD

index 15ae531..fd06b19 100644 (file)
@@ -2445,6 +2445,26 @@ Status HloEvaluator::HandleCopy(HloInstruction* copy) {
   return Status::OK();
 }
 
+Status HloEvaluator::HandleCall(HloInstruction* call) {
+  auto* computation = call->to_apply();
+  auto operands = call->operands();
+
+  std::vector<const Literal*> arg_literals;
+  arg_literals.reserve(operands.size());
+  for (auto operand : operands) {
+    const Literal& arg_literal = GetEvaluatedLiteralFor(operand);
+    arg_literals.push_back(&arg_literal);
+  }
+
+  HloEvaluator embedded_evaluator;
+  std::unique_ptr<Literal> result =
+      embedded_evaluator.Evaluate<const Literal*>(*computation, arg_literals)
+          .ConsumeValueOrDie();
+
+  evaluated_[call] = std::move(result);
+  return Status::OK();
+}
+
 Status HloEvaluator::Preprocess(HloInstruction* hlo) {
   VLOG(2) << "About to visit HLO: " << hlo->ToString();
   return Status::OK();
index 3b2b697..c65d991 100644 (file)
@@ -153,6 +153,8 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
 
   Status HandleCopy(HloInstruction* copy) override;
 
+  Status HandleCall(HloInstruction* call) override;
+
  private:
   // Returns the already-evaluated literal result for the instruction.
   // A Constant instruction is considered evaluated and its literal will be
index 97abf21..33fde97 100644 (file)
@@ -1143,6 +1143,9 @@ xla_test(
 xla_test(
     name = "call_test",
     srcs = ["call_test.cc"],
+    tags = [
+        "enable_for_xla_interpreter",
+    ],
     deps = [
         "//tensorflow/compiler/xla:literal_util",
         "//tensorflow/compiler/xla:shape_util",