[Matrix] Add minimal lowering pass that only requires TTI.
authorFlorian Hahn <flo@fhahn.com>
Mon, 20 Jul 2020 09:50:15 +0000 (10:50 +0100)
committerFlorian Hahn <flo@fhahn.com>
Mon, 20 Jul 2020 10:16:11 +0000 (11:16 +0100)
This patch adds a new variant of the matrix lowering pass that only does
a minimal lowering and only depends on TTI. The main purpose of this pass
is to have a pass with minimal dependencies to run as part of the backend
pipeline.

At the moment, the only difference to the regular lowering pass is that it
does not support remarks. But in subsequent patches add support for tiling
to the lowering pass which will require more analysis, which we do not want
to run in the backend, as the lowering should happen in the middle-end in
practice and running it in the backend is mostly for convenience when
running llc.

Reviewers: anemet, Gerolf, efriedma, hfinkel

Reviewed By: anemet

Differential Revision: https://reviews.llvm.org/D76867

llvm/include/llvm/InitializePasses.h
llvm/include/llvm/Transforms/Scalar.h
llvm/lib/Transforms/IPO/PassManagerBuilder.cpp
llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
llvm/lib/Transforms/Scalar/Scalar.cpp
llvm/test/Other/opt-O0-pipeline-enable-matrix.ll
llvm/test/Transforms/LowerMatrixIntrinsics/multiply-minimal.ll [new file with mode: 0644]

index 06e8507..7a85a96 100644 (file)
@@ -267,6 +267,7 @@ void initializeLowerInvokeLegacyPassPass(PassRegistry&);
 void initializeLowerSwitchPass(PassRegistry&);
 void initializeLowerTypeTestsPass(PassRegistry&);
 void initializeLowerMatrixIntrinsicsLegacyPassPass(PassRegistry &);
+void initializeLowerMatrixIntrinsicsMinimalLegacyPassPass(PassRegistry &);
 void initializeMIRCanonicalizerPass(PassRegistry &);
 void initializeMIRNamerPass(PassRegistry &);
 void initializeMIRPrintingPassPass(PassRegistry&);
index 7f55835..07d968e 100644 (file)
@@ -370,6 +370,13 @@ Pass *createLowerMatrixIntrinsicsPass();
 
 //===----------------------------------------------------------------------===//
 //
+// LowerMatrixIntrinsicsMinimal - Lower matrix intrinsics to vector operations
+//                               (lightweight, does not require extra analysis)
+//
+Pass *createLowerMatrixIntrinsicsMinimalPass();
+
+//===----------------------------------------------------------------------===//
+//
 // LowerWidenableCondition - Lower widenable condition to i1 true.
 //
 Pass *createLowerWidenableConditionPass();
