Enable constant propagation across Switch(x,x) by rewriting the two outputs as Const...
authorA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 23 Feb 2018 03:47:03 +0000 (19:47 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 23 Feb 2018 03:51:04 +0000 (19:51 -0800)
By rewriting the graph a bit, we can propagate the constants down the two output branches, and just use control dependencies to trigger the selected one at runtime. For example,
         +------+
     x-->|Switch|-->a
     x-->|      |-->b
         +------+

Is rewritten as

         +------+
     x-->|Switch|-->Identity--^>Const(false)-->a
     x-->|      |-->Identity--^>Const(true)-->b
         +------+

(In practice there may be multiple consumers of each output branch.)

PiperOrigin-RevId: 186714991

tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
tensorflow/core/grappler/optimizers/constant_folding.cc
tensorflow/core/grappler/optimizers/constant_folding_test.cc
tensorflow/core/grappler/utils.cc
tensorflow/core/grappler/utils.h

index fbb3e5aaee09e58e1296280fa49d5e35326c3696..709a434e40e887502cac1317870eb0db8e0c2910 100644 (file)
@@ -45,45 +45,6 @@ namespace tensorflow {
 namespace grappler {
 namespace {
 
-template <typename T>
-bool SafeSetTensorValue(double value, Tensor* tensor) {
-  using RealType = typename Eigen::NumTraits<T>::Real;
-  if (value > std::numeric_limits<RealType>::max() ||
-      value < std::numeric_limits<RealType>::min()) {
-    return false;
-  }
-  tensor->flat<T>()(0) = static_cast<T>(value);
-  return true;
-}
-
-#define HANDLE_CASE(DTYPE)                                          \
-  case DTYPE:                                                       \
-    if (!SafeSetTensorValue<EnumToDataType<DTYPE>::Type>(           \
-            static_cast<double>(value), tensor)) {                  \
-      return errors::InvalidArgument("Cannot store value ", value,  \
-                                     " in tensor of type " #DTYPE); \
-    }                                                               \
-    break
-
-Status SetTensorValue(DataType dtype, int value, Tensor* tensor) {
-  switch (dtype) {
-    //    HANDLE_CASE(DT_HALF);
-    HANDLE_CASE(DT_FLOAT);
-    HANDLE_CASE(DT_DOUBLE);
-    HANDLE_CASE(DT_UINT8);
-    HANDLE_CASE(DT_INT8);
-    HANDLE_CASE(DT_UINT16);
-    HANDLE_CASE(DT_INT16);
-    HANDLE_CASE(DT_INT32);
-    HANDLE_CASE(DT_INT64);
-    HANDLE_CASE(DT_COMPLEX64);
-    HANDLE_CASE(DT_COMPLEX128);
-    default:
-      return errors::InvalidArgument("Unexpected type ", DataTypeString(dtype));
-  }
-  return Status::OK();
-}
-
 template <typename T>
 bool AreInversePermutations(const std::vector<T>& a, const std::vector<T>& b) {
   if (a.size() != b.size()) {
index 064cb8b5ae0d4fa0ed83386592d351055826fda0..182e03f04e205f4426db716b1ac29fe18c8acc7e 100644 (file)
@@ -811,44 +811,51 @@ Status ConstantFolding::CreateNodeDef(const string& name,
   // Use the packed representation whenever possible to avoid generating large
   // graphdefs. Moreover, avoid repeating the last values if they're equal.
   if (tensor->NumElements() > 4) {
-#define POPULATE_TENSOR_PROTO(tensor, t, TYPE, NAME)                \
-  const TYPE* val_ptr = tensor->flat<TYPE>().data();                \
-  TYPE last = *val_ptr;                                             \
-  int64 last_index = 0;                                             \
-  for (int64 i = 0; i < tensor->NumElements(); ++i) {               \
-    TYPE cur = *val_ptr++;                                          \
-    if (cur != last) {                                              \
-      last = cur;                                                   \
-      last_index = i;                                               \
-    }                                                               \
-  }                                                                 \
-  if (last_index < kint32max) {                                     \
-    optimized = true;                                               \
-    encoded_size = (last_index + 1) * sizeof(NAME);                 \
-    t->mutable_##NAME##_val()->Reserve(last_index + 1);             \
-    t->mutable_##NAME##_val()->AddNAlreadyReserved(last_index + 1); \
-    val_ptr = tensor->flat<TYPE>().data();                          \
-    for (int64 i = 0; i <= last_index; ++i) {                       \
-      t->set_##NAME##_val(i, *val_ptr++);                           \
-    }                                                               \
-  }
-
-    if (tensor->dtype() == DT_FLOAT) {
-      POPULATE_TENSOR_PROTO(tensor, t, float, float)
-    } else if (tensor->dtype() == DT_DOUBLE) {
-      POPULATE_TENSOR_PROTO(tensor, t, double, double)
-    } else if (tensor->dtype() == DT_INT64) {
-      POPULATE_TENSOR_PROTO(tensor, t, int64, int64)
-    } else if (tensor->dtype() == DT_INT32) {
-      POPULATE_TENSOR_PROTO(tensor, t, int32, int)
-    } else if (tensor->dtype() == DT_INT16) {
-      POPULATE_TENSOR_PROTO(tensor, t, int16, int)
-    } else if (tensor->dtype() == DT_INT8) {
-      POPULATE_TENSOR_PROTO(tensor, t, int8, int)
-    } else if (tensor->dtype() == DT_UINT8) {
-      POPULATE_TENSOR_PROTO(tensor, t, uint8, int)
-    } else if (tensor->dtype() == DT_BOOL) {
-      POPULATE_TENSOR_PROTO(tensor, t, bool, bool)
+#define POPULATE_TENSOR_PROTO(tensor, t, TYPE, NAME)                  \
+  {                                                                   \
+    const TYPE* val_ptr = tensor->flat<TYPE>().data();                \
+    TYPE last = *val_ptr;                                             \
+    int64 last_index = 0;                                             \
+    for (int64 i = 0; i < tensor->NumElements(); ++i) {               \
+      TYPE cur = *val_ptr++;                                          \
+      if (cur != last) {                                              \
+        last = cur;                                                   \
+        last_index = i;                                               \
+      }                                                               \
+    }                                                                 \
+    if (last_index < kint32max) {                                     \
+      optimized = true;                                               \
+      encoded_size = (last_index + 1) * sizeof(NAME);                 \
+      t->mutable_##NAME##_val()->Reserve(last_index + 1);             \
+      t->mutable_##NAME##_val()->AddNAlreadyReserved(last_index + 1); \
+      val_ptr = tensor->flat<TYPE>().data();                          \
+      for (int64 i = 0; i <= last_index; ++i) {                       \
+        t->set_##NAME##_val(i, *val_ptr++);                           \
+      }                                                               \
+    }                                                                 \
+  }                                                                   \
+  break
+
+    switch (tensor->dtype()) {
+      case DT_FLOAT:
+        POPULATE_TENSOR_PROTO(tensor, t, float, float);
+      case DT_DOUBLE:
+        POPULATE_TENSOR_PROTO(tensor, t, double, double);
+      case DT_INT64:
+        POPULATE_TENSOR_PROTO(tensor, t, int64, int64);
+      case DT_INT32:
+        POPULATE_TENSOR_PROTO(tensor, t, int32, int);
+      case DT_INT16:
+        POPULATE_TENSOR_PROTO(tensor, t, int16, int);
+      case DT_INT8:
+        POPULATE_TENSOR_PROTO(tensor, t, int8, int);
+      case DT_UINT8:
+        POPULATE_TENSOR_PROTO(tensor, t, uint8, int);
+      case DT_BOOL:
+        POPULATE_TENSOR_PROTO(tensor, t, bool, bool);
+      default:
+        /* Do nothing. */
+        break;
     }
   }
   if (optimized) {
@@ -1469,9 +1476,111 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output,
       for (int j = 0; j < shape.dim_size(); ++j) {
         replaceable &= shape.dim(j).size() == 1;
       }
-      if (replaceable) ReplaceOperationWithIdentity(0, node, output);
+      if (replaceable) {
+        ReplaceOperationWithIdentity(0, node, output);
+      }
     }
 
+    // Switch(x, x) will always feed false to its false branch and true to
+    // its true branch. By rewriting the graph a bit, we can propagate these
+    // constants down the two output branches, and just use control dependencies
+    // to trigger the selected one at runtime. For example,
+    //
+    //     +------+
+    // x-->|Switch|-->a  (in practice there may be multiple consumers of each
+    // x-->|      |-->b   output branch.)
+    //     +------+
+    //
+    // Is rewritten as
+    //
+    //     +------+
+    // x-->|Switch|-->Identity--^>Const(false)-->a
+    // x-->|      |-->Identity--^>Const(true)-->b
+    //     +------+
+    if (node->op() == "Switch" && node->input(0) == node->input(1) &&
+        !OptimizedNodeExists(*node, "_const_false") &&
+        !OptimizedNodeExists(*node, "_const_true")) {
+      bool already_optimized = true;
+      // If the optimization was already applied, the switch would have exactly
+      // one Identity node consuming each of its outputs, each without any
+      // non-control outputs.
+      auto fanouts = node_map_->GetOutputs(node->name());
+      if (fanouts.size() == 2) {
+        for (NodeDef* fanout : fanouts) {
+          if (!IsIdentity(*fanout) ||
+              NumNonControlOutputs(*fanout, *node_map_) > 0) {
+            already_optimized = false;
+            break;
+          }
+        }
+      }
+      Tensor false_t(DT_BOOL, TensorShape({}));
+      Tensor true_t(DT_BOOL, TensorShape({}));
+      // Make sure we don't proceed if this switch node was already optimized.
+      if (!already_optimized && SetTensorValue(DT_BOOL, true, &true_t).ok() &&
+          SetTensorValue(DT_BOOL, false, &false_t).ok()) {
+        // Copy the set of consumers of the switch as they will be manipulated
+        // below.
+        const std::set<NodeDef*>& consumer_set =
+            node_map_->GetOutputs(node->name());
+        std::vector<NodeDef*> consumers(consumer_set.begin(),
+                                        consumer_set.end());
+        std::sort(consumers.begin(), consumers.end(),
+                  [](const NodeDef* n1, const NodeDef* n2) {
+                    return n1->name() < n2->name();
+                  });
+        // Create constant false & true nodes.
+        NodeDef* false_node = output->add_node();
+        false_node->set_name(OptimizedNodeName(*node, "_const_false"));
+        if (!CreateNodeDef(false_node->name(), TensorValue(&false_t),
+                           false_node)
+                 .ok()) {
+          continue;
+        }
+        false_node->set_device(node->device());
+
+        NodeDef* true_node = output->add_node();
+        true_node->set_name(OptimizedNodeName(*node, "_const_true"));
+        if (!CreateNodeDef(true_node->name(), TensorValue(&true_t), true_node)
+                 .ok()) {
+          continue;
+        }
+        true_node->set_device(node->device());
+
+        // Add controls from the switch ports to the constants, and connect the
+        // constants to the original switch outputs.
+        const string false_port = node->name();
+        const string true_port = strings::StrCat(node->name(), ":1");
+        const string false_ctrl_dep =
+            AddControlDependency(false_port, output, node_map_.get());
+        false_node->add_input(false_ctrl_dep);
+        const string true_ctrl_dep =
+            AddControlDependency(true_port, output, node_map_.get());
+        true_node->add_input(true_ctrl_dep);
+
+        node_map_->AddNode(false_node->name(), false_node);
+        node_map_->AddNode(true_node->name(), true_node);
+        node_map_->AddOutput(NodeName(false_ctrl_dep), false_node->name());
+        node_map_->AddOutput(NodeName(true_ctrl_dep), true_node->name());
+
+        for (NodeDef* consumer : consumers) {
+          for (int i = 0; i < consumer->input_size(); ++i) {
+            const string& input = consumer->input(i);
+            if (input == false_port) {
+              consumer->set_input(i, false_node->name());
+              node_map_->UpdateInput(consumer->name(), false_port,
+                                     false_node->name());
+            } else if (input == true_port) {
+              consumer->set_input(i, true_node->name());
+              node_map_->UpdateInput(consumer->name(), true_port,
+                                     true_node->name());
+            }
+          }
+        }
+        graph_modified_ = true;
+        continue;
+      }
+    }
     if (IsSimplifiableReduction(*node)) {
       // Replace the reduction node with an identity node, that can be further
       // optimized by the model pruner.
@@ -1547,9 +1656,8 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output,
       const bool y_is_zero = IsZeros(*y);
       const bool y_is_one = IsOnes(*y);
       const bool x_matches_output_shape = ShapesEqual(output_shape, x_shape);
-      if (x_matches_output_shape &&
-          (((is_mul || is_any_div) && y_is_one) ||
-           ((is_add || is_sub) && y_is_zero))) {
+      if (x_matches_output_shape && (((is_mul || is_any_div) && y_is_one) ||
+                                     ((is_add || is_sub) && y_is_zero))) {
         // x * 1 = x or x / 1 = x or x +/- 0 = x
         ReplaceOperationWithSnapshot(0, node, output);
         continue;
@@ -1601,8 +1709,7 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output,
       }
       // Insert new reciprocal op and change node from Div to Mul.
       NodeDef* reciprocal_node = output->add_node();
-      reciprocal_node->set_name(AddPrefixToNodeName(
-          strings::StrCat(node->name(), "_recip"), kConstantFoldingConst));
+      reciprocal_node->set_name(OptimizedNodeName(*node, "_recip"));
       reciprocal_node->set_op("Reciprocal");
       reciprocal_node->set_device(node->device());
       node->set_op("Mul");
@@ -1701,6 +1808,7 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output,
       graph_modified_ = true;
     }
   }
+
   return Status::OK();
 }
 
@@ -1779,5 +1887,5 @@ void ConstantFolding::Feedback(Cluster* cluster, const GrapplerItem& item,
   // Nothing to do for ConstantFolding.
 }
 
-}  // end namespace grappler
-}  // end namespace tensorflow
+}  // namespace grappler
+}  // namespace tensorflow
index 2048692c227900cc8a1a3c3693e26573792eb37a..219f3bd5ec2a1c15078972bdea69a7642bb4af46 100644 (file)
@@ -469,7 +469,6 @@ TEST_F(ConstantFoldingTest, NeutralElement_PartialShape_KnownOutputShape) {
   GraphDef output;
   Status status = optimizer.Optimize(nullptr, item, &output);
   TF_EXPECT_OK(status);
-  LOG(INFO) << output.DebugString();
 
   EXPECT_EQ(10, output.node_size());
   for (int i = 0; i < output.node_size(); ++i) {
@@ -991,8 +990,10 @@ TEST_F(ConstantFoldingTest, SwitchNodesEmptyFetch) {
   EXPECT_EQ(present_nodes.size(), output.node_size());
   int found = 0;
   for (const auto& node : output.node()) {
-    EXPECT_TRUE(present_nodes.find(node.name()) != present_nodes.end());
-    EXPECT_TRUE(not_present_nodes.find(node.name()) == not_present_nodes.end());
+    EXPECT_TRUE(present_nodes.find(node.name()) != present_nodes.end())
+        << node.name();
+    EXPECT_TRUE(not_present_nodes.find(node.name()) == not_present_nodes.end())
+        << node.name();
     present_nodes.erase(node.name());
     not_present_nodes.erase(node.name());
     if (node.name() == "rank") {
@@ -1212,7 +1213,8 @@ TEST_F(ConstantFoldingTest, ShuffleReverseOnScalarRemoval) {
 }
 
 TEST_F(ConstantFoldingTest, NoOpReduction) {
-  // Build a simple graph with a reduction that can be reduced to the identity.
+  // Build a simple graph with a reduction that can be reduced to the
+  // identity.
   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
 
   Output v = ops::Variable(scope.WithOpName("v"), {3, 5, 7}, DT_FLOAT);
@@ -1338,8 +1340,8 @@ TEST_F(ConstantFoldingTest, Packing) {
   TF_EXPECT_OK(status);
 
   // Make sure that the representation of the folded constant is space
-  // efficient: in particular, the whole message should be smaller than 8k (the
-  // size needed to naively encode 1000 floats folded twice).
+  // efficient: in particular, the whole message should be smaller than 8k
+  // (the size needed to naively encode 1000 floats folded twice).
   EXPECT_GT(8000, output.ByteSizeLong());
 }
 
@@ -1494,6 +1496,58 @@ TEST_F(ConstantFoldingTest, LargeConstant) {
   EXPECT_GT(1024 * 1024, output.ByteSizeLong());
 }
 
+TEST_F(ConstantFoldingTest, SwitchIdenticalInputs) {
+  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+  Output x = ops::Placeholder(s.WithOpName("x"), DT_BOOL,
+                              ops::Placeholder::Shape(TensorShape({})));
+  ops::Switch sw = ops::Switch(s.WithOpName("switch"), x, x);
+  Output id_false = ops::LogicalNot(s.WithOpName("id_false"), sw.output_false);
+  Output id_true = ops::LogicalNot(s.WithOpName("id_true"), sw.output_true);
+
+  GrapplerItem item;
+  item.fetch.push_back("id_false");
+  item.fetch.push_back("id_true");
+  TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+  ConstantFolding fold(nullptr /* cpu_device */);
+  GraphDef output;
+  Status status = fold.Optimize(nullptr, item, &output);
+  TF_EXPECT_OK(status);
+
+  EXPECT_EQ(6, output.node_size());
+  int found = 0;
+  for (const auto& node : output.node()) {
+    if (node.name() == "switch" || node.name() == "x") {
+      ++found;
+    }
+    if (node.name() == "id_false") {
+      EXPECT_EQ("Const", node.op());
+      EXPECT_EQ(1, node.input_size());
+      EXPECT_EQ("^ConstantFoldingCtrl/switch_0", node.input(0));
+      ++found;
+    }
+    if (node.name() == "id_true") {
+      EXPECT_EQ("Const", node.op());
+      EXPECT_EQ(1, node.input_size());
+      EXPECT_EQ("^ConstantFoldingCtrl/switch_1", node.input(0));
+      ++found;
+    }
+    if (node.name() == "ConstantFoldingCtrl/switch_0") {
+      EXPECT_EQ("Identity", node.op());
+      EXPECT_EQ(1, node.input_size());
+      EXPECT_EQ("switch", node.input(0));
+      ++found;
+    }
+    if (node.name() == "ConstantFoldingCtrl/switch_1") {
+      EXPECT_EQ("Identity", node.op());
+      EXPECT_EQ(1, node.input_size());
+      EXPECT_EQ("switch:1", node.input(0));
+      ++found;
+    }
+  }
+  EXPECT_EQ(6, found);
+}
+
 }  // namespace
 }  // namespace grappler
 }  // namespace tensorflow
index eb5a2c48dc8b12f7b4090e80c403e238a526e122..81bb5e6c3b26ebbed8cd1555c10d2dd6f2a47c12 100644 (file)
@@ -29,6 +29,18 @@ limitations under the License.
 
 namespace tensorflow {
 namespace grappler {
+namespace {
+template <typename T>
+bool SafeSetScalarTensorValue(double value, Tensor* tensor) {
+  using RealType = typename Eigen::NumTraits<T>::Real;
+  if (value > std::numeric_limits<RealType>::max() ||
+      value < std::numeric_limits<RealType>::min()) {
+    return false;
+  }
+  tensor->flat<T>()(0) = static_cast<T>(value);
+  return true;
+}
+}  // namespace
 
 NodeMap::NodeMap(GraphDef* graph) {
   CHECK(graph != nullptr);
@@ -402,5 +414,43 @@ string SimpleGraphView::PrintToString() const {
   return str;
 }
 
+#define HANDLE_CASE(DTYPE)                                          \
+  case DTYPE:                                                       \
+    if (!SafeSetScalarTensorValue<EnumToDataType<DTYPE>::Type>(     \
+            static_cast<double>(value), tensor)) {                  \
+      return errors::InvalidArgument("Cannot store value ", value,  \
+                                     " in tensor of type " #DTYPE); \
+    }                                                               \
+    break
+
+Status SetTensorValue(DataType dtype, int value, Tensor* tensor) {
+  // TODO(rmlarsen): Support more general shapes.
+  if (tensor->NumElements() != 1) {
+    return errors::InvalidArgument(
+        "Expected scalar tensor, got num_elements = ", tensor->NumElements());
+  }
+  switch (dtype) {
+    // TODO(rmlarsen): Handle DT_HALF.
+    //    HANDLE_CASE(DT_HALF);
+    HANDLE_CASE(DT_BOOL);
+    HANDLE_CASE(DT_FLOAT);
+    HANDLE_CASE(DT_DOUBLE);
+    HANDLE_CASE(DT_UINT8);
+    HANDLE_CASE(DT_INT8);
+    HANDLE_CASE(DT_UINT16);
+    HANDLE_CASE(DT_INT16);
+    HANDLE_CASE(DT_INT32);
+    HANDLE_CASE(DT_INT64);
+    HANDLE_CASE(DT_COMPLEX64);
+    HANDLE_CASE(DT_COMPLEX128);
+    default:
+      return errors::InvalidArgument("Unsupported type ",
+                                     DataTypeString(dtype));
+  }
+  return Status::OK();
+}
+
+#undef HANDLE_CASE
+
 }  // end namespace grappler
 }  // end namespace tensorflow
index 4ecb28f681507f50ad5909f15cf1b408ed6e2979..255319693a57a7cc493365a51d5d04d2893f08c5 100644 (file)
@@ -23,6 +23,7 @@ limitations under the License.
 
 #include "tensorflow/core/framework/graph.pb.h"
 #include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/tensor.h"
 #include "tensorflow/core/framework/types.h"
 #include "tensorflow/core/lib/core/status.h"
 #include "tensorflow/core/lib/core/threadpool.h"
@@ -167,6 +168,8 @@ NodeDef* GetTailOfChain(const NodeDef& source, const NodeMap& node_map,
 void PermuteNodesInPlace(GraphDef* graph, std::vector<int>* permutation,
                          bool invert_permutation);
 
+Status SetTensorValue(DataType dtype, int value, Tensor* tensor);
+
 class SimpleGraphView {
  public:
   Status Initialize(const GraphDef& graph) {