From: Kay Zhu Date: Tue, 27 Feb 2018 01:55:31 +0000 (-0800) Subject: [XLA::Interpreter] Add support for kConditional to HloEvaluator. Also enable X-Git-Tag: upstream/v1.7.0~87 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=ef7c481b0aa563ab8a3bf387e97121382cbaa588;p=platform%2Fupstream%2Ftensorflow.git [XLA::Interpreter] Add support for kConditional to HloEvaluator. Also enable xla/tests/conditional_tests to run on interpreter. PiperOrigin-RevId: 187110438 --- diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index cf8b359..afbfdac 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -2491,6 +2491,34 @@ Status HloEvaluator::HandleCall(HloInstruction* call) { return Status::OK(); } +Status HloEvaluator::HandleConditional(HloInstruction* conditional) { + const auto& pred = GetEvaluatedLiteralFor(conditional->operand(0)); + const auto& true_computation_arg = + GetEvaluatedLiteralFor(conditional->operand(1)); + const auto& false_computation_arg = + GetEvaluatedLiteralFor(conditional->operand(2)); + + auto* true_computation = conditional->true_computation(); + auto* false_computation = conditional->false_computation(); + + auto result = Literal::CreateFromShape(conditional->shape()); + HloEvaluator embedded_evaluator; + if (pred.Get({})) { + result = embedded_evaluator + .Evaluate(*true_computation, + {&true_computation_arg}) + .ConsumeValueOrDie(); + } else { + result = embedded_evaluator + .Evaluate(*false_computation, + {&false_computation_arg}) + .ConsumeValueOrDie(); + } + + evaluated_[conditional] = std::move(result); + return Status::OK(); +} + Status HloEvaluator::Preprocess(HloInstruction* hlo) { VLOG(2) << "About to visit HLO: " << hlo->ToString(); return Status::OK(); diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h index c65d991..fc82011 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator.h @@ -153,6 +153,8 @@ class HloEvaluator : public DfsHloVisitorWithDefault { Status HandleCopy(HloInstruction* copy) override; + Status HandleConditional(HloInstruction* conditional) override; + Status HandleCall(HloInstruction* call) override; private: diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index f3ecfc1..19b3dfa 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -478,6 +478,7 @@ xla_test( xla_test( name = "conditional_test", srcs = ["conditional_test.cc"], + tags = ["enable_for_xla_interpreter"], deps = [ "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:computation_builder",