index a109d69..5a3be75 100644 (file)
@@ -299,7 +299,7 @@ void PassManagerBuilder::populateFunctionPassManager(
   // FIXME: A lightweight version of the pass should run in the backend
   //        pipeline on demand.
   if (EnableMatrix && OptLevel == 0)
-    FPM.add(createLowerMatrixIntrinsicsPass());
+    FPM.add(createLowerMatrixIntrinsicsMinimalPass());
 
   if (OptLevel == 0) return;
 
index 90314b1..1a70071 100644 (file)
@@ -182,10 +182,10 @@ class LowerMatrixIntrinsics {
   Function &Func;
   const DataLayout &DL;
   const TargetTransformInfo &TTI;
-  AliasAnalysis &AA;
-  DominatorTree &DT;
-  LoopInfo &LI;
-  OptimizationRemarkEmitter &ORE;
+  AliasAnalysis *AA;
+  DominatorTree *DT;
+  LoopInfo *LI;
+  OptimizationRemarkEmitter *ORE;
 
   /// Contains estimates of the number of operations (loads, stores, compute) required to lower a matrix operation.
   struct OpInfoTy {
@@ -393,8 +393,8 @@ class LowerMatrixIntrinsics {
 
 public:
   LowerMatrixIntrinsics(Function &F, TargetTransformInfo &TTI,
-                        AliasAnalysis &AA, DominatorTree &DT, LoopInfo &LI,
-                        OptimizationRemarkEmitter &ORE)
+                        AliasAnalysis *AA, DominatorTree *DT, LoopInfo *LI,
+                        OptimizationRemarkEmitter *ORE)
       : Func(F), DL(F.getParent()->getDataLayout()), TTI(TTI), AA(AA), DT(DT),
         LI(LI), ORE(ORE) {}
 
@@ -727,8 +727,10 @@ public:
         Changed |= VisitStore(cast<StoreInst>(Inst), Op1, Op2, Builder);
     }
 
-    RemarkGenerator RemarkGen(Inst2ColumnMatrix, ORE, Func);
-    RemarkGen.emitRemarks();
+    if (ORE) {
+      RemarkGenerator RemarkGen(Inst2ColumnMatrix, *ORE, Func);
+      RemarkGen.emitRemarks();
+    }
 
     for (Instruction *Inst : reverse(ToRemove))
       Inst->eraseFromParent();
@@ -1085,7 +1087,7 @@ public:
     MemoryLocation StoreLoc = MemoryLocation::get(Store);
     MemoryLocation LoadLoc = MemoryLocation::get(Load);
 
-    AliasResult LdAliased = AA.alias(LoadLoc, StoreLoc);
+    AliasResult LdAliased = AA->alias(LoadLoc, StoreLoc);
 
     // If we can statically determine noalias we're good.
     if (!LdAliased)
@@ -1101,13 +1103,13 @@ public:
     // as we adjust Check0 and Check1's branches.
     SmallVector<DominatorTree::UpdateType, 4> DTUpdates;
     for (BasicBlock *Succ : successors(Check0))
-      DTUpdates.push_back({DT.Delete, Check0, Succ});
+      DTUpdates.push_back({DT->Delete, Check0, Succ});
 
-    BasicBlock *Check1 = SplitBlock(MatMul->getParent(), MatMul, nullptr, &LI,
+    BasicBlock *Check1 = SplitBlock(MatMul->getParent(), MatMul, nullptr, LI,
                                     nullptr, "alias_cont");
     BasicBlock *Copy =
-        SplitBlock(MatMul->getParent(), MatMul, nullptr, &LI, nullptr, "copy");
-    BasicBlock *Fusion = SplitBlock(MatMul->getParent(), MatMul, nullptr, &LI,
+        SplitBlock(MatMul->getParent(), MatMul, nullptr, LI, nullptr, "copy");
+    BasicBlock *Fusion = SplitBlock(MatMul->getParent(), MatMul, nullptr, LI,
                                     nullptr, "no_alias");
 
     // Check if the loaded memory location begins before the end of the store
@@ -1152,11 +1154,11 @@ public:
     PHI->addIncoming(NewLd, Copy);
 
     // Adjust DT.
-    DTUpdates.push_back({DT.Insert, Check0, Check1});
-    DTUpdates.push_back({DT.Insert, Check0, Fusion});
-    DTUpdates.push_back({DT.Insert, Check1, Copy});
-    DTUpdates.push_back({DT.Insert, Check1, Fusion});
-    DT.applyUpdates(DTUpdates);
+    DTUpdates.push_back({DT->Insert, Check0, Check1});
+    DTUpdates.push_back({DT->Insert, Check0, Fusion});
+    DTUpdates.push_back({DT->Insert, Check1, Copy});
+    DTUpdates.push_back({DT->Insert, Check1, Fusion});
+    DT->applyUpdates(DTUpdates);
     return PHI;
   }
 
@@ -1272,7 +1274,7 @@ public:
   void LowerMatrixMultiplyFused(CallInst *MatMul,
                                 SmallPtrSetImpl<Instruction *> &FusedInsts) {
     if (!FuseMatrix || !MatMul->hasOneUse() ||
-        MatrixLayout != MatrixLayoutTy::ColumnMajor)
+        MatrixLayout != MatrixLayoutTy::ColumnMajor || !DT)
       return;
 
     auto *LoadOp0 = dyn_cast<LoadInst>(MatMul->getOperand(0));
@@ -1283,7 +1285,7 @@ public:
       // we create invalid IR.
       // FIXME: See if we can hoist the store address computation.
       auto *AddrI = dyn_cast<Instruction>(Store->getOperand(1));
-      if (AddrI && (!DT.dominates(AddrI, MatMul)))
+      if (AddrI && (!DT->dominates(AddrI, MatMul)))
         return;
 
       emitSIMDTiling(MatMul, LoadOp0, LoadOp1, Store, FusedInsts);
@@ -1868,7 +1870,7 @@ PreservedAnalyses LowerMatrixIntrinsicsPass::run(Function &F,
   auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
   auto &LI = AM.getResult<LoopAnalysis>(F);
 
-  LowerMatrixIntrinsics LMT(F, TTI, AA, DT, LI, ORE);
+  LowerMatrixIntrinsics LMT(F, TTI, &AA, &DT, &LI, &ORE);
   if (LMT.Visit()) {
     PreservedAnalyses PA;
     PA.preserveSet<CFGAnalyses>();
@@ -1894,7 +1896,7 @@ public:
     auto &AA = getAnalysis<AAResultsWrapperPass>().getAAResults();
     auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
     auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
-    LowerMatrixIntrinsics LMT(F, TTI, AA, DT, LI, ORE);
+    LowerMatrixIntrinsics LMT(F, TTI, &AA, &DT, &LI, &ORE);
     bool C = LMT.Visit();
     return C;
   }
@@ -1925,3 +1927,45 @@ INITIALIZE_PASS_END(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name,
 Pass *llvm::createLowerMatrixIntrinsicsPass() {
   return new LowerMatrixIntrinsicsLegacyPass();
 }
+
+namespace {
+
+/// A lightweight version of the matrix lowering pass that only requires TTI.
+/// Advanced features that require DT, AA or ORE like tiling are disabled. This
+/// is used to lower matrix intrinsics if the main lowering pass is not run, for
+/// example with -O0.
+class LowerMatrixIntrinsicsMinimalLegacyPass : public FunctionPass {
+public:
+  static char ID;
+
+  LowerMatrixIntrinsicsMinimalLegacyPass() : FunctionPass(ID) {
+    initializeLowerMatrixIntrinsicsMinimalLegacyPassPass(
+        *PassRegistry::getPassRegistry());
+  }
+
+  bool runOnFunction(Function &F) override {
+    auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
+    LowerMatrixIntrinsics LMT(F, TTI, nullptr, nullptr, nullptr, nullptr);
+    bool C = LMT.Visit();
+    return C;
+  }
+
+  void getAnalysisUsage(AnalysisUsage &AU) const override {
+    AU.addRequired<TargetTransformInfoWrapperPass>();
+    AU.setPreservesCFG();
+  }
+};
+} // namespace
+
+static const char pass_name_minimal[] = "Lower the matrix intrinsics (minimal)";
+char LowerMatrixIntrinsicsMinimalLegacyPass::ID = 0;
+INITIALIZE_PASS_BEGIN(LowerMatrixIntrinsicsMinimalLegacyPass,
+                      "lower-matrix-intrinsics-minimal", pass_name_minimal,
+                      false, false)
+INITIALIZE_PASS_END(LowerMatrixIntrinsicsMinimalLegacyPass,
+                    "lower-matrix-intrinsics-minimal", pass_name_minimal, false,
+                    false)
+
+Pass *llvm::createLowerMatrixIntrinsicsMinimalPass() {
+  return new LowerMatrixIntrinsicsMinimalLegacyPass();
+}
index 42f79d8..a059844 100644 (file)
@@ -83,6 +83,7 @@ void llvm::initializeScalarOpts(PassRegistry &Registry) {
   initializeLowerExpectIntrinsicPass(Registry);
   initializeLowerGuardIntrinsicLegacyPassPass(Registry);
   initializeLowerMatrixIntrinsicsLegacyPassPass(Registry);
+  initializeLowerMatrixIntrinsicsMinimalLegacyPassPass(Registry);
   initializeLowerWidenableConditionLegacyPassPass(Registry);
   initializeMemCpyOptLegacyPassPass(Registry);
   initializeMergeICmpsLegacyPassPass(Registry);
index 5e2d272..401cbb9 100644 (file)
@@ -4,19 +4,10 @@
 
 ; CHECK:      Pass Arguments:
 ; CHECK-NEXT: Target Transform Information
-; CHECK-NEXT: Target Library Information
-; CHECK-NEXT: Assumption Cache Tracker
 ; CHECK-NEXT:   FunctionPass Manager
 ; CHECK-NEXT:     Module Verifier
 ; CHECK-NEXT:     Instrument function entry/exit with calls to e.g. mcount() (pre inlining)
-; CHECK-NEXT:     Dominator Tree Construction
-; CHECK-NEXT:     Natural Loop Information
-; CHECK-NEXT:     Lazy Branch Probability Analysis
-; CHECK-NEXT:     Lazy Block Frequency Analysis
-; CHECK-NEXT:     Optimization Remark Emitter
-; CHECK-NEXT:     Basic Alias Analysis (stateless AA impl)
-; CHECK-NEXT:     Function Alias Analysis Results
-; CHECK-NEXT:     Lower the matrix intrinsics
+; CHECK-NEXT:     Lower the matrix intrinsics (minimal)
 
 
 define void @f() {
diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/multiply-minimal.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/multiply-minimal.ll
new file mode 100644 (file)
index 0000000..1271de4
--- /dev/null
@@ -0,0 +1,65 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt -lower-matrix-intrinsics-minimal -fuse-matrix-tile-size=2 -matrix-allow-contract -force-fuse-matrix -instcombine -verify-dom-info %s -S | FileCheck %s
+
+; Test for the minimal version of the matrix lowering pass, which does not
+; require DT or AA. Make sure no tiling is happening, even though it was
+; requested.
+
+; REQUIRES: aarch64-registered-target
+
+target datalayout = "e-m:o-i64:64-f80:128-n8:8:32:64-S128"
+target triple = "aarch64-apple-ios"
+
+define void @multiply(<8 x double> * %A, <8 x double> * %B, <4 x double>* %C) {
+; CHECK-LABEL: @multiply(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[VEC_CAST:%.*]] = bitcast <8 x double>* [[A:%.*]] to <2 x double>*
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load <2 x double>, <2 x double>* [[VEC_CAST]], align 8
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr <8 x double>, <8 x double>* [[A]], i64 0, i64 2
+; CHECK-NEXT:    [[VEC_CAST1:%.*]] = bitcast double* [[VEC_GEP]] to <2 x double>*
+; CHECK-NEXT:    [[COL_LOAD2:%.*]] = load <2 x double>, <2 x double>* [[VEC_CAST1]], align 8
+; CHECK-NEXT:    [[VEC_GEP3:%.*]] = getelementptr <8 x double>, <8 x double>* [[A]], i64 0, i64 4
+; CHECK-NEXT:    [[VEC_CAST4:%.*]] = bitcast double* [[VEC_GEP3]] to <2 x double>*
+; CHECK-NEXT:    [[COL_LOAD5:%.*]] = load <2 x double>, <2 x double>* [[VEC_CAST4]], align 8
+; CHECK-NEXT:    [[VEC_GEP6:%.*]] = getelementptr <8 x double>, <8 x double>* [[A]], i64 0, i64 6
+; CHECK-NEXT:    [[VEC_CAST7:%.*]] = bitcast double* [[VEC_GEP6]] to <2 x double>*
+; CHECK-NEXT:    [[COL_LOAD8:%.*]] = load <2 x double>, <2 x double>* [[VEC_CAST7]], align 8
+; CHECK-NEXT:    [[VEC_CAST9:%.*]] = bitcast <8 x double>* [[B:%.*]] to <4 x double>*
+; CHECK-NEXT:    [[COL_LOAD10:%.*]] = load <4 x double>, <4 x double>* [[VEC_CAST9]], align 8
+; CHECK-NEXT:    [[VEC_GEP11:%.*]] = getelementptr <8 x double>, <8 x double>* [[B]], i64 0, i64 4
+; CHECK-NEXT:    [[VEC_CAST12:%.*]] = bitcast double* [[VEC_GEP11]] to <4 x double>*
+; CHECK-NEXT:    [[COL_LOAD13:%.*]] = load <4 x double>, <4 x double>* [[VEC_CAST12]], align 8
+; CHECK-NEXT:    [[SPLAT_SPLAT:%.*]] = shufflevector <4 x double> [[COL_LOAD10]], <4 x double> undef, <2 x i32> zeroinitializer
+; CHECK-NEXT:    [[TMP0:%.*]] = fmul <2 x double> [[COL_LOAD]], [[SPLAT_SPLAT]]
+; CHECK-NEXT:    [[SPLAT_SPLAT16:%.*]] = shufflevector <4 x double> [[COL_LOAD10]], <4 x double> undef, <2 x i32> <i32 1, i32 1>
+; CHECK-NEXT:    [[TMP1:%.*]] = call <2 x double> @llvm.fmuladd.v2f64(<2 x double> [[COL_LOAD2]], <2 x double> [[SPLAT_SPLAT16]], <2 x double> [[TMP0]])
+; CHECK-NEXT:    [[SPLAT_SPLAT19:%.*]] = shufflevector <4 x double> [[COL_LOAD10]], <4 x double> undef, <2 x i32> <i32 2, i32 2>
+; CHECK-NEXT:    [[TMP2:%.*]] = call <2 x double> @llvm.fmuladd.v2f64(<2 x double> [[COL_LOAD5]], <2 x double> [[SPLAT_SPLAT19]], <2 x double> [[TMP1]])
+; CHECK-NEXT:    [[SPLAT_SPLAT22:%.*]] = shufflevector <4 x double> [[COL_LOAD10]], <4 x double> undef, <2 x i32> <i32 3, i32 3>
+; CHECK-NEXT:    [[TMP3:%.*]] = call <2 x double> @llvm.fmuladd.v2f64(<2 x double> [[COL_LOAD8]], <2 x double> [[SPLAT_SPLAT22]], <2 x double> [[TMP2]])
+; CHECK-NEXT:    [[SPLAT_SPLAT25:%.*]] = shufflevector <4 x double> [[COL_LOAD13]], <4 x double> undef, <2 x i32> zeroinitializer
+; CHECK-NEXT:    [[TMP4:%.*]] = fmul <2 x double> [[COL_LOAD]], [[SPLAT_SPLAT25]]
+; CHECK-NEXT:    [[SPLAT_SPLAT28:%.*]] = shufflevector <4 x double> [[COL_LOAD13]], <4 x double> undef, <2 x i32> <i32 1, i32 1>
+; CHECK-NEXT:    [[TMP5:%.*]] = call <2 x double> @llvm.fmuladd.v2f64(<2 x double> [[COL_LOAD2]], <2 x double> [[SPLAT_SPLAT28]], <2 x double> [[TMP4]])
+; CHECK-NEXT:    [[SPLAT_SPLAT31:%.*]] = shufflevector <4 x double> [[COL_LOAD13]], <4 x double> undef, <2 x i32> <i32 2, i32 2>
+; CHECK-NEXT:    [[TMP6:%.*]] = call <2 x double> @llvm.fmuladd.v2f64(<2 x double> [[COL_LOAD5]], <2 x double> [[SPLAT_SPLAT31]], <2 x double> [[TMP5]])
+; CHECK-NEXT:    [[SPLAT_SPLAT34:%.*]] = shufflevector <4 x double> [[COL_LOAD13]], <4 x double> undef, <2 x i32> <i32 3, i32 3>
+; CHECK-NEXT:    [[TMP7:%.*]] = call <2 x double> @llvm.fmuladd.v2f64(<2 x double> [[COL_LOAD8]], <2 x double> [[SPLAT_SPLAT34]], <2 x double> [[TMP6]])
+; CHECK-NEXT:    [[VEC_CAST35:%.*]] = bitcast <4 x double>* [[C:%.*]] to <2 x double>*
+; CHECK-NEXT:    store <2 x double> [[TMP3]], <2 x double>* [[VEC_CAST35]], align 8
+; CHECK-NEXT:    [[VEC_GEP36:%.*]] = getelementptr <4 x double>, <4 x double>* [[C]], i64 0, i64 2
+; CHECK-NEXT:    [[VEC_CAST37:%.*]] = bitcast double* [[VEC_GEP36]] to <2 x double>*
+; CHECK-NEXT:    store <2 x double> [[TMP7]], <2 x double>* [[VEC_CAST37]], align 8
+; CHECK-NEXT:    ret void
+;
+entry:
+  %a = load <8 x double>, <8 x double>* %A, align 8
+  %b = load <8 x double>, <8 x double>* %B, align 8
+
+  %c = call <4 x double> @llvm.matrix.multiply(<8 x double> %a, <8 x double> %b, i32 2, i32 4, i32 2)
+
+  store <4 x double> %c, <4 x double>* %C, align 8
+  ret void
+}
+
+declare <4 x double> @llvm.matrix.multiply(<8 x double>, <8 x double>, i32, i32, i32)