Extract the EvaluateConstantTensorForEdge method from ShapeRefiner.
authorSkye Wanderman-Milne <skyewm@google.com>
Mon, 5 Mar 2018 19:23:29 +0000 (11:23 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 5 Mar 2018 19:28:50 +0000 (11:28 -0800)
This change introduces a new stand-alone function, EvaluateConstantTensor,
pulled from ShapeRefiner. ShapeRefiner now calls this new function and the
old functions are removed.

I'm still depending on shape_refiner_test.cc for test coverage.

This is the first step towards making smart_cond better able to evaluate
constant tensors.

PiperOrigin-RevId: 187894976

tensorflow/core/BUILD
tensorflow/core/common_runtime/constant_folding.h
tensorflow/core/common_runtime/eval_const_tensor.cc [new file with mode: 0644]
tensorflow/core/common_runtime/eval_const_tensor.h [new file with mode: 0644]
tensorflow/core/common_runtime/shape_refiner.cc
tensorflow/core/common_runtime/shape_refiner.h

index 3a436ff..445cf5b 100644 (file)
@@ -2039,6 +2039,7 @@ tf_cuda_library(
 
 CORE_CPU_BASE_HDRS = GRAPH_HDRS + [
     "common_runtime/device.h",
+    "common_runtime/eval_const_tensor.h",
     "common_runtime/graph_runner.h",
     "common_runtime/shape_refiner.h",
     "framework/versions.h",
@@ -2047,6 +2048,7 @@ CORE_CPU_BASE_HDRS = GRAPH_HDRS + [
 tf_cuda_library(
     name = "core_cpu_base",
     srcs = [
+        "common_runtime/eval_const_tensor.cc",
         "common_runtime/shape_refiner.cc",
         "common_runtime/shape_refiner.h",
         "framework/versions.h",
index b1e1fb8..8459888 100644 (file)
@@ -22,6 +22,8 @@ limitations under the License.
 #include "tensorflow/core/graph/graph.h"
 #include "tensorflow/core/platform/env.h"
 
+// TODO(skyewm): can this be combined with EvaluateConstantTensor?
+
 namespace tensorflow {
 
 // This generator type is used to generate a name for the newly folded node
diff --git a/tensorflow/core/common_runtime/eval_const_tensor.cc b/tensorflow/core/common_runtime/eval_const_tensor.cc
new file mode 100644 (file)
index 0000000..6370bb5
--- /dev/null
@@ -0,0 +1,358 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/common_runtime/eval_const_tensor.h"
+
+#include <deque>
+
+#include "tensorflow/core/common_runtime/graph_runner.h"
+#include "tensorflow/core/common_runtime/shape_refiner.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/versions.pb.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/kernels/bounds_check.h"
+
+namespace tensorflow {
+
+using shape_inference::InferenceContext;
+
+namespace {
+
+// Tries to infer tensor output based on the input shapes of the node. In some
+// cases, the shapes of the inputs are sufficient for inferring the contents of
+// the output tensor. For example, a Shape op with fully defined input shapes
+// can have its output tensor inferred.
+Status TryToInferTensorOutputFromInputShapes(const Edge& edge,
+                                             const ShapeRefiner& refiner,
+                                             Tensor* output, bool* success) {
+  *success = false;
+  const Node* node = edge.src();
+  InferenceContext* c = refiner.GetContext(node);
+  if (c == nullptr) {
+    return errors::FailedPrecondition("Node does not have context.");
+  }
+
+  if (node->type_string() == "Shape") {
+    // If input shapes to the shape op are fully defined,
+    // we can infer the shape op's output tensor.
+    bool fully_defined_inputs = c->FullyDefined(c->input(0));
+    if (fully_defined_inputs) {
+      int input_rank = c->Rank(c->input(0));
+      Tensor t(node->output_type(0), TensorShape({input_rank}));
+      if (node->output_type(0) == DT_INT32) {
+        auto flat = t.flat<int>();
+        for (int i = 0; i < input_rank; i++) {
+          int64 dimension = c->Value(c->Dim(c->input(0), i));
+          if (!FastBoundsCheck(dimension, std::numeric_limits<int32>::max())) {
+            return errors::InvalidArgument(
+                "Shape has output type int32, but dimension exceeds maximum "
+                "int32 value");
+          }
+          flat(i) = static_cast<int32>(dimension);
+        }
+      } else if (node->output_type(0) == DT_INT64) {
+        auto flat = t.flat<int64>();
+        for (int i = 0; i < input_rank; i++) {
+          flat(i) = c->Value(c->Dim(c->input(0), i));
+        }
+      } else {
+        return errors::FailedPrecondition(
+            "Shape has output type that is not int32 or int64");
+      }
+      *output = t;
+      *success = true;
+    }
+  } else if (node->type_string() == "Rank") {
+    bool rank_known = c->RankKnown(c->input(0));
+    if (rank_known) {
+      int32 input_rank = c->Rank(c->input(0));
+      Tensor t(node->output_type(0), TensorShape({}));
+      t.flat<int32>()(0) = input_rank;
+      *output = t;
+      *success = true;
+    }
+  } else if (node->type_string() == "Size") {
+    bool fully_defined_inputs = c->FullyDefined(c->input(0));
+    if (fully_defined_inputs) {
+      int32 rank = c->Rank(c->input(0));
+      Tensor t(node->output_type(0), TensorShape({}));
+      int64 size = 1;
+      for (int i = 0; i < rank; i++) {
+        size *= c->Value(c->Dim(c->input(0), i));
+      }
+      if (node->output_type(0) == DT_INT32) {
+        if (!FastBoundsCheck(size, std::numeric_limits<int32>::max())) {
+          return errors::InvalidArgument(
+              "Size has output type int32, but size exceeds maximum int32 "
+              "value");
+        }
+        t.flat<int32>()(0) = static_cast<int32>(size);
+      } else if (node->output_type(0) == DT_INT64) {
+        t.flat<int64>()(0) = size;
+      } else {
+        return errors::FailedPrecondition(
+            "Size has output type that is not int32 or int64");
+      }
+      *output = t;
+      *success = true;
+    }
+  }
+  return Status::OK();
+}
+
+// Extracts the subgraph ending at 'target_node' that is statically computable
+// and inserts into 'out_graph'. If statically computable, 'is_constant_graph'
+// will be set to true.
+Status ExtractConstantSubgraph(
+    const Node& target_node, const ShapeRefiner& refiner,
+    const std::unordered_map<string, Tensor>* cached_values, Graph* out_graph,
+    bool* is_constant_graph,
+    std::vector<std::pair<string, Tensor>>* const_inputs) {
+  *is_constant_graph = false;
+  std::unordered_set<string> const_inputs_added;
+
+  if (target_node.op_def().is_stateful()) {
+    return Status::OK();
+  }
+
+  if (target_node.type_string() == "PlaceholderWithDefault") {
+    return Status::OK();
+  }
+
+  // TODO(skyewm): more of the filtering applied in input nodes below should be
+  // applied to target_node here
+
+  // Identify the possibly constant subgraph by recursively iterating backwards
+  // through the inputs to 'target_node' until we either 1) find an already
+  // existing input to our subgraph 'const_inputs', 2) Discover our graph is not
+  // constant, or 3) Hit a root node.
+
+  struct NodeAndRecursed {
+    Node* new_node = nullptr;
+    bool recursed = false;
+  };
+
+  std::map<const Node*, NodeAndRecursed> old_to_new_and_recursed;
+  Node* target_node_copy = out_graph->CopyNode(&target_node);
+  old_to_new_and_recursed[&target_node].new_node = target_node_copy;
+  old_to_new_and_recursed[&target_node].recursed = true;
+
+  // Add the target node's inputs to seed the recursion.
+  std::deque<const Edge*> edges_to_visit;
+  for (const Edge* e : target_node.in_edges()) {
+    // TODO(vrv): What do we do about control edges?  Based on our
+    // definition of a constant graph, we should be free to ignore
+    // control edges since the order in which a constant graph is
+    // executed should be the same regardless of when nodes run: we
+    // should only need to recurse down data edges.
+    if (e->IsControlEdge()) continue;
+    edges_to_visit.push_back(e);
+  }
+
+  *is_constant_graph = true;
+
+  // Iterate over the set of edges to visit (backwards).
+  while (!edges_to_visit.empty()) {
+    const Edge* current_edge = edges_to_visit.front();
+    edges_to_visit.pop_front();
+    Node* current_node = current_edge->src();
+
+    // If the node is stateful, assume the graph is not constant.
+    if (current_node->op_def().is_stateful()) {
+      *is_constant_graph = false;
+      return Status::OK();
+    }
+
+    // During construction or import from GraphConstructor, back edges may not
+    // be filled in.  Don't constant fold through merges at all for now.
+    if (IsMerge(current_node)) {
+      *is_constant_graph = false;
+      return Status::OK();
+    }
+
+    // Don't constant fold enter/exit currently either, as it's easy to end
+    // up with a partial frame.
+    if (IsEnter(current_node) || IsExit(current_node)) {
+      *is_constant_graph = false;
+      return Status::OK();
+    }
+
+    // Placeholders should never be constant folded because their outputs are
+    // fed by the user. Note that "Placeholder" nodes have no inputs so are
+    // handled below.
+    if (current_node->type_string() == "PlaceholderWithDefault") {
+      *is_constant_graph = false;
+      return Status::OK();
+    }
+
+    // If there is nothing more to recurse down, see if
+    // the generator node is a constant.
+    if (current_node->num_inputs() == 0) {
+      if (!current_node->IsConstant()) {
+        // Generator node is not a constant, so subgraph is not
+        // constant.
+        *is_constant_graph = false;
+        return Status::OK();
+      }
+    }
+
+    // Either the node is a constant, or the node is a potential
+    // intermediate node on the path from a constant.
+    //
+    // Add a copy of its node and a new edge to the new subgraph.
+
+    // Get or create the version of 'current_node' in the new graph.
+    Node* current_node_copy;
+    // This gets or creates the NodeAndRecursed entry for current_node.
+    NodeAndRecursed* node_and_recursed = &old_to_new_and_recursed[current_node];
+    if (node_and_recursed->new_node == nullptr) {
+      // First time processing this node.
+      current_node_copy = out_graph->CopyNode(current_node);
+      // Track the mapping from the original node to the new one.
+      node_and_recursed->new_node = current_node_copy;
+    } else {
+      current_node_copy = node_and_recursed->new_node;
+    }
+
+    // Add the edge to the destination node.
+    {
+      auto it = old_to_new_and_recursed.find(current_edge->dst());
+      if (it == old_to_new_and_recursed.end()) {
+        return errors::Internal(
+            "Could not find mapping from old to new copy of destination node: ",
+            current_edge->dst()->name());
+      }
+      Node* dst_copy = it->second.new_node;
+
+      out_graph->AddEdge(current_node_copy, current_edge->src_output(),
+                         dst_copy, current_edge->dst_input());
+    }
+
+    const string& output_tensor_name =
+        strings::StrCat(current_node->name(), ":", current_edge->src_output());
+
+    // Some tensor values can be inferred. For example, a shape op
+    // with input shapes fully defined can have its output tensor inferred.
+    Tensor tensor_inferred;
+    bool successfully_inferred_tensor = false;
+    TF_RETURN_IF_ERROR(TryToInferTensorOutputFromInputShapes(
+        *current_edge, refiner, &tensor_inferred,
+        &successfully_inferred_tensor));
+    if (successfully_inferred_tensor) {
+      const_inputs->emplace_back(output_tensor_name, tensor_inferred);
+      const_inputs_added.insert(output_tensor_name);
+      continue;
+    }
+
+    // If we have a copy of the input tensor materialized already,
+    // then add to the list of inputs to feed and do not recurse further.
+    if (cached_values != nullptr) {
+      auto it = cached_values->find(output_tensor_name);
+      if (it != cached_values->end() &&
+          const_inputs_added.count(output_tensor_name) == 0) {
+        const_inputs->emplace_back(output_tensor_name, it->second);
+        const_inputs_added.insert(output_tensor_name);
+        continue;
+      }
+    }
+
+    // If this node's inputs have not been processed already, do so now.
+    if (!node_and_recursed->recursed) {
+      node_and_recursed->recursed = true;
+      for (const Edge* e : current_node->in_edges()) {
+        if (e->IsControlEdge()) continue;
+        edges_to_visit.push_back(e);
+      }
+    }
+  }
+
+  return Status::OK();
+}
+
+}  // namespace
+
+Status EvaluateConstantTensor(OutputTensor tensor, const ShapeRefiner& refiner,
+                              const OpRegistryInterface& ops,
+                              int32 graph_def_version, bool* evaluated,
+                              Tensor* result, GraphRunner* graph_runner,
+                              std::unordered_map<string, Tensor>* cached_values,
+                              int64 max_cached_value_size,
+                              bool disable_constant_propagation) {
+  *evaluated = false;
+  const Node* src = tensor.node;
+
+  // Simple case: the source node is a constant
+  if (src->IsConstant()) {
+    if (result->FromProto(src->def().attr().at("value").tensor())) {
+      *evaluated = true;
+      return Status::OK();
+    }
+  }
+
+  if (disable_constant_propagation) {
+    return Status::OK();
+  }
+
+  bool is_constant_graph = false;
+  Graph subgraph(&ops);
+  auto versions = subgraph.versions();
+  versions.set_producer(graph_def_version);
+  subgraph.set_versions(versions);
+
+  std::vector<std::pair<string, Tensor>> const_inputs;
+  TF_RETURN_IF_ERROR(ExtractConstantSubgraph(*src, refiner, cached_values,
+                                             &subgraph, &is_constant_graph,
+                                             &const_inputs));
+  if (!is_constant_graph) {
+    return Status::OK();
+  }
+  const string output_tensor_name =
+      strings::StrCat(src->name(), ":", tensor.index);
+  std::vector<Tensor> outputs;
+
+  std::unique_ptr<GraphRunner> graph_runner_storage;
+  if (graph_runner == nullptr) {
+    // TODO(skyewm): Convert to std::make_unique when available.
+    graph_runner_storage.reset(new GraphRunner(Env::Default()));
+    graph_runner = graph_runner_storage.get();
+  }
+
+  // NOTE; we should pass in a function library runtime if we want
+  // to support constant-expression evaluation on functions.
+  Status s = graph_runner->Run(&subgraph, nullptr /* function_library */,
+                               const_inputs, {output_tensor_name}, &outputs);
+
+  // If all kernels in the constant graph are not registered
+  // in the process, GraphRunner::Run may fail, in which case
+  // we cannot propagate constants, so this is best-effort.
+  if (s.ok()) {
+    *result = outputs[0];
+    *evaluated = true;
+
+    // We memoize (small) constants evaluated so far, so
+    // ExtractConstantSubgraph can avoid extracting the full
+    // subgraph.  As we build up large graphs, this avoids
+    // repeated computation of the early parts of a constant
+    // graph.
+    if (cached_values != nullptr &&
+        outputs[0].TotalBytes() <= max_cached_value_size) {
+      (*cached_values)[output_tensor_name] = outputs[0];
+    }
+  }
+  return Status::OK();
+}
+
+}  // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/eval_const_tensor.h b/tensorflow/core/common_runtime/eval_const_tensor.h
new file mode 100644 (file)
index 0000000..fca5a23
--- /dev/null
@@ -0,0 +1,66 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EVAL_CONST_TENSOR_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_EVAL_CONST_TENSOR_H_
+
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/lib/core/status.h"
+
+// TODO(skyewm): can this be combined with ConstantFold?
+
+namespace tensorflow {
+
+class GraphRunner;
+class OpRegistryInterface;
+class ShapeRefiner;
+class Tensor;
+
+// Attempts to evaluate `tensor`. This will only be possible if `tensor` doesn't
+// depend on any graph inputs (this function is safe to call if this isn't the
+// case though).
+//
+// If the evaluation is successful, `evaluated` will be set to true and
+// `tensor`s value returned in `result`. Otherwise `evaluated` will be set to
+// false. An error status is returned if something is wrong with the graph or
+// input. Note that `evaluated` may set to false if Status::OK() is returned.
+//
+// Params:
+//   tensor - the tensor to be evaluated.
+//   refiner - used to fetch the InferenceContexts for nodes in the graph.
+//   ops - the OpRegistryInterface for the graph.
+//   graph_def_version - the producer version of the graph.
+//   evaluated - output param indicating whether evaluation was successful.
+//   result - output param containing the result if evaluated is true.
+//   graph_runner - optional. If not set, a GraphRunner will be created for
+//     evaluating tensor. This can be set to avoid creating a new GraphRunner
+//     for every call.
+//   cached_values - optional. This can be used to cache evaluated results
+//     across calls, to avoid evaluating the same parts of the graph multiple
+//     times.
+//   max_cached_value_size - optional. If `cached_values` is set, the maximum
+//     result size to cache.
+//   disable_constant_propagation - if true, only Const node values will be
+//     returned.
+Status EvaluateConstantTensor(
+    OutputTensor tensor, const ShapeRefiner& refiner,
+    const OpRegistryInterface& ops, int32 graph_def_version, bool* evaluated,
+    Tensor* result, GraphRunner* graph_runner = nullptr,
+    std::unordered_map<string, Tensor>* cached_values = nullptr,
+    int64 max_cached_value_size = 1024,
+    bool disable_constant_propagation = false);
+
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_CORE_COMMON_RUNTIME_EVAL_CONST_TENSOR_H_
index 2acaa31..cef50be 100644 (file)
@@ -19,6 +19,7 @@ limitations under the License.
 #include <unordered_set>
 #include <vector>
 
+#include "tensorflow/core/common_runtime/eval_const_tensor.h"
 #include "tensorflow/core/framework/common_shape_fns.h"
 #include "tensorflow/core/framework/node_def.pb.h"
 #include "tensorflow/core/framework/tensor.h"
@@ -407,301 +408,13 @@ Status ShapeRefiner::EvaluateConstantTensorForEdge(const Node* node,
                                                    int dst_idx, bool* evaluated,
                                                    Tensor* result) {
   *evaluated = false;
-
   const Edge* input_edge;
   TF_RETURN_IF_ERROR(node->input_edge(dst_idx, &input_edge));
-
-  // Simple case: the source node is a constant
-  const Node* src = input_edge->src();
-  if (src->IsConstant()) {
-    if (result->FromProto(src->def().attr().at("value").tensor())) {
-      *evaluated = true;
-      return Status::OK();
-    }
-  }
-
-  if (disable_constant_propagation_) {
-    return Status::OK();
-  }
-
-  bool is_constant_graph = false;
-  Graph subgraph(ops_registry_);
-  auto versions = subgraph.versions();
-  versions.set_producer(graph_def_version_);
-  subgraph.set_versions(versions);
-
-  // We identify the possibly constant subgraph to evaluate by
-  // recursively iterating backwards through the inputs to 'node'
-  // until we either 1) find an already existing input to our subgraph
-  // (filled in `const_inputs`), 2) Discover our graph is not constant,
-  // or 3) Hit a root node.
-  std::vector<std::pair<string, Tensor>> const_inputs;
-  TF_RETURN_IF_ERROR(ExtractConstantSubgraph(
-      input_edge->src(), &subgraph, &is_constant_graph, &const_inputs));
-  if (!is_constant_graph) {
-    return Status::OK();
-  }
-  const string output_tensor_name =
-      strings::StrCat(input_edge->src()->name(), ":", input_edge->src_output());
-  std::vector<Tensor> outputs;
-
-  // NOTE; we should pass in a function library runtime if we want
-  // to support constant-expression evaluation on functions.
-  Status s = graph_runner_.Run(&subgraph, nullptr /* function_library */,
-                               const_inputs, {output_tensor_name}, &outputs);
-
-  // If all kernels in the constant graph are not registered
-  // in the process, GraphRunner::Run may fail, in which case
-  // we cannot propagate constants, so this is best-effort.
-  if (s.ok()) {
-    *result = outputs[0];
-    *evaluated = true;
-
-    // We memoize (small) constants evaluated so far, so
-    // ExtractConstantSubgraph can avoid extracting the full
-    // subgraph.  As we build up large graphs, this avoids
-    // repeated computation of the early parts of a constant
-    // graph.
-    if (outputs[0].TotalBytes() <= kMaxTensorSize) {
-      const_tensor_map_[output_tensor_name] = outputs[0];
-    }
-  }
-  return Status::OK();
-}
-
-Status ShapeRefiner::TryToInferTensorOutputFromInputShapes(const Edge* edge,
-                                                           Tensor* output,
-                                                           bool* success) {
-  *success = false;
-  const Node* node = edge->src();
-  auto it = node_to_context_.find(node);
-  if (it == node_to_context_.end()) {
-    return errors::FailedPrecondition("Node does not have context.");
-  }
-  InferenceContext* c = it->second->get_context();
-
-  if (node->type_string() == "Shape") {
-    // If input shapes to the shape op are fully defined,
-    // we can infer the shape op's output tensor.
-    bool fully_defined_inputs = c->FullyDefined(c->input(0));
-    if (fully_defined_inputs) {
-      int input_rank = c->Rank(c->input(0));
-      Tensor t(node->output_type(0), TensorShape({input_rank}));
-      if (node->output_type(0) == DT_INT32) {
-        auto flat = t.flat<int>();
-        for (int i = 0; i < input_rank; i++) {
-          int64 dimension = c->Value(c->Dim(c->input(0), i));
-          if (!FastBoundsCheck(dimension, std::numeric_limits<int32>::max())) {
-            return errors::FailedPrecondition(
-                "Shape has output type int32, but dimension exceeds maximum "
-                "int32 value");
-          }
-          flat(i) = static_cast<int32>(dimension);
-        }
-      } else if (node->output_type(0) == DT_INT64) {
-        auto flat = t.flat<int64>();
-        for (int i = 0; i < input_rank; i++) {
-          flat(i) = c->Value(c->Dim(c->input(0), i));
-        }
-      } else {
-        return errors::FailedPrecondition(
-            "Shape has output type that is not int32 or int64");
-      }
-      *output = t;
-      *success = true;
-    }
-  } else if (node->type_string() == "Rank") {
-    bool rank_known = c->RankKnown(c->input(0));
-    if (rank_known) {
-      int32 input_rank = c->Rank(c->input(0));
-      Tensor t(node->output_type(0), TensorShape({}));
-      t.flat<int32>()(0) = input_rank;
-      *output = t;
-      *success = true;
-    }
-  } else if (node->type_string() == "Size") {
-    bool fully_defined_inputs = c->FullyDefined(c->input(0));
-    if (fully_defined_inputs) {
-      int32 rank = c->Rank(c->input(0));
-      Tensor t(node->output_type(0), TensorShape({}));
-      int64 size = 1;
-      for (int i = 0; i < rank; i++) {
-        size *= c->Value(c->Dim(c->input(0), i));
-      }
-      if (node->output_type(0) == DT_INT32) {
-        if (!FastBoundsCheck(size, std::numeric_limits<int32>::max())) {
-          return errors::FailedPrecondition(
-              "Size has output type int32, but size exceeds maximum int32 "
-              "value");
-        }
-        t.flat<int32>()(0) = static_cast<int32>(size);
-      } else if (node->output_type(0) == DT_INT64) {
-        t.flat<int64>()(0) = size;
-      } else {
-        return errors::FailedPrecondition(
-            "Size has output type that is not int32 or int64");
-      }
-      *output = t;
-      *success = true;
-    }
-  }
-  return Status::OK();
-}
-
-Status ShapeRefiner::ExtractConstantSubgraph(
-    Node* target_node, Graph* out_graph, bool* is_constant_graph,
-    std::vector<std::pair<string, Tensor>>* const_inputs) {
-  *is_constant_graph = false;
-  std::unordered_set<string> const_inputs_added;
-
-  if (target_node->op_def().is_stateful()) {
-    return Status::OK();
-  }
-
-  if (target_node->type_string() == "PlaceholderWithDefault") {
-    return Status::OK();
-  }
-
-  // TODO(skyewm): more of the filtering applied in input nodes below should be
-  // applied to target_node here
-
-  struct NodeAndRecursed {
-    Node* new_node = nullptr;
-    bool recursed = false;
-  };
-
-  std::map<Node*, NodeAndRecursed> old_to_new_and_recursed;
-  Node* target_node_copy = out_graph->CopyNode(target_node);
-  old_to_new_and_recursed[target_node].new_node = target_node_copy;
-  old_to_new_and_recursed[target_node].recursed = true;
-
-  // Add the target node's inputs to seed the recursion.
-  std::deque<const Edge*> edges_to_visit;
-  for (const Edge* e : target_node->in_edges()) {
-    // TODO(vrv): What do we do about control edges?  Based on our
-    // definition of a constant graph, we should be free to ignore
-    // control edges since the order in which a constant graph is
-    // executed should be the same regardless of when nodes run: we
-    // should only need to recurse down data edges.
-    if (e->IsControlEdge()) continue;
-    edges_to_visit.push_back(e);
-  }
-
-  *is_constant_graph = true;
-
-  // Iterate over the set of edges to visit (backwards).
-  while (!edges_to_visit.empty()) {
-    const Edge* current_edge = edges_to_visit.front();
-    edges_to_visit.pop_front();
-    Node* current_node = current_edge->src();
-
-    // If the node is stateful, assume the graph is not constant.
-    if (current_node->op_def().is_stateful()) {
-      *is_constant_graph = false;
-      return Status::OK();
-    }
-
-    // During construction or import from GraphConstructor, back edges may not
-    // be filled in.  Don't constant fold through merges at all for now.
-    if (IsMerge(current_node)) {
-      *is_constant_graph = false;
-      return Status::OK();
-    }
-
-    // Don't constant fold enter/exit currently either, as it's easy to end
-    // up with a partial frame.
-    if (IsEnter(current_node) || IsExit(current_node)) {
-      *is_constant_graph = false;
-      return Status::OK();
-    }
-
-    // Placeholders should never be constant folded because their outputs are
-    // fed by the user. Note that "Placeholder" nodes have no inputs so are
-    // handled below.
-    if (current_node->type_string() == "PlaceholderWithDefault") {
-      *is_constant_graph = false;
-      return Status::OK();
-    }
-
-    // If there is nothing more to recurse down, see if
-    // the generator node is a constant.
-    if (current_node->num_inputs() == 0) {
-      if (!current_node->IsConstant()) {
-        // Generator node is not a constant, so subgraph is not
-        // constant.
-        *is_constant_graph = false;
-        return Status::OK();
-      }
-    }
-
-    // Either the node is a constant, or the node is a potential
-    // intermediate node on the path from a constant.
-    //
-    // Add a copy of its node and a new edge to the new subgraph.
-
-    // Get or create the version of 'current_node' in the new graph.
-    Node* current_node_copy;
-    // This gets or creates the NodeAndRecursed entry for current_node.
-    NodeAndRecursed* node_and_recursed = &old_to_new_and_recursed[current_node];
-    if (node_and_recursed->new_node == nullptr) {
-      // First time processing this node.
-      current_node_copy = out_graph->CopyNode(current_node);
-      // Track the mapping from the original node to the new one.
-      node_and_recursed->new_node = current_node_copy;
-    } else {
-      current_node_copy = node_and_recursed->new_node;
-    }
-
-    // Add the edge to the destination node.
-    {
-      auto it = old_to_new_and_recursed.find(current_edge->dst());
-      if (it == old_to_new_and_recursed.end()) {
-        return errors::Internal(
-            "Could not find mapping from old to new copy of destination node: ",
-            current_edge->dst()->name());
-      }
-      Node* dst_copy = it->second.new_node;
-
-      out_graph->AddEdge(current_node_copy, current_edge->src_output(),
-                         dst_copy, current_edge->dst_input());
-    }
-
-    const string& output_tensor_name =
-        strings::StrCat(current_node->name(), ":", current_edge->src_output());
-
-    // Some tensor values can be inferred. For example, a shape op
-    // with input shapes fully defined can have its output tensor inferred.
-    Tensor tensor_inferred;
-    bool successfully_inferred_tensor = false;
-    TF_RETURN_IF_ERROR(TryToInferTensorOutputFromInputShapes(
-        current_edge, &tensor_inferred, &successfully_inferred_tensor));
-    if (successfully_inferred_tensor) {
-      const_inputs->emplace_back(output_tensor_name, tensor_inferred);
-      const_inputs_added.insert(output_tensor_name);
-      continue;
-    }
-
-    // If we have a copy of the input tensor materialized already,
-    // then add to the list of inputs to feed and do not recurse further.
-    auto it = const_tensor_map_.find(output_tensor_name);
-    if (it != const_tensor_map_.end() &&
-        const_inputs_added.count(output_tensor_name) == 0) {
-      const_inputs->emplace_back(output_tensor_name, it->second);
-      const_inputs_added.insert(output_tensor_name);
-      continue;
-    }
-
-    // If this node's inputs have not been processed already, do so now.
-    if (!node_and_recursed->recursed) {
-      node_and_recursed->recursed = true;
-      for (const Edge* e : current_node->in_edges()) {
-        if (e->IsControlEdge()) continue;
-        edges_to_visit.push_back(e);
-      }
-    }
-  }
-
-  return Status::OK();
+  OutputTensor tensor(input_edge->src(), input_edge->src_output());
+  return EvaluateConstantTensor(tensor, *this, *ops_registry_,
+                                graph_def_version_, evaluated, result,
+                                &graph_runner_, &const_tensor_map_,
+                                kMaxTensorSize, disable_constant_propagation_);
 }
 
 Status ShapeRefiner::ConstantPartialShape(InferenceContext* target_context,
index 75eb5bf..d49c437 100644 (file)
@@ -215,20 +215,6 @@ class ShapeRefiner {
                                 bool keep_nested_shapes,
                                 ExtendedInferenceContext* outer_context);
 
-  // Tries to infer tensor output based on the input shapes of the node. In some
-  // cases, the shapes of the inputs are sufficient for inferring the contents
-  // of the output tensor. For example, a Shape op with fully defined input
-  // shapes can have its output tensor inferred.
-  Status TryToInferTensorOutputFromInputShapes(const Edge* edge, Tensor* output,
-                                               bool* success);
-
-  // Extracts the subgraph ending at 'node' that is statically
-  // computable and inserts into 'out_graph'. If statically computable,
-  // 'is_constant_graph' will be true.
-  Status ExtractConstantSubgraph(
-      Node* node, Graph* out_graph, bool* is_constant_graph,
-      std::vector<std::pair<string, Tensor>>* const_inputs) TF_MUST_USE_RESULT;
-
   Status EvaluateConstantTensorForEdge(const Node* node, int dst_idx,
                                        bool* evaluated, Tensor* result);