From a2048b8ce0e8ab37c5cf75bc21b503093091673b Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 24 May 2018 03:48:24 -0700 Subject: [PATCH] Automated g4 rollback of changelist 197477959 PiperOrigin-RevId: 197868028 --- .../grappler/optimizers/arithmetic_optimizer.cc | 43 --------- .../grappler/optimizers/arithmetic_optimizer.h | 1 - .../optimizers/arithmetic_optimizer_test.cc | 103 --------------------- 3 files changed, 147 deletions(-) diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index 060e420..e7f70c6 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -1380,47 +1380,6 @@ class RemoveNegationStage : public ArithmeticOptimizerStage { } }; -class RemoveLogicalNotStage : public ArithmeticOptimizerStage { - public: - explicit RemoveLogicalNotStage(const GraphOptimizerContext& ctx, - const ArithmeticOptimizerContext& ctx_ext) - : ArithmeticOptimizerStage("RemoveLogicalNot", ctx, ctx_ext) {} - ~RemoveLogicalNotStage() override = default; - - bool IsSupported(const NodeDef* node) const override { - return IsLogicalNot(*node) && !IsInPreserveSet(*node); - } - - Status TrySimplify(NodeDef* node, string* simplified_node_name) override { - const string node_name = node->name(); - NodeDef* input; - TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input)); - if (IsInPreserveSet(*input) || - NumNonControlOutputs(*input, *ctx().node_map) > 1) { - return Status::OK(); - } - string new_op; - if (IsEqual(*input)) { - new_op = "NotEqual"; - } else if (IsNotEqual(*input)) { - new_op = "Equal"; - } else if (IsLess(*input)) { - new_op = "GreaterEqual"; - } else if (IsLessEqual(*input)) { - new_op = "Greater"; - } else if (IsGreater(*input)) { - new_op = "LessEqual"; - } else if (IsGreaterEqual(*input)) { - new_op = "Less"; - } - if (!new_op.empty()) { - input->set_op(new_op); - *simplified_node_name = input->name(); - } - return Status::OK(); - } -}; - // This optimization hoists the common prefix of unary ops of the inputs to // concat out of the concat, for example: // Concat([Exp(Sin(x)), Exp(Sin(y)), Exp(Sin(z))]) @@ -2470,8 +2429,6 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) { pipeline.AddStage(ctx, ctx_ext); if (options_.remove_negation) pipeline.AddStage(ctx, ctx_ext); - if (options_.remove_logical_not) - pipeline.AddStage(ctx, ctx_ext); if (options_.hoist_cwise_unary_chains) pipeline.AddStage(ctx, ctx_ext); if (options_.convert_sqrt_div_to_rsqrt_mul) diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h index 8e1b3ed..1f6f563 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h @@ -68,7 +68,6 @@ class ArithmeticOptimizer : public GraphOptimizer { bool hoist_cwise_unary_chains = false; bool convert_sqrt_div_to_rsqrt_mul = false; bool remove_idempotent = true; - bool remove_logical_not = true; // Choose which arithmetic optimizer stages will be enabled for a given // optimization level by default. diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc index 64fdc8a..99f93e6 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc @@ -177,11 +177,6 @@ class ArithmeticOptimizerTest : public GrapplerTest { DisableAllStages(optimizer); optimizer->options_.remove_idempotent = true; } - - void EnableOnlyRemoveLogicalNot(ArithmeticOptimizer* optimizer) { - DisableAllStages(optimizer); - optimizer->options_.remove_logical_not = true; - } }; TEST_F(ArithmeticOptimizerTest, NoOp) { @@ -2742,103 +2737,5 @@ TEST_F(ArithmeticOptimizerTest, RemoveIdempotent) { } } -TEST_F(ArithmeticOptimizerTest, RemoveLogicalNot) { - tensorflow::Scope s = tensorflow::Scope::NewRootScope(); - Output a = ops::Const(s.WithOpName("a"), 3.14f, {32}); - Output b = ops::Const(s.WithOpName("b"), -3.14f, {32}); - Output eq = ops::Equal(s.WithOpName("eq"), a, b); - Output neq = ops::NotEqual(s.WithOpName("neq"), a, b); - Output lt = ops::Less(s.WithOpName("lt"), a, b); - Output le = ops::LessEqual(s.WithOpName("le"), a, b); - Output gt = ops::Greater(s.WithOpName("gt"), a, b); - Output ge = ops::GreaterEqual(s.WithOpName("ge"), a, b); - // not_eq is reserved - Output not_eq1 = ops::LogicalNot(s.WithOpName("not_eq1"), eq); - Output not_neq = ops::LogicalNot(s.WithOpName("not_neq"), neq); - Output not_lt = ops::LogicalNot(s.WithOpName("not_lt"), lt); - Output not_le = ops::LogicalNot(s.WithOpName("not_le"), le); - Output not_gt = ops::LogicalNot(s.WithOpName("not_gt"), gt); - Output not_ge = ops::LogicalNot(s.WithOpName("not_ge"), ge); - Output id_not_eq = ops::Identity(s.WithOpName("id_not_eq"), not_eq1); - Output id_not_neq = ops::Identity(s.WithOpName("id_not_neq"), not_neq); - Output id_not_lt = ops::Identity(s.WithOpName("id_not_lt"), not_lt); - Output id_not_le = ops::Identity(s.WithOpName("id_not_le"), not_le); - Output id_not_gt = ops::Identity(s.WithOpName("id_not_gt"), not_gt); - Output id_not_ge = ops::Identity(s.WithOpName("id_not_ge"), not_ge); - - GrapplerItem item; - item.fetch = {"id_not_eq", "id_not_neq", "id_not_lt", - "id_not_le", "id_not_gt", "id_not_ge"}; - TF_CHECK_OK(s.ToGraphDef(&item.graph)); - - auto tensors_expected = EvaluateNodes(item.graph, item.fetch); - - GraphDef output; - ArithmeticOptimizer optimizer; - EnableOnlyRemoveLogicalNot(&optimizer); - OptimizeTwice(&optimizer, &item, &output); - LOG(INFO) << output.DebugString(); - int found = 0; - for (const NodeDef& node : output.node()) { - if (node.name() == "id_not_eq") { - EXPECT_EQ("eq", node.input(0)); - ++found; - } - if (node.name() == "id_not_neq") { - EXPECT_EQ("neq", node.input(0)); - ++found; - } - if (node.name() == "id_not_lt") { - EXPECT_EQ("lt", node.input(0)); - ++found; - } - if (node.name() == "id_not_le") { - EXPECT_EQ("le", node.input(0)); - ++found; - } - if (node.name() == "id_not_gt") { - EXPECT_EQ("gt", node.input(0)); - ++found; - } - if (node.name() == "id_not_ge") { - EXPECT_EQ("ge", node.input(0)); - ++found; - } - - if (node.name() == "eq") { - EXPECT_EQ("NotEqual", node.op()); - ++found; - } - if (node.name() == "neq") { - EXPECT_EQ("Equal", node.op()); - ++found; - } - if (node.name() == "lt") { - EXPECT_EQ("GreaterEqual", node.op()); - ++found; - } - if (node.name() == "le") { - EXPECT_EQ("Greater", node.op()); - ++found; - } - if (node.name() == "gt") { - EXPECT_EQ("LessEqual", node.op()); - ++found; - } - if (node.name() == "ge") { - EXPECT_EQ("Less", node.op()); - ++found; - } - } - EXPECT_EQ(12, found); - - auto tensors = EvaluateNodes(output, item.fetch); - EXPECT_EQ(tensors.size(), tensors_expected.size()); - EXPECT_EQ(tensors.size(), item.fetch.size()); - for (int i = 0; i < item.fetch.size(); ++i) { - test::ExpectTensorEqual(tensors_expected[i], tensors[i]); - } -} - } // namespace grappler } // namespace tensorflow -- 2.7.4