From: Kay Zhu Date: Mon, 26 Feb 2018 23:37:27 +0000 (-0800) Subject: [XLA::Interpreter] Add support for kCall to HloEvaluator. Also enable X-Git-Tag: upstream/v1.7.0~96 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=854a07650f33be545441a08f5db84a0f05a8b88e;p=platform%2Fupstream%2Ftensorflow.git [XLA::Interpreter] Add support for kCall to HloEvaluator. Also enable xla/tests/call_test to run on interpreter. PiperOrigin-RevId: 187092587 --- diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index 15ae531..fd06b19 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -2445,6 +2445,26 @@ Status HloEvaluator::HandleCopy(HloInstruction* copy) { return Status::OK(); } +Status HloEvaluator::HandleCall(HloInstruction* call) { + auto* computation = call->to_apply(); + auto operands = call->operands(); + + std::vector arg_literals; + arg_literals.reserve(operands.size()); + for (auto operand : operands) { + const Literal& arg_literal = GetEvaluatedLiteralFor(operand); + arg_literals.push_back(&arg_literal); + } + + HloEvaluator embedded_evaluator; + std::unique_ptr result = + embedded_evaluator.Evaluate(*computation, arg_literals) + .ConsumeValueOrDie(); + + evaluated_[call] = std::move(result); + return Status::OK(); +} + Status HloEvaluator::Preprocess(HloInstruction* hlo) { VLOG(2) << "About to visit HLO: " << hlo->ToString(); return Status::OK(); diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h index 3b2b697..c65d991 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator.h @@ -153,6 +153,8 @@ class HloEvaluator : public DfsHloVisitorWithDefault { Status HandleCopy(HloInstruction* copy) override; + Status HandleCall(HloInstruction* call) override; + private: // Returns the already-evaluated literal result for the instruction. // A Constant instruction is considered evaluated and its literal will be diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index 97abf21..33fde97 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -1143,6 +1143,9 @@ xla_test( xla_test( name = "call_test", srcs = ["call_test.cc"], + tags = [ + "enable_for_xla_interpreter", + ], deps = [ "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util",