From 864ddbc9db7611633c7320691353136b4ff557bb Mon Sep 17 00:00:00 2001 From: Skye Wanderman-Milne Date: Mon, 5 Mar 2018 11:23:29 -0800 Subject: [PATCH] Extract the EvaluateConstantTensorForEdge method from ShapeRefiner. This change introduces a new stand-alone function, EvaluateConstantTensor, pulled from ShapeRefiner. ShapeRefiner now calls this new function and the old functions are removed. I'm still depending on shape_refiner_test.cc for test coverage. This is the first step towards making smart_cond better able to evaluate constant tensors. PiperOrigin-RevId: 187894976 --- tensorflow/core/BUILD | 2 + tensorflow/core/common_runtime/constant_folding.h | 2 + .../core/common_runtime/eval_const_tensor.cc | 358 +++++++++++++++++++++ tensorflow/core/common_runtime/eval_const_tensor.h | 66 ++++ tensorflow/core/common_runtime/shape_refiner.cc | 299 +---------------- tensorflow/core/common_runtime/shape_refiner.h | 14 - 6 files changed, 434 insertions(+), 307 deletions(-) create mode 100644 tensorflow/core/common_runtime/eval_const_tensor.cc create mode 100644 tensorflow/core/common_runtime/eval_const_tensor.h diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 3a436ff..445cf5b 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -2039,6 +2039,7 @@ tf_cuda_library( CORE_CPU_BASE_HDRS = GRAPH_HDRS + [ "common_runtime/device.h", + "common_runtime/eval_const_tensor.h", "common_runtime/graph_runner.h", "common_runtime/shape_refiner.h", "framework/versions.h", @@ -2047,6 +2048,7 @@ CORE_CPU_BASE_HDRS = GRAPH_HDRS + [ tf_cuda_library( name = "core_cpu_base", srcs = [ + "common_runtime/eval_const_tensor.cc", "common_runtime/shape_refiner.cc", "common_runtime/shape_refiner.h", "framework/versions.h", diff --git a/tensorflow/core/common_runtime/constant_folding.h b/tensorflow/core/common_runtime/constant_folding.h index b1e1fb8..8459888 100644 --- a/tensorflow/core/common_runtime/constant_folding.h +++ b/tensorflow/core/common_runtime/constant_folding.h @@ -22,6 +22,8 @@ limitations under the License. #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/platform/env.h" +// TODO(skyewm): can this be combined with EvaluateConstantTensor? + namespace tensorflow { // This generator type is used to generate a name for the newly folded node diff --git a/tensorflow/core/common_runtime/eval_const_tensor.cc b/tensorflow/core/common_runtime/eval_const_tensor.cc new file mode 100644 index 0000000..6370bb5 --- /dev/null +++ b/tensorflow/core/common_runtime/eval_const_tensor.cc @@ -0,0 +1,358 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/common_runtime/eval_const_tensor.h" + +#include + +#include "tensorflow/core/common_runtime/graph_runner.h" +#include "tensorflow/core/common_runtime/shape_refiner.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/versions.pb.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/kernels/bounds_check.h" + +namespace tensorflow { + +using shape_inference::InferenceContext; + +namespace { + +// Tries to infer tensor output based on the input shapes of the node. In some +// cases, the shapes of the inputs are sufficient for inferring the contents of +// the output tensor. For example, a Shape op with fully defined input shapes +// can have its output tensor inferred. +Status TryToInferTensorOutputFromInputShapes(const Edge& edge, + const ShapeRefiner& refiner, + Tensor* output, bool* success) { + *success = false; + const Node* node = edge.src(); + InferenceContext* c = refiner.GetContext(node); + if (c == nullptr) { + return errors::FailedPrecondition("Node does not have context."); + } + + if (node->type_string() == "Shape") { + // If input shapes to the shape op are fully defined, + // we can infer the shape op's output tensor. + bool fully_defined_inputs = c->FullyDefined(c->input(0)); + if (fully_defined_inputs) { + int input_rank = c->Rank(c->input(0)); + Tensor t(node->output_type(0), TensorShape({input_rank})); + if (node->output_type(0) == DT_INT32) { + auto flat = t.flat(); + for (int i = 0; i < input_rank; i++) { + int64 dimension = c->Value(c->Dim(c->input(0), i)); + if (!FastBoundsCheck(dimension, std::numeric_limits::max())) { + return errors::InvalidArgument( + "Shape has output type int32, but dimension exceeds maximum " + "int32 value"); + } + flat(i) = static_cast(dimension); + } + } else if (node->output_type(0) == DT_INT64) { + auto flat = t.flat(); + for (int i = 0; i < input_rank; i++) { + flat(i) = c->Value(c->Dim(c->input(0), i)); + } + } else { + return errors::FailedPrecondition( + "Shape has output type that is not int32 or int64"); + } + *output = t; + *success = true; + } + } else if (node->type_string() == "Rank") { + bool rank_known = c->RankKnown(c->input(0)); + if (rank_known) { + int32 input_rank = c->Rank(c->input(0)); + Tensor t(node->output_type(0), TensorShape({})); + t.flat()(0) = input_rank; + *output = t; + *success = true; + } + } else if (node->type_string() == "Size") { + bool fully_defined_inputs = c->FullyDefined(c->input(0)); + if (fully_defined_inputs) { + int32 rank = c->Rank(c->input(0)); + Tensor t(node->output_type(0), TensorShape({})); + int64 size = 1; + for (int i = 0; i < rank; i++) { + size *= c->Value(c->Dim(c->input(0), i)); + } + if (node->output_type(0) == DT_INT32) { + if (!FastBoundsCheck(size, std::numeric_limits::max())) { + return errors::InvalidArgument( + "Size has output type int32, but size exceeds maximum int32 " + "value"); + } + t.flat()(0) = static_cast(size); + } else if (node->output_type(0) == DT_INT64) { + t.flat()(0) = size; + } else { + return errors::FailedPrecondition( + "Size has output type that is not int32 or int64"); + } + *output = t; + *success = true; + } + } + return Status::OK(); +} + +// Extracts the subgraph ending at 'target_node' that is statically computable +// and inserts into 'out_graph'. If statically computable, 'is_constant_graph' +// will be set to true. +Status ExtractConstantSubgraph( + const Node& target_node, const ShapeRefiner& refiner, + const std::unordered_map* cached_values, Graph* out_graph, + bool* is_constant_graph, + std::vector>* const_inputs) { + *is_constant_graph = false; + std::unordered_set const_inputs_added; + + if (target_node.op_def().is_stateful()) { + return Status::OK(); + } + + if (target_node.type_string() == "PlaceholderWithDefault") { + return Status::OK(); + } + + // TODO(skyewm): more of the filtering applied in input nodes below should be + // applied to target_node here + + // Identify the possibly constant subgraph by recursively iterating backwards + // through the inputs to 'target_node' until we either 1) find an already + // existing input to our subgraph 'const_inputs', 2) Discover our graph is not + // constant, or 3) Hit a root node. + + struct NodeAndRecursed { + Node* new_node = nullptr; + bool recursed = false; + }; + + std::map old_to_new_and_recursed; + Node* target_node_copy = out_graph->CopyNode(&target_node); + old_to_new_and_recursed[&target_node].new_node = target_node_copy; + old_to_new_and_recursed[&target_node].recursed = true; + + // Add the target node's inputs to seed the recursion. + std::deque edges_to_visit; + for (const Edge* e : target_node.in_edges()) { + // TODO(vrv): What do we do about control edges? Based on our + // definition of a constant graph, we should be free to ignore + // control edges since the order in which a constant graph is + // executed should be the same regardless of when nodes run: we + // should only need to recurse down data edges. + if (e->IsControlEdge()) continue; + edges_to_visit.push_back(e); + } + + *is_constant_graph = true; + + // Iterate over the set of edges to visit (backwards). + while (!edges_to_visit.empty()) { + const Edge* current_edge = edges_to_visit.front(); + edges_to_visit.pop_front(); + Node* current_node = current_edge->src(); + + // If the node is stateful, assume the graph is not constant. + if (current_node->op_def().is_stateful()) { + *is_constant_graph = false; + return Status::OK(); + } + + // During construction or import from GraphConstructor, back edges may not + // be filled in. Don't constant fold through merges at all for now. + if (IsMerge(current_node)) { + *is_constant_graph = false; + return Status::OK(); + } + + // Don't constant fold enter/exit currently either, as it's easy to end + // up with a partial frame. + if (IsEnter(current_node) || IsExit(current_node)) { + *is_constant_graph = false; + return Status::OK(); + } + + // Placeholders should never be constant folded because their outputs are + // fed by the user. Note that "Placeholder" nodes have no inputs so are + // handled below. + if (current_node->type_string() == "PlaceholderWithDefault") { + *is_constant_graph = false; + return Status::OK(); + } + + // If there is nothing more to recurse down, see if + // the generator node is a constant. + if (current_node->num_inputs() == 0) { + if (!current_node->IsConstant()) { + // Generator node is not a constant, so subgraph is not + // constant. + *is_constant_graph = false; + return Status::OK(); + } + } + + // Either the node is a constant, or the node is a potential + // intermediate node on the path from a constant. + // + // Add a copy of its node and a new edge to the new subgraph. + + // Get or create the version of 'current_node' in the new graph. + Node* current_node_copy; + // This gets or creates the NodeAndRecursed entry for current_node. + NodeAndRecursed* node_and_recursed = &old_to_new_and_recursed[current_node]; + if (node_and_recursed->new_node == nullptr) { + // First time processing this node. + current_node_copy = out_graph->CopyNode(current_node); + // Track the mapping from the original node to the new one. + node_and_recursed->new_node = current_node_copy; + } else { + current_node_copy = node_and_recursed->new_node; + } + + // Add the edge to the destination node. + { + auto it = old_to_new_and_recursed.find(current_edge->dst()); + if (it == old_to_new_and_recursed.end()) { + return errors::Internal( + "Could not find mapping from old to new copy of destination node: ", + current_edge->dst()->name()); + } + Node* dst_copy = it->second.new_node; + + out_graph->AddEdge(current_node_copy, current_edge->src_output(), + dst_copy, current_edge->dst_input()); + } + + const string& output_tensor_name = + strings::StrCat(current_node->name(), ":", current_edge->src_output()); + + // Some tensor values can be inferred. For example, a shape op + // with input shapes fully defined can have its output tensor inferred. + Tensor tensor_inferred; + bool successfully_inferred_tensor = false; + TF_RETURN_IF_ERROR(TryToInferTensorOutputFromInputShapes( + *current_edge, refiner, &tensor_inferred, + &successfully_inferred_tensor)); + if (successfully_inferred_tensor) { + const_inputs->emplace_back(output_tensor_name, tensor_inferred); + const_inputs_added.insert(output_tensor_name); + continue; + } + + // If we have a copy of the input tensor materialized already, + // then add to the list of inputs to feed and do not recurse further. + if (cached_values != nullptr) { + auto it = cached_values->find(output_tensor_name); + if (it != cached_values->end() && + const_inputs_added.count(output_tensor_name) == 0) { + const_inputs->emplace_back(output_tensor_name, it->second); + const_inputs_added.insert(output_tensor_name); + continue; + } + } + + // If this node's inputs have not been processed already, do so now. + if (!node_and_recursed->recursed) { + node_and_recursed->recursed = true; + for (const Edge* e : current_node->in_edges()) { + if (e->IsControlEdge()) continue; + edges_to_visit.push_back(e); + } + } + } + + return Status::OK(); +} + +} // namespace + +Status EvaluateConstantTensor(OutputTensor tensor, const ShapeRefiner& refiner, + const OpRegistryInterface& ops, + int32 graph_def_version, bool* evaluated, + Tensor* result, GraphRunner* graph_runner, + std::unordered_map* cached_values, + int64 max_cached_value_size, + bool disable_constant_propagation) { + *evaluated = false; + const Node* src = tensor.node; + + // Simple case: the source node is a constant + if (src->IsConstant()) { + if (result->FromProto(src->def().attr().at("value").tensor())) { + *evaluated = true; + return Status::OK(); + } + } + + if (disable_constant_propagation) { + return Status::OK(); + } + + bool is_constant_graph = false; + Graph subgraph(&ops); + auto versions = subgraph.versions(); + versions.set_producer(graph_def_version); + subgraph.set_versions(versions); + + std::vector> const_inputs; + TF_RETURN_IF_ERROR(ExtractConstantSubgraph(*src, refiner, cached_values, + &subgraph, &is_constant_graph, + &const_inputs)); + if (!is_constant_graph) { + return Status::OK(); + } + const string output_tensor_name = + strings::StrCat(src->name(), ":", tensor.index); + std::vector outputs; + + std::unique_ptr graph_runner_storage; + if (graph_runner == nullptr) { + // TODO(skyewm): Convert to std::make_unique when available. + graph_runner_storage.reset(new GraphRunner(Env::Default())); + graph_runner = graph_runner_storage.get(); + } + + // NOTE; we should pass in a function library runtime if we want + // to support constant-expression evaluation on functions. + Status s = graph_runner->Run(&subgraph, nullptr /* function_library */, + const_inputs, {output_tensor_name}, &outputs); + + // If all kernels in the constant graph are not registered + // in the process, GraphRunner::Run may fail, in which case + // we cannot propagate constants, so this is best-effort. + if (s.ok()) { + *result = outputs[0]; + *evaluated = true; + + // We memoize (small) constants evaluated so far, so + // ExtractConstantSubgraph can avoid extracting the full + // subgraph. As we build up large graphs, this avoids + // repeated computation of the early parts of a constant + // graph. + if (cached_values != nullptr && + outputs[0].TotalBytes() <= max_cached_value_size) { + (*cached_values)[output_tensor_name] = outputs[0]; + } + } + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/core/common_runtime/eval_const_tensor.h b/tensorflow/core/common_runtime/eval_const_tensor.h new file mode 100644 index 0000000..fca5a23 --- /dev/null +++ b/tensorflow/core/common_runtime/eval_const_tensor.h @@ -0,0 +1,66 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EVAL_CONST_TENSOR_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_EVAL_CONST_TENSOR_H_ + +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/lib/core/status.h" + +// TODO(skyewm): can this be combined with ConstantFold? + +namespace tensorflow { + +class GraphRunner; +class OpRegistryInterface; +class ShapeRefiner; +class Tensor; + +// Attempts to evaluate `tensor`. This will only be possible if `tensor` doesn't +// depend on any graph inputs (this function is safe to call if this isn't the +// case though). +// +// If the evaluation is successful, `evaluated` will be set to true and +// `tensor`s value returned in `result`. Otherwise `evaluated` will be set to +// false. An error status is returned if something is wrong with the graph or +// input. Note that `evaluated` may set to false if Status::OK() is returned. +// +// Params: +// tensor - the tensor to be evaluated. +// refiner - used to fetch the InferenceContexts for nodes in the graph. +// ops - the OpRegistryInterface for the graph. +// graph_def_version - the producer version of the graph. +// evaluated - output param indicating whether evaluation was successful. +// result - output param containing the result if evaluated is true. +// graph_runner - optional. If not set, a GraphRunner will be created for +// evaluating tensor. This can be set to avoid creating a new GraphRunner +// for every call. +// cached_values - optional. This can be used to cache evaluated results +// across calls, to avoid evaluating the same parts of the graph multiple +// times. +// max_cached_value_size - optional. If `cached_values` is set, the maximum +// result size to cache. +// disable_constant_propagation - if true, only Const node values will be +// returned. +Status EvaluateConstantTensor( + OutputTensor tensor, const ShapeRefiner& refiner, + const OpRegistryInterface& ops, int32 graph_def_version, bool* evaluated, + Tensor* result, GraphRunner* graph_runner = nullptr, + std::unordered_map* cached_values = nullptr, + int64 max_cached_value_size = 1024, + bool disable_constant_propagation = false); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_EVAL_CONST_TENSOR_H_ diff --git a/tensorflow/core/common_runtime/shape_refiner.cc b/tensorflow/core/common_runtime/shape_refiner.cc index 2acaa31..cef50be 100644 --- a/tensorflow/core/common_runtime/shape_refiner.cc +++ b/tensorflow/core/common_runtime/shape_refiner.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "tensorflow/core/common_runtime/eval_const_tensor.h" #include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/tensor.h" @@ -407,301 +408,13 @@ Status ShapeRefiner::EvaluateConstantTensorForEdge(const Node* node, int dst_idx, bool* evaluated, Tensor* result) { *evaluated = false; - const Edge* input_edge; TF_RETURN_IF_ERROR(node->input_edge(dst_idx, &input_edge)); - - // Simple case: the source node is a constant - const Node* src = input_edge->src(); - if (src->IsConstant()) { - if (result->FromProto(src->def().attr().at("value").tensor())) { - *evaluated = true; - return Status::OK(); - } - } - - if (disable_constant_propagation_) { - return Status::OK(); - } - - bool is_constant_graph = false; - Graph subgraph(ops_registry_); - auto versions = subgraph.versions(); - versions.set_producer(graph_def_version_); - subgraph.set_versions(versions); - - // We identify the possibly constant subgraph to evaluate by - // recursively iterating backwards through the inputs to 'node' - // until we either 1) find an already existing input to our subgraph - // (filled in `const_inputs`), 2) Discover our graph is not constant, - // or 3) Hit a root node. - std::vector> const_inputs; - TF_RETURN_IF_ERROR(ExtractConstantSubgraph( - input_edge->src(), &subgraph, &is_constant_graph, &const_inputs)); - if (!is_constant_graph) { - return Status::OK(); - } - const string output_tensor_name = - strings::StrCat(input_edge->src()->name(), ":", input_edge->src_output()); - std::vector outputs; - - // NOTE; we should pass in a function library runtime if we want - // to support constant-expression evaluation on functions. - Status s = graph_runner_.Run(&subgraph, nullptr /* function_library */, - const_inputs, {output_tensor_name}, &outputs); - - // If all kernels in the constant graph are not registered - // in the process, GraphRunner::Run may fail, in which case - // we cannot propagate constants, so this is best-effort. - if (s.ok()) { - *result = outputs[0]; - *evaluated = true; - - // We memoize (small) constants evaluated so far, so - // ExtractConstantSubgraph can avoid extracting the full - // subgraph. As we build up large graphs, this avoids - // repeated computation of the early parts of a constant - // graph. - if (outputs[0].TotalBytes() <= kMaxTensorSize) { - const_tensor_map_[output_tensor_name] = outputs[0]; - } - } - return Status::OK(); -} - -Status ShapeRefiner::TryToInferTensorOutputFromInputShapes(const Edge* edge, - Tensor* output, - bool* success) { - *success = false; - const Node* node = edge->src(); - auto it = node_to_context_.find(node); - if (it == node_to_context_.end()) { - return errors::FailedPrecondition("Node does not have context."); - } - InferenceContext* c = it->second->get_context(); - - if (node->type_string() == "Shape") { - // If input shapes to the shape op are fully defined, - // we can infer the shape op's output tensor. - bool fully_defined_inputs = c->FullyDefined(c->input(0)); - if (fully_defined_inputs) { - int input_rank = c->Rank(c->input(0)); - Tensor t(node->output_type(0), TensorShape({input_rank})); - if (node->output_type(0) == DT_INT32) { - auto flat = t.flat(); - for (int i = 0; i < input_rank; i++) { - int64 dimension = c->Value(c->Dim(c->input(0), i)); - if (!FastBoundsCheck(dimension, std::numeric_limits::max())) { - return errors::FailedPrecondition( - "Shape has output type int32, but dimension exceeds maximum " - "int32 value"); - } - flat(i) = static_cast(dimension); - } - } else if (node->output_type(0) == DT_INT64) { - auto flat = t.flat(); - for (int i = 0; i < input_rank; i++) { - flat(i) = c->Value(c->Dim(c->input(0), i)); - } - } else { - return errors::FailedPrecondition( - "Shape has output type that is not int32 or int64"); - } - *output = t; - *success = true; - } - } else if (node->type_string() == "Rank") { - bool rank_known = c->RankKnown(c->input(0)); - if (rank_known) { - int32 input_rank = c->Rank(c->input(0)); - Tensor t(node->output_type(0), TensorShape({})); - t.flat()(0) = input_rank; - *output = t; - *success = true; - } - } else if (node->type_string() == "Size") { - bool fully_defined_inputs = c->FullyDefined(c->input(0)); - if (fully_defined_inputs) { - int32 rank = c->Rank(c->input(0)); - Tensor t(node->output_type(0), TensorShape({})); - int64 size = 1; - for (int i = 0; i < rank; i++) { - size *= c->Value(c->Dim(c->input(0), i)); - } - if (node->output_type(0) == DT_INT32) { - if (!FastBoundsCheck(size, std::numeric_limits::max())) { - return errors::FailedPrecondition( - "Size has output type int32, but size exceeds maximum int32 " - "value"); - } - t.flat()(0) = static_cast(size); - } else if (node->output_type(0) == DT_INT64) { - t.flat()(0) = size; - } else { - return errors::FailedPrecondition( - "Size has output type that is not int32 or int64"); - } - *output = t; - *success = true; - } - } - return Status::OK(); -} - -Status ShapeRefiner::ExtractConstantSubgraph( - Node* target_node, Graph* out_graph, bool* is_constant_graph, - std::vector>* const_inputs) { - *is_constant_graph = false; - std::unordered_set const_inputs_added; - - if (target_node->op_def().is_stateful()) { - return Status::OK(); - } - - if (target_node->type_string() == "PlaceholderWithDefault") { - return Status::OK(); - } - - // TODO(skyewm): more of the filtering applied in input nodes below should be - // applied to target_node here - - struct NodeAndRecursed { - Node* new_node = nullptr; - bool recursed = false; - }; - - std::map old_to_new_and_recursed; - Node* target_node_copy = out_graph->CopyNode(target_node); - old_to_new_and_recursed[target_node].new_node = target_node_copy; - old_to_new_and_recursed[target_node].recursed = true; - - // Add the target node's inputs to seed the recursion. - std::deque edges_to_visit; - for (const Edge* e : target_node->in_edges()) { - // TODO(vrv): What do we do about control edges? Based on our - // definition of a constant graph, we should be free to ignore - // control edges since the order in which a constant graph is - // executed should be the same regardless of when nodes run: we - // should only need to recurse down data edges. - if (e->IsControlEdge()) continue; - edges_to_visit.push_back(e); - } - - *is_constant_graph = true; - - // Iterate over the set of edges to visit (backwards). - while (!edges_to_visit.empty()) { - const Edge* current_edge = edges_to_visit.front(); - edges_to_visit.pop_front(); - Node* current_node = current_edge->src(); - - // If the node is stateful, assume the graph is not constant. - if (current_node->op_def().is_stateful()) { - *is_constant_graph = false; - return Status::OK(); - } - - // During construction or import from GraphConstructor, back edges may not - // be filled in. Don't constant fold through merges at all for now. - if (IsMerge(current_node)) { - *is_constant_graph = false; - return Status::OK(); - } - - // Don't constant fold enter/exit currently either, as it's easy to end - // up with a partial frame. - if (IsEnter(current_node) || IsExit(current_node)) { - *is_constant_graph = false; - return Status::OK(); - } - - // Placeholders should never be constant folded because their outputs are - // fed by the user. Note that "Placeholder" nodes have no inputs so are - // handled below. - if (current_node->type_string() == "PlaceholderWithDefault") { - *is_constant_graph = false; - return Status::OK(); - } - - // If there is nothing more to recurse down, see if - // the generator node is a constant. - if (current_node->num_inputs() == 0) { - if (!current_node->IsConstant()) { - // Generator node is not a constant, so subgraph is not - // constant. - *is_constant_graph = false; - return Status::OK(); - } - } - - // Either the node is a constant, or the node is a potential - // intermediate node on the path from a constant. - // - // Add a copy of its node and a new edge to the new subgraph. - - // Get or create the version of 'current_node' in the new graph. - Node* current_node_copy; - // This gets or creates the NodeAndRecursed entry for current_node. - NodeAndRecursed* node_and_recursed = &old_to_new_and_recursed[current_node]; - if (node_and_recursed->new_node == nullptr) { - // First time processing this node. - current_node_copy = out_graph->CopyNode(current_node); - // Track the mapping from the original node to the new one. - node_and_recursed->new_node = current_node_copy; - } else { - current_node_copy = node_and_recursed->new_node; - } - - // Add the edge to the destination node. - { - auto it = old_to_new_and_recursed.find(current_edge->dst()); - if (it == old_to_new_and_recursed.end()) { - return errors::Internal( - "Could not find mapping from old to new copy of destination node: ", - current_edge->dst()->name()); - } - Node* dst_copy = it->second.new_node; - - out_graph->AddEdge(current_node_copy, current_edge->src_output(), - dst_copy, current_edge->dst_input()); - } - - const string& output_tensor_name = - strings::StrCat(current_node->name(), ":", current_edge->src_output()); - - // Some tensor values can be inferred. For example, a shape op - // with input shapes fully defined can have its output tensor inferred. - Tensor tensor_inferred; - bool successfully_inferred_tensor = false; - TF_RETURN_IF_ERROR(TryToInferTensorOutputFromInputShapes( - current_edge, &tensor_inferred, &successfully_inferred_tensor)); - if (successfully_inferred_tensor) { - const_inputs->emplace_back(output_tensor_name, tensor_inferred); - const_inputs_added.insert(output_tensor_name); - continue; - } - - // If we have a copy of the input tensor materialized already, - // then add to the list of inputs to feed and do not recurse further. - auto it = const_tensor_map_.find(output_tensor_name); - if (it != const_tensor_map_.end() && - const_inputs_added.count(output_tensor_name) == 0) { - const_inputs->emplace_back(output_tensor_name, it->second); - const_inputs_added.insert(output_tensor_name); - continue; - } - - // If this node's inputs have not been processed already, do so now. - if (!node_and_recursed->recursed) { - node_and_recursed->recursed = true; - for (const Edge* e : current_node->in_edges()) { - if (e->IsControlEdge()) continue; - edges_to_visit.push_back(e); - } - } - } - - return Status::OK(); + OutputTensor tensor(input_edge->src(), input_edge->src_output()); + return EvaluateConstantTensor(tensor, *this, *ops_registry_, + graph_def_version_, evaluated, result, + &graph_runner_, &const_tensor_map_, + kMaxTensorSize, disable_constant_propagation_); } Status ShapeRefiner::ConstantPartialShape(InferenceContext* target_context, diff --git a/tensorflow/core/common_runtime/shape_refiner.h b/tensorflow/core/common_runtime/shape_refiner.h index 75eb5bf..d49c437 100644 --- a/tensorflow/core/common_runtime/shape_refiner.h +++ b/tensorflow/core/common_runtime/shape_refiner.h @@ -215,20 +215,6 @@ class ShapeRefiner { bool keep_nested_shapes, ExtendedInferenceContext* outer_context); - // Tries to infer tensor output based on the input shapes of the node. In some - // cases, the shapes of the inputs are sufficient for inferring the contents - // of the output tensor. For example, a Shape op with fully defined input - // shapes can have its output tensor inferred. - Status TryToInferTensorOutputFromInputShapes(const Edge* edge, Tensor* output, - bool* success); - - // Extracts the subgraph ending at 'node' that is statically - // computable and inserts into 'out_graph'. If statically computable, - // 'is_constant_graph' will be true. - Status ExtractConstantSubgraph( - Node* node, Graph* out_graph, bool* is_constant_graph, - std::vector>* const_inputs) TF_MUST_USE_RESULT; - Status EvaluateConstantTensorForEdge(const Node* node, int dst_idx, bool* evaluated, Tensor* result); -- 2.7.4