[XLA::Interpreter] Add support for kConditional to HloEvaluator. Also enable
authorKay Zhu <kayzhu@google.com>
Tue, 27 Feb 2018 01:55:31 +0000 (17:55 -0800)
committerGunhan Gulsoy <gunan@google.com>
Tue, 27 Feb 2018 22:33:33 +0000 (14:33 -0800)
xla/tests/conditional_tests to run on interpreter.

PiperOrigin-RevId: 187110438

tensorflow/compiler/xla/service/hlo_evaluator.cc
tensorflow/compiler/xla/service/hlo_evaluator.h
tensorflow/compiler/xla/tests/BUILD

index cf8b359..afbfdac 100644 (file)
@@ -2491,6 +2491,34 @@ Status HloEvaluator::HandleCall(HloInstruction* call) {
   return Status::OK();
 }
 
+Status HloEvaluator::HandleConditional(HloInstruction* conditional) {
+  const auto& pred = GetEvaluatedLiteralFor(conditional->operand(0));
+  const auto& true_computation_arg =
+      GetEvaluatedLiteralFor(conditional->operand(1));
+  const auto& false_computation_arg =
+      GetEvaluatedLiteralFor(conditional->operand(2));
+
+  auto* true_computation = conditional->true_computation();
+  auto* false_computation = conditional->false_computation();
+
+  auto result = Literal::CreateFromShape(conditional->shape());
+  HloEvaluator embedded_evaluator;
+  if (pred.Get<bool>({})) {
+    result = embedded_evaluator
+                 .Evaluate<const Literal*>(*true_computation,
+                                           {&true_computation_arg})
+                 .ConsumeValueOrDie();
+  } else {
+    result = embedded_evaluator
+                 .Evaluate<const Literal*>(*false_computation,
+                                           {&false_computation_arg})
+                 .ConsumeValueOrDie();
+  }
+
+  evaluated_[conditional] = std::move(result);
+  return Status::OK();
+}
+
 Status HloEvaluator::Preprocess(HloInstruction* hlo) {
   VLOG(2) << "About to visit HLO: " << hlo->ToString();
   return Status::OK();
index c65d991..fc82011 100644 (file)
@@ -153,6 +153,8 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
 
   Status HandleCopy(HloInstruction* copy) override;
 
+  Status HandleConditional(HloInstruction* conditional) override;
+
   Status HandleCall(HloInstruction* call) override;
 
  private:
index f3ecfc1..19b3dfa 100644 (file)
@@ -478,6 +478,7 @@ xla_test(
 xla_test(
     name = "conditional_test",
     srcs = ["conditional_test.cc"],
+    tags = ["enable_for_xla_interpreter"],
     deps = [
         "//tensorflow/compiler/xla:xla_data_proto",
         "//tensorflow/compiler/xla/client:computation_builder",