From 4d03411da6fcc803d9abcef97a59072144e325f9 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 21 May 2018 16:14:10 -0700 Subject: [PATCH] Add arithmetic optimizer stage that removes LogicalNot that takes a comparison as input, i.e. !(a == b) => a != b !(a != b) => a == b !(a < b) => a >= b !(a <= b) => a > b !(a > b) => a <= b !(a >= b) => a < b PiperOrigin-RevId: 197477959 --- .../grappler/optimizers/arithmetic_optimizer.cc | 43 +++++++++ .../grappler/optimizers/arithmetic_optimizer.h | 1 + .../optimizers/arithmetic_optimizer_test.cc | 103 +++++++++++++++++++++ 3 files changed, 147 insertions(+) diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index e7f70c6..060e420 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -1380,6 +1380,47 @@ class RemoveNegationStage : public ArithmeticOptimizerStage { } }; +class RemoveLogicalNotStage : public ArithmeticOptimizerStage { + public: + explicit RemoveLogicalNotStage(const GraphOptimizerContext& ctx, + const ArithmeticOptimizerContext& ctx_ext) + : ArithmeticOptimizerStage("RemoveLogicalNot", ctx, ctx_ext) {} + ~RemoveLogicalNotStage() override = default; + + bool IsSupported(const NodeDef* node) const override { + return IsLogicalNot(*node) && !IsInPreserveSet(*node); + } + + Status TrySimplify(NodeDef* node, string* simplified_node_name) override { + const string node_name = node->name(); + NodeDef* input; + TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input)); + if (IsInPreserveSet(*input) || + NumNonControlOutputs(*input, *ctx().node_map) > 1) { + return Status::OK(); + } + string new_op; + if (IsEqual(*input)) { + new_op = "NotEqual"; + } else if (IsNotEqual(*input)) { + new_op = "Equal"; + } else if (IsLess(*input)) { + new_op = "GreaterEqual"; + } else if (IsLessEqual(*input)) { + new_op = "Greater"; + } else if (IsGreater(*input)) { + new_op = "LessEqual"; + } else if (IsGreaterEqual(*input)) { + new_op = "Less"; + } + if (!new_op.empty()) { + input->set_op(new_op); + *simplified_node_name = input->name(); + } + return Status::OK(); + } +}; + // This optimization hoists the common prefix of unary ops of the inputs to // concat out of the concat, for example: // Concat([Exp(Sin(x)), Exp(Sin(y)), Exp(Sin(z))]) @@ -2429,6 +2470,8 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) { pipeline.AddStage(ctx, ctx_ext); if (options_.remove_negation) pipeline.AddStage(ctx, ctx_ext); + if (options_.remove_logical_not) + pipeline.AddStage(ctx, ctx_ext); if (options_.hoist_cwise_unary_chains) pipeline.AddStage(ctx, ctx_ext); if (options_.convert_sqrt_div_to_rsqrt_mul) diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h index 1f6f563..8e1b3ed 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h @@ -68,6 +68,7 @@ class ArithmeticOptimizer : public GraphOptimizer { bool hoist_cwise_unary_chains = false; bool convert_sqrt_div_to_rsqrt_mul = false; bool remove_idempotent = true; + bool remove_logical_not = true; // Choose which arithmetic optimizer stages will be enabled for a given // optimization level by default. diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc index 99f93e6..64fdc8a 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc @@ -177,6 +177,11 @@ class ArithmeticOptimizerTest : public GrapplerTest { DisableAllStages(optimizer); optimizer->options_.remove_idempotent = true; } + + void EnableOnlyRemoveLogicalNot(ArithmeticOptimizer* optimizer) { + DisableAllStages(optimizer); + optimizer->options_.remove_logical_not = true; + } }; TEST_F(ArithmeticOptimizerTest, NoOp) { @@ -2737,5 +2742,103 @@ TEST_F(ArithmeticOptimizerTest, RemoveIdempotent) { } } +TEST_F(ArithmeticOptimizerTest, RemoveLogicalNot) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output a = ops::Const(s.WithOpName("a"), 3.14f, {32}); + Output b = ops::Const(s.WithOpName("b"), -3.14f, {32}); + Output eq = ops::Equal(s.WithOpName("eq"), a, b); + Output neq = ops::NotEqual(s.WithOpName("neq"), a, b); + Output lt = ops::Less(s.WithOpName("lt"), a, b); + Output le = ops::LessEqual(s.WithOpName("le"), a, b); + Output gt = ops::Greater(s.WithOpName("gt"), a, b); + Output ge = ops::GreaterEqual(s.WithOpName("ge"), a, b); + // not_eq is reserved + Output not_eq1 = ops::LogicalNot(s.WithOpName("not_eq1"), eq); + Output not_neq = ops::LogicalNot(s.WithOpName("not_neq"), neq); + Output not_lt = ops::LogicalNot(s.WithOpName("not_lt"), lt); + Output not_le = ops::LogicalNot(s.WithOpName("not_le"), le); + Output not_gt = ops::LogicalNot(s.WithOpName("not_gt"), gt); + Output not_ge = ops::LogicalNot(s.WithOpName("not_ge"), ge); + Output id_not_eq = ops::Identity(s.WithOpName("id_not_eq"), not_eq1); + Output id_not_neq = ops::Identity(s.WithOpName("id_not_neq"), not_neq); + Output id_not_lt = ops::Identity(s.WithOpName("id_not_lt"), not_lt); + Output id_not_le = ops::Identity(s.WithOpName("id_not_le"), not_le); + Output id_not_gt = ops::Identity(s.WithOpName("id_not_gt"), not_gt); + Output id_not_ge = ops::Identity(s.WithOpName("id_not_ge"), not_ge); + + GrapplerItem item; + item.fetch = {"id_not_eq", "id_not_neq", "id_not_lt", + "id_not_le", "id_not_gt", "id_not_ge"}; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + + auto tensors_expected = EvaluateNodes(item.graph, item.fetch); + + GraphDef output; + ArithmeticOptimizer optimizer; + EnableOnlyRemoveLogicalNot(&optimizer); + OptimizeTwice(&optimizer, &item, &output); + LOG(INFO) << output.DebugString(); + int found = 0; + for (const NodeDef& node : output.node()) { + if (node.name() == "id_not_eq") { + EXPECT_EQ("eq", node.input(0)); + ++found; + } + if (node.name() == "id_not_neq") { + EXPECT_EQ("neq", node.input(0)); + ++found; + } + if (node.name() == "id_not_lt") { + EXPECT_EQ("lt", node.input(0)); + ++found; + } + if (node.name() == "id_not_le") { + EXPECT_EQ("le", node.input(0)); + ++found; + } + if (node.name() == "id_not_gt") { + EXPECT_EQ("gt", node.input(0)); + ++found; + } + if (node.name() == "id_not_ge") { + EXPECT_EQ("ge", node.input(0)); + ++found; + } + + if (node.name() == "eq") { + EXPECT_EQ("NotEqual", node.op()); + ++found; + } + if (node.name() == "neq") { + EXPECT_EQ("Equal", node.op()); + ++found; + } + if (node.name() == "lt") { + EXPECT_EQ("GreaterEqual", node.op()); + ++found; + } + if (node.name() == "le") { + EXPECT_EQ("Greater", node.op()); + ++found; + } + if (node.name() == "gt") { + EXPECT_EQ("LessEqual", node.op()); + ++found; + } + if (node.name() == "ge") { + EXPECT_EQ("Less", node.op()); + ++found; + } + } + EXPECT_EQ(12, found); + + auto tensors = EvaluateNodes(output, item.fetch); + EXPECT_EQ(tensors.size(), tensors_expected.size()); + EXPECT_EQ(tensors.size(), item.fetch.size()); + for (int i = 0; i < item.fetch.size(); ++i) { + test::ExpectTensorEqual(tensors_expected[i], tensors[i]); + } +} + } // namespace grappler } // namespace tensorflow -- 2.7.4