From 6959b8e76f18f63aacaaf24dd74b11d733b57314 Mon Sep 17 00:00:00 2001 From: Roman Lebedev Date: Fri, 27 Apr 2018 21:23:20 +0000 Subject: [PATCH] [PatternMatch] Stabilize the matching order of commutative matchers Summary: Currently, we 1. match `LHS` matcher to the `first` operand of binary operator, 2. and then match `RHS` matcher to the `second` operand of binary operator. If that does not match, we swap the `LHS` and `RHS` matchers: 1. match `RHS` matcher to the `first` operand of binary operator, 2. and then match `LHS` matcher to the `second` operand of binary operator. This works ok. But it complicates writing of commutative matchers, where one would like to match (`m_Value()`) the value on one side, and use (`m_Specific()`) it on the other side. This is additionally complicated by the fact that `m_Specific()` stores the `Value *`, not `Value **`, so it won't work at all out of the box. The last problem is trivially solved by adding a new `m_c_Specific()` that stores the `Value **`, not `Value *`. I'm choosing to add a new matcher, not change the existing one because i guess all the current users are ok with existing behavior, and this additional pointer indirection may have performance drawbacks. Also, i'm storing pointer, not reference, because for some mysterious-to-me reason it did not work with the reference. The first one appears trivial, too. Currently, we 1. match `LHS` matcher to the `first` operand of binary operator, 2. and then match `RHS` matcher to the `second` operand of binary operator. If that does not match, we swap the ~~`LHS` and `RHS` matchers~~ **operands**: 1. match ~~`RHS`~~ **`LHS`** matcher to the ~~`first`~~ **`second`** operand of binary operator, 2. and then match ~~`LHS`~~ **`RHS`** matcher to the ~~`second`~ **`first`** operand of binary operator. Surprisingly, `$ ninja check-llvm` still passes with this. But i expect the bots will disagree.. The motivational unittest is included. I'd like to use this in D45664. Reviewers: spatel, craig.topper, arsenm, RKSimon Reviewed By: craig.topper Subscribers: xbolva00, wdng, llvm-commits Differential Revision: https://reviews.llvm.org/D45828 llvm-svn: 331085 --- llvm/include/llvm/IR/PatternMatch.h | 42 ++++++++++++++---- llvm/lib/Analysis/ValueTracking.cpp | 7 +-- .../Transforms/InstCombine/InstCombineAddSub.cpp | 19 ++------ .../Transforms/InstCombine/InstCombineAndOrXor.cpp | 26 +++++------ llvm/test/Transforms/InstCombine/and-or-not.ll | 10 ++--- llvm/test/Transforms/InstCombine/or-xor.ll | 8 ++-- llvm/unittests/IR/PatternMatch.cpp | 50 ++++++++++++++++++++++ 7 files changed, 108 insertions(+), 54 deletions(-) diff --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h index 8b1c6b2..dcca931 100644 --- a/llvm/include/llvm/IR/PatternMatch.h +++ b/llvm/include/llvm/IR/PatternMatch.h @@ -489,6 +489,22 @@ struct specificval_ty { /// Match if we have a specific specified value. inline specificval_ty m_Specific(const Value *V) { return V; } +/// Stores a reference to the Value *, not the Value * itself, +/// thus can be used in commutative matchers. +template struct deferredval_ty { + Class *const &Val; + + deferredval_ty(Class *const &V) : Val(V) {} + + template bool match(ITy *const V) { return V == Val; } +}; + +/// A commutative-friendly version of m_Specific(). +inline deferredval_ty m_Deferred(Value *const &V) { return V; } +inline deferredval_ty m_Deferred(const Value *const &V) { + return V; +} + /// Match a specified floating point value or vector of all elements of /// that value. struct specific_fpval { @@ -562,13 +578,15 @@ struct AnyBinaryOp_match { LHS_t L; RHS_t R; + // The evaluation order is always stable, regardless of Commutability. + // The LHS is always matched first. AnyBinaryOp_match(const LHS_t &LHS, const RHS_t &RHS) : L(LHS), R(RHS) {} template bool match(OpTy *V) { if (auto *I = dyn_cast(V)) return (L.match(I->getOperand(0)) && R.match(I->getOperand(1))) || - (Commutable && R.match(I->getOperand(0)) && - L.match(I->getOperand(1))); + (Commutable && L.match(I->getOperand(1)) && + R.match(I->getOperand(0))); return false; } }; @@ -588,20 +606,22 @@ struct BinaryOp_match { LHS_t L; RHS_t R; + // The evaluation order is always stable, regardless of Commutability. + // The LHS is always matched first. BinaryOp_match(const LHS_t &LHS, const RHS_t &RHS) : L(LHS), R(RHS) {} template bool match(OpTy *V) { if (V->getValueID() == Value::InstructionVal + Opcode) { auto *I = cast(V); return (L.match(I->getOperand(0)) && R.match(I->getOperand(1))) || - (Commutable && R.match(I->getOperand(0)) && - L.match(I->getOperand(1))); + (Commutable && L.match(I->getOperand(1)) && + R.match(I->getOperand(0))); } if (auto *CE = dyn_cast(V)) return CE->getOpcode() == Opcode && ((L.match(CE->getOperand(0)) && R.match(CE->getOperand(1))) || - (Commutable && R.match(CE->getOperand(0)) && - L.match(CE->getOperand(1)))); + (Commutable && L.match(CE->getOperand(1)) && + R.match(CE->getOperand(0)))); return false; } }; @@ -926,14 +946,16 @@ struct CmpClass_match { LHS_t L; RHS_t R; + // The evaluation order is always stable, regardless of Commutability. + // The LHS is always matched first. CmpClass_match(PredicateTy &Pred, const LHS_t &LHS, const RHS_t &RHS) : Predicate(Pred), L(LHS), R(RHS) {} template bool match(OpTy *V) { if (auto *I = dyn_cast(V)) if ((L.match(I->getOperand(0)) && R.match(I->getOperand(1))) || - (Commutable && R.match(I->getOperand(0)) && - L.match(I->getOperand(1)))) { + (Commutable && L.match(I->getOperand(1)) && + R.match(I->getOperand(0)))) { Predicate = I->getPredicate(); return true; } @@ -1251,6 +1273,8 @@ struct MaxMin_match { LHS_t L; RHS_t R; + // The evaluation order is always stable, regardless of Commutability. + // The LHS is always matched first. MaxMin_match(const LHS_t &LHS, const RHS_t &RHS) : L(LHS), R(RHS) {} template bool match(OpTy *V) { @@ -1277,7 +1301,7 @@ struct MaxMin_match { return false; // It does! Bind the operands. return (L.match(LHS) && R.match(RHS)) || - (Commutable && R.match(LHS) && L.match(RHS)); + (Commutable && L.match(RHS) && R.match(LHS)); } }; diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp index ac29cde..a799f7c 100644 --- a/llvm/lib/Analysis/ValueTracking.cpp +++ b/llvm/lib/Analysis/ValueTracking.cpp @@ -982,12 +982,9 @@ static void computeKnownBitsFromOperator(const Operator *I, KnownBits &Known, // matching the form add(x, add(x, y)) where y is odd. // TODO: This could be generalized to clearing any bit set in y where the // following bit is known to be unset in y. - Value *Y = nullptr; + Value *X = nullptr, *Y = nullptr; if (!Known.Zero[0] && !Known.One[0] && - (match(I->getOperand(0), m_Add(m_Specific(I->getOperand(1)), - m_Value(Y))) || - match(I->getOperand(1), m_Add(m_Specific(I->getOperand(0)), - m_Value(Y))))) { + match(I, m_c_BinOp(m_Value(X), m_Add(m_Deferred(X), m_Value(Y))))) { Known2.resetAll(); computeKnownBits(Y, Known2, Depth + 1, Q); if (Known2.countMinTrailingOnes() > 0) diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp index b7e115e..274bde0 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp @@ -1363,26 +1363,15 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { } // (add (xor A, B) (and A, B)) --> (or A, B) - if (match(LHS, m_Xor(m_Value(A), m_Value(B))) && - match(RHS, m_c_And(m_Specific(A), m_Specific(B)))) - return BinaryOperator::CreateOr(A, B); - // (add (and A, B) (xor A, B)) --> (or A, B) - if (match(RHS, m_Xor(m_Value(A), m_Value(B))) && - match(LHS, m_c_And(m_Specific(A), m_Specific(B)))) + if (match(&I, m_c_BinOp(m_Xor(m_Value(A), m_Value(B)), + m_c_And(m_Deferred(A), m_Deferred(B))))) return BinaryOperator::CreateOr(A, B); // (add (or A, B) (and A, B)) --> (add A, B) - if (match(LHS, m_Or(m_Value(A), m_Value(B))) && - match(RHS, m_c_And(m_Specific(A), m_Specific(B)))) { - I.setOperand(0, A); - I.setOperand(1, B); - return &I; - } - // (add (and A, B) (or A, B)) --> (add A, B) - if (match(RHS, m_Or(m_Value(A), m_Value(B))) && - match(LHS, m_c_And(m_Specific(A), m_Specific(B)))) { + if (match(&I, m_c_BinOp(m_Or(m_Value(A), m_Value(B)), + m_c_And(m_Deferred(A), m_Deferred(B))))) { I.setOperand(0, A); I.setOperand(1, B); return &I; diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp index 1d90d236..f4aec91 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -1288,8 +1288,8 @@ static Instruction *foldAndToXor(BinaryOperator &I, // Operand complexity canonicalization guarantees that the 'or' is Op0. // (A | B) & ~(A & B) --> A ^ B // (A | B) & ~(B & A) --> A ^ B - if (match(Op0, m_Or(m_Value(A), m_Value(B))) && - match(Op1, m_Not(m_c_And(m_Specific(A), m_Specific(B))))) + if (match(&I, m_BinOp(m_Or(m_Value(A), m_Value(B)), + m_Not(m_c_And(m_Deferred(A), m_Deferred(B)))))) return BinaryOperator::CreateXor(A, B); // (A | ~B) & (~A | B) --> ~(A ^ B) @@ -1297,8 +1297,8 @@ static Instruction *foldAndToXor(BinaryOperator &I, // (~B | A) & (~A | B) --> ~(A ^ B) // (~B | A) & (B | ~A) --> ~(A ^ B) if (Op0->hasOneUse() || Op1->hasOneUse()) - if (match(Op0, m_c_Or(m_Value(A), m_Not(m_Value(B)))) && - match(Op1, m_c_Or(m_Not(m_Specific(A)), m_Specific(B)))) + if (match(&I, m_BinOp(m_c_Or(m_Value(A), m_Not(m_Value(B))), + m_c_Or(m_Not(m_Deferred(A)), m_Deferred(B))))) return BinaryOperator::CreateNot(Builder.CreateXor(A, B)); return nullptr; @@ -2294,10 +2294,8 @@ static Instruction *foldXorToXor(BinaryOperator &I, // (A & B) ^ (B | A) -> A ^ B // (A | B) ^ (A & B) -> A ^ B // (A | B) ^ (B & A) -> A ^ B - if ((match(Op0, m_And(m_Value(A), m_Value(B))) && - match(Op1, m_c_Or(m_Specific(A), m_Specific(B)))) || - (match(Op0, m_Or(m_Value(A), m_Value(B))) && - match(Op1, m_c_And(m_Specific(A), m_Specific(B))))) { + if (match(&I, m_c_Xor(m_And(m_Value(A), m_Value(B)), + m_c_Or(m_Deferred(A), m_Deferred(B))))) { I.setOperand(0, A); I.setOperand(1, B); return &I; @@ -2307,10 +2305,8 @@ static Instruction *foldXorToXor(BinaryOperator &I, // (~B | A) ^ (~A | B) -> A ^ B // (~A | B) ^ (A | ~B) -> A ^ B // (B | ~A) ^ (A | ~B) -> A ^ B - if ((match(Op0, m_Or(m_Value(A), m_Not(m_Value(B)))) && - match(Op1, m_c_Or(m_Not(m_Specific(A)), m_Specific(B)))) || - (match(Op0, m_Or(m_Not(m_Value(A)), m_Value(B))) && - match(Op1, m_c_Or(m_Specific(A), m_Not(m_Specific(B)))))) { + if (match(&I, m_Xor(m_c_Or(m_Value(A), m_Not(m_Value(B))), + m_c_Or(m_Not(m_Deferred(A)), m_Deferred(B))))) { I.setOperand(0, A); I.setOperand(1, B); return &I; @@ -2320,10 +2316,8 @@ static Instruction *foldXorToXor(BinaryOperator &I, // (~B & A) ^ (~A & B) -> A ^ B // (~A & B) ^ (A & ~B) -> A ^ B // (B & ~A) ^ (A & ~B) -> A ^ B - if ((match(Op0, m_And(m_Value(A), m_Not(m_Value(B)))) && - match(Op1, m_c_And(m_Not(m_Specific(A)), m_Specific(B)))) || - (match(Op0, m_And(m_Not(m_Value(A)), m_Value(B))) && - match(Op1, m_c_And(m_Specific(A), m_Not(m_Specific(B)))))) { + if (match(&I, m_Xor(m_c_And(m_Value(A), m_Not(m_Value(B))), + m_c_And(m_Not(m_Deferred(A)), m_Deferred(B))))) { I.setOperand(0, A); I.setOperand(1, B); return &I; diff --git a/llvm/test/Transforms/InstCombine/and-or-not.ll b/llvm/test/Transforms/InstCombine/and-or-not.ll index 2d3fb42..7bd4ad7 100644 --- a/llvm/test/Transforms/InstCombine/and-or-not.ll +++ b/llvm/test/Transforms/InstCombine/and-or-not.ll @@ -333,7 +333,7 @@ define i32 @xor_to_xor3(i32 %a, i32 %b) { define i32 @xor_to_xor4(i32 %a, i32 %b) { ; CHECK-LABEL: @xor_to_xor4( -; CHECK-NEXT: [[XOR:%.*]] = xor i32 [[A:%.*]], [[B:%.*]] +; CHECK-NEXT: [[XOR:%.*]] = xor i32 [[B:%.*]], [[A:%.*]] ; CHECK-NEXT: ret i32 [[XOR]] ; %or = or i32 %a, %b @@ -389,7 +389,7 @@ define i32 @xor_to_xor7(float %fa, float %fb) { ; CHECK-LABEL: @xor_to_xor7( ; CHECK-NEXT: [[A:%.*]] = fptosi float [[FA:%.*]] to i32 ; CHECK-NEXT: [[B:%.*]] = fptosi float [[FB:%.*]] to i32 -; CHECK-NEXT: [[XOR:%.*]] = xor i32 [[A]], [[B]] +; CHECK-NEXT: [[XOR:%.*]] = xor i32 [[B]], [[A]] ; CHECK-NEXT: ret i32 [[XOR]] ; %a = fptosi float %fa to i32 @@ -408,7 +408,7 @@ define i32 @xor_to_xor8(float %fa, float %fb) { ; CHECK-LABEL: @xor_to_xor8( ; CHECK-NEXT: [[A:%.*]] = fptosi float [[FA:%.*]] to i32 ; CHECK-NEXT: [[B:%.*]] = fptosi float [[FB:%.*]] to i32 -; CHECK-NEXT: [[XOR:%.*]] = xor i32 [[A]], [[B]] +; CHECK-NEXT: [[XOR:%.*]] = xor i32 [[B]], [[A]] ; CHECK-NEXT: ret i32 [[XOR]] ; %a = fptosi float %fa to i32 @@ -465,7 +465,7 @@ define i32 @xor_to_xor11(float %fa, float %fb) { ; CHECK-LABEL: @xor_to_xor11( ; CHECK-NEXT: [[A:%.*]] = fptosi float [[FA:%.*]] to i32 ; CHECK-NEXT: [[B:%.*]] = fptosi float [[FB:%.*]] to i32 -; CHECK-NEXT: [[XOR:%.*]] = xor i32 [[A]], [[B]] +; CHECK-NEXT: [[XOR:%.*]] = xor i32 [[B]], [[A]] ; CHECK-NEXT: ret i32 [[XOR]] ; %a = fptosi float %fa to i32 @@ -484,7 +484,7 @@ define i32 @xor_to_xor12(float %fa, float %fb) { ; CHECK-LABEL: @xor_to_xor12( ; CHECK-NEXT: [[A:%.*]] = fptosi float [[FA:%.*]] to i32 ; CHECK-NEXT: [[B:%.*]] = fptosi float [[FB:%.*]] to i32 -; CHECK-NEXT: [[XOR:%.*]] = xor i32 [[A]], [[B]] +; CHECK-NEXT: [[XOR:%.*]] = xor i32 [[B]], [[A]] ; CHECK-NEXT: ret i32 [[XOR]] ; %a = fptosi float %fa to i32 diff --git a/llvm/test/Transforms/InstCombine/or-xor.ll b/llvm/test/Transforms/InstCombine/or-xor.ll index 7d163d5..ab5f2f8 100644 --- a/llvm/test/Transforms/InstCombine/or-xor.ll +++ b/llvm/test/Transforms/InstCombine/or-xor.ll @@ -188,7 +188,7 @@ define i32 @test13(i32 %x, i32 %y) { ; ((x | ~y) ^ (~x | y)) -> x ^ y define i32 @test14(i32 %x, i32 %y) { ; CHECK-LABEL: @test14( -; CHECK-NEXT: [[XOR:%.*]] = xor i32 [[Y:%.*]], [[X:%.*]] +; CHECK-NEXT: [[XOR:%.*]] = xor i32 [[X:%.*]], [[Y:%.*]] ; CHECK-NEXT: ret i32 [[XOR]] ; %noty = xor i32 %y, -1 @@ -201,7 +201,7 @@ define i32 @test14(i32 %x, i32 %y) { define i32 @test14_commuted(i32 %x, i32 %y) { ; CHECK-LABEL: @test14_commuted( -; CHECK-NEXT: [[XOR:%.*]] = xor i32 [[Y:%.*]], [[X:%.*]] +; CHECK-NEXT: [[XOR:%.*]] = xor i32 [[X:%.*]], [[Y:%.*]] ; CHECK-NEXT: ret i32 [[XOR]] ; %noty = xor i32 %y, -1 @@ -215,7 +215,7 @@ define i32 @test14_commuted(i32 %x, i32 %y) { ; ((x & ~y) ^ (~x & y)) -> x ^ y define i32 @test15(i32 %x, i32 %y) { ; CHECK-LABEL: @test15( -; CHECK-NEXT: [[XOR:%.*]] = xor i32 [[Y:%.*]], [[X:%.*]] +; CHECK-NEXT: [[XOR:%.*]] = xor i32 [[X:%.*]], [[Y:%.*]] ; CHECK-NEXT: ret i32 [[XOR]] ; %noty = xor i32 %y, -1 @@ -228,7 +228,7 @@ define i32 @test15(i32 %x, i32 %y) { define i32 @test15_commuted(i32 %x, i32 %y) { ; CHECK-LABEL: @test15_commuted( -; CHECK-NEXT: [[XOR:%.*]] = xor i32 [[Y:%.*]], [[X:%.*]] +; CHECK-NEXT: [[XOR:%.*]] = xor i32 [[X:%.*]], [[Y:%.*]] ; CHECK-NEXT: ret i32 [[XOR]] ; %noty = xor i32 %y, -1 diff --git a/llvm/unittests/IR/PatternMatch.cpp b/llvm/unittests/IR/PatternMatch.cpp index 6bdcf4d..2c6da22 100644 --- a/llvm/unittests/IR/PatternMatch.cpp +++ b/llvm/unittests/IR/PatternMatch.cpp @@ -65,6 +65,56 @@ TEST_F(PatternMatchTest, OneUse) { EXPECT_FALSE(m_OneUse(m_Value()).match(Leaf)); } +TEST_F(PatternMatchTest, CommutativeDeferredValue) { + Value *X = IRB.getInt32(1); + Value *Y = IRB.getInt32(2); + + { + Value *tX = X; + EXPECT_TRUE(match(X, m_Deferred(tX))); + EXPECT_FALSE(match(Y, m_Deferred(tX))); + } + { + const Value *tX = X; + EXPECT_TRUE(match(X, m_Deferred(tX))); + EXPECT_FALSE(match(Y, m_Deferred(tX))); + } + { + Value *const tX = X; + EXPECT_TRUE(match(X, m_Deferred(tX))); + EXPECT_FALSE(match(Y, m_Deferred(tX))); + } + { + const Value *const tX = X; + EXPECT_TRUE(match(X, m_Deferred(tX))); + EXPECT_FALSE(match(Y, m_Deferred(tX))); + } + + { + Value *tX = nullptr; + EXPECT_TRUE(match(IRB.CreateAnd(X, X), m_And(m_Value(tX), m_Deferred(tX)))); + EXPECT_EQ(tX, X); + } + { + Value *tX = nullptr; + EXPECT_FALSE( + match(IRB.CreateAnd(X, Y), m_c_And(m_Value(tX), m_Deferred(tX)))); + } + + auto checkMatch = [X, Y](Value *Pattern) { + Value *tX = nullptr, *tY = nullptr; + EXPECT_TRUE(match( + Pattern, m_c_And(m_Value(tX), m_c_And(m_Deferred(tX), m_Value(tY))))); + EXPECT_EQ(tX, X); + EXPECT_EQ(tY, Y); + }; + + checkMatch(IRB.CreateAnd(X, IRB.CreateAnd(X, Y))); + checkMatch(IRB.CreateAnd(X, IRB.CreateAnd(Y, X))); + checkMatch(IRB.CreateAnd(IRB.CreateAnd(X, Y), X)); + checkMatch(IRB.CreateAnd(IRB.CreateAnd(Y, X), X)); +} + TEST_F(PatternMatchTest, FloatingPointOrderedMin) { Type *FltTy = IRB.getFloatTy(); Value *L = ConstantFP::get(FltTy, 1.0); -- 2.7.4