convert Pow op into something that is more recognizable, so we can have further
authorA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 22 May 2018 08:35:36 +0000 (01:35 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 22 May 2018 08:38:17 +0000 (01:38 -0700)
optimizations.

PiperOrigin-RevId: 197527651

tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc

index 060e420..21a8602 100644 (file)
@@ -1819,6 +1819,141 @@ class SqrtDivToRsqrtMulStage : public ArithmeticOptimizerStage {
   }
 };
 
+class ConvertPowStage : public ArithmeticOptimizerStage {
+ public:
+  explicit ConvertPowStage(const GraphOptimizerContext& ctx,
+                           const ArithmeticOptimizerContext& ctx_ext)
+      : ArithmeticOptimizerStage("ConvertPow", ctx, ctx_ext) {}
+
+  bool IsSupported(const NodeDef* node) const override {
+    return IsPow(*node) &&
+           ctx().graph_properties->GetInputProperties(node->name()).size() == 2;
+  }
+
+  Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
+    const auto& p = ctx().graph_properties->GetInputProperties(node->name())[1];
+    for (int i = 0; i < p.shape().dim_size(); ++i) {
+      if (p.shape().dim(i).size() < 0) {
+        // skip if p is is not fully defined.
+        return Status::OK();
+      }
+    }
+    if (TensorShape::IsValid(p.shape()) && p.has_value()) {
+      Tensor pow(p.dtype(), p.shape());
+      if (!pow.FromProto(p.value())) {
+        return errors::InvalidArgument("Cannot parse tensor from proto: ",
+                                       p.value().DebugString());
+      }
+
+      complex128 prev, curr;
+      for (int i = 0; i < pow.NumElements(); ++i) {
+        TF_RETURN_IF_ERROR(GetElement(pow, i, &curr));
+        if (i != 0 && curr != prev) {
+          // pow has different values on different elements. Skip.
+          return Status::OK();
+        }
+        prev = curr;
+      }
+      NodeDef *x, *y;
+      TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &x));
+      TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &y));
+      if (curr == complex128(2, 0)) {
+        node->set_op("Square");
+        node->set_input(1, AsControlDependency(y->name()));
+        AddToOptimizationQueue(node);
+        AddToOptimizationQueue(y);
+      } else if (curr == complex128(1, 0)) {
+        node->set_op("Identity");
+        node->set_input(1, AsControlDependency(y->name()));
+        AddToOptimizationQueue(node);
+        AddToOptimizationQueue(y);
+      } else if (curr == complex128(0.5, 0)) {
+        node->set_op("Sqrt");
+        node->set_input(1, AsControlDependency(y->name()));
+        AddToOptimizationQueue(node);
+        AddToOptimizationQueue(y);
+      } else if (curr == complex128(0, 0)) {
+        node->set_op("Const");
+        Tensor c(pow.dtype(), pow.shape());
+        for (int i = 0; i < c.NumElements(); ++i) {
+          TF_RETURN_IF_ERROR(SetElementToOne(i, &c));
+        }
+        (*node->mutable_attr())["dtype"].set_type(pow.dtype());
+        c.AsProtoTensorContent(
+            (*node->mutable_attr())["value"].mutable_tensor());
+        node->mutable_attr()->erase("T");
+        node->set_input(0, AsControlDependency(x->name()));
+        node->set_input(1, AsControlDependency(y->name()));
+        AddToOptimizationQueue(node);
+        AddToOptimizationQueue(x);
+        AddToOptimizationQueue(y);
+      } else if (curr == complex128(-0.5, 0)) {
+        node->set_op("Rsqrt");
+        node->set_input(1, AsControlDependency(y->name()));
+        AddToOptimizationQueue(node);
+        AddToOptimizationQueue(y);
+      } else if (curr == complex128(-1, 0)) {
+        node->set_op("Reciprocal");
+        node->set_input(1, AsControlDependency(y->name()));
+        AddToOptimizationQueue(node);
+        AddToOptimizationQueue(y);
+      }
+    }
+    return Status::OK();
+  }
+
+ private:
+  Status GetElement(const Tensor& t, int i, complex128* element) {
+    switch (t.dtype()) {
+      case DT_INT32:
+        *element = complex128(t.flat<int32>()(i));
+        return Status::OK();
+      case DT_INT64:
+        *element = complex128(t.flat<int64>()(i));
+        return Status::OK();
+      case DT_FLOAT:
+        *element = complex128(t.flat<float>()(i));
+        return Status::OK();
+      case DT_DOUBLE:
+        *element = complex128(t.flat<double>()(i));
+        return Status::OK();
+      case DT_COMPLEX64:
+        *element = complex128(t.flat<complex64>()(i));
+        return Status::OK();
+      case DT_COMPLEX128:
+        *element = t.flat<complex128>()(i);
+        return Status::OK();
+      default:
+        return errors::InvalidArgument("Invalid data type: ", t.dtype());
+    }
+  }
+
+  Status SetElementToOne(int i, Tensor* t) {
+    switch (t->dtype()) {
+      case DT_INT32:
+        t->flat<int32>()(i) = 1;
+        return Status::OK();
+      case DT_INT64:
+        t->flat<int64>()(i) = 1L;
+        return Status::OK();
+      case DT_FLOAT:
+        t->flat<float>()(i) = 1.0f;
+        return Status::OK();
+      case DT_DOUBLE:
+        t->flat<double>()(i) = 1.0;
+        return Status::OK();
+      case DT_COMPLEX64:
+        t->flat<complex64>()(i) = complex64(1);
+        return Status::OK();
+      case DT_COMPLEX128:
+        t->flat<complex128>()(i) = complex128(1);
+        return Status::OK();
+      default:
+        return errors::InvalidArgument("Invalid data type: ", t->dtype());
+    }
+  }
+};
+
 }  // namespace
 
 class UniqueNodes {
@@ -2478,6 +2613,7 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) {
     pipeline.AddStage<SqrtDivToRsqrtMulStage>(ctx, ctx_ext);
   if (options_.remove_idempotent)
     pipeline.AddStage<RemoveIdempotentStage>(ctx, ctx_ext);
+  if (options_.convert_pow) pipeline.AddStage<ConvertPowStage>(ctx, ctx_ext);
 
   VLOG(1) << "Run " << pipeline.NumStages() << " arithmetic optimizer stages: "
           << str_util::Join(pipeline.StageNames(), ", ");
index 8e1b3ed..420b72b 100644 (file)
@@ -69,6 +69,7 @@ class ArithmeticOptimizer : public GraphOptimizer {
     bool convert_sqrt_div_to_rsqrt_mul = false;
     bool remove_idempotent = true;
     bool remove_logical_not = true;
+    bool convert_pow = true;
 
     // Choose which arithmetic optimizer stages will be enabled for a given
     // optimization level by default.
index 64fdc8a..8b8eedf 100644 (file)
@@ -173,6 +173,11 @@ class ArithmeticOptimizerTest : public GrapplerTest {
     optimizer->options_.convert_sqrt_div_to_rsqrt_mul = true;
   }
 
+  void EnableOnlyConvertPow(ArithmeticOptimizer* optimizer) {
+    DisableAllStages(optimizer);
+    optimizer->options_.convert_pow = true;
+  }
+
   void EnableOnlyRemoveIdempotent(ArithmeticOptimizer* optimizer) {
     DisableAllStages(optimizer);
     optimizer->options_.remove_idempotent = true;
@@ -2243,6 +2248,58 @@ TEST_F(ArithmeticOptimizerTest, ConvertSqrtDivToRsqrtMul) {
   }
 }
 
+TEST_F(ArithmeticOptimizerTest, ConvertPow) {
+  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+  auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
+  auto y2 = ops::Const(s.WithOpName("y2"), {2.0f, 2.0f}, {1, 2});
+  auto y1 = ops::Const(s.WithOpName("y1"), {1.0f, 1.0f}, {1, 2});
+  auto yPoint5 = ops::Const(s.WithOpName("y.5"), {0.5f, 0.5f}, {1, 2});
+  auto y0 = ops::Const(s.WithOpName("y0"), {0.0f, 0.0f}, {1, 2});
+  auto y_Point5 = ops::Const(s.WithOpName("y_.5"), {-0.5f, -0.5f}, {1, 2});
+  auto y_1 = ops::Const(s.WithOpName("y_1"), {-1.0f, -1.0f}, {1, 2});
+  auto y = ops::Const(s.WithOpName("y"), {3.0f, 4.0f}, {1, 2});
+  Output out2 = ops::Pow(s.WithOpName("out2"), x, y2);
+  Output out1 = ops::Pow(s.WithOpName("out1"), x, y1);
+  Output outPoint5 = ops::Pow(s.WithOpName("out.5"), x, yPoint5);
+  Output out0 = ops::Pow(s.WithOpName("out0"), x, y0);
+  Output out_Point5 = ops::Pow(s.WithOpName("out_.5"), x, y_Point5);
+  Output out_1 = ops::Pow(s.WithOpName("out_1"), x, y_1);
+  Output out = ops::Pow(s.WithOpName("out"), x, y);
+
+  GrapplerItem item;
+  item.fetch = {"out2", "out1", "out.5", "out0", "out_.5", "out_1", "out"};
+  TF_CHECK_OK(s.ToGraphDef(&item.graph));
+  auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
+  EXPECT_EQ(7, tensors_expected.size());
+
+  GraphDef got;
+  ArithmeticOptimizer optimizer;
+  EnableOnlyConvertPow(&optimizer);
+  OptimizeAndPrune(&optimizer, &item, &got);
+  auto tensors = EvaluateNodes(got, item.fetch);
+  EXPECT_EQ(7, tensors.size());
+
+  GraphDef want;
+  AddNode("x", "Const", {}, {}, &want);
+  AddNode("y2", "Const", {}, {}, &want);
+  AddNode("y1", "Const", {}, {}, &want);
+  AddNode("y.5", "Const", {}, {}, &want);
+  AddNode("y0", "Const", {}, {}, &want);
+  AddNode("y_.5", "Const", {}, {}, &want);
+  AddNode("y_1", "Const", {}, {}, &want);
+  AddNode("y", "Const", {}, {}, &want);
+  AddNode("out2", "Square", {"x", AsControlDependency("y2")}, {}, &want);
+  AddNode("out1", "Identity", {"x", AsControlDependency("y1")}, {}, &want);
+  AddNode("out.5", "Sqrt", {"x", AsControlDependency("y.5")}, {}, &want);
+  AddNode("out0", "Const",
+          {AsControlDependency("x"), AsControlDependency("y0")}, {}, &want);
+  AddNode("out_.5", "Rsqrt", {"x", AsControlDependency("y_.5")}, {}, &want);
+  AddNode("out_1", "Reciprocal", {"x", AsControlDependency("y_1")}, {}, &want);
+  AddNode("out", "Pow", {"x", "y"}, {}, &want);
+
+  CompareGraphs(want, got);
+}
+
 TEST_F(ArithmeticOptimizerTest, MinimizeBroadcasts_SimpleSwap) {
   tensorflow::Scope s = tensorflow::Scope::NewRootScope();