[MLIR] Extend vectorization to 2+-D patterns
authorNicolas Vasilache <ntv@google.com>
Thu, 1 Nov 2018 14:14:14 +0000 (07:14 -0700)
committerjpienaar <jpienaar@google.com>
Fri, 29 Mar 2019 20:46:58 +0000 (13:46 -0700)
This CL adds support for vectorization using more interesting 2-D and 3-D
patterns. Note in particular the fact that we match some pretty complex
imperfectly nested 2-D patterns with a quite minimal change to the
implementation: we just add a bit of recursion to traverse the matched
patterns and actually vectorize the loops.

For instance, vectorizing the following loop by 128:
```
for %i3 = 0 to %0 {
  %7 = affine_apply (d0) -> (d0)(%i3)
  %8 = load %arg0[%c0_0, %7] : memref<?x?xf32>
}
```

Currently generates:
```
#map0 = ()[s0] -> (s0 + 127)
#map1 = (d0) -> (d0)
for %i3 = 0 to #map0()[%0] step 128 {
  %9 = affine_apply #map1(%i3)
  %10 = alloc() : memref<1xvector<128xf32>>
  %11 = "n_d_unaligned_load"(%arg0, %c0_0, %9, %10, %c0) :
    (memref<?x?xf32>, index, index, memref<1xvector<128xf32>>, index) ->
    (memref<?x?xf32>, index, index, memref<1xvector<128xf32>>, index)
   %12 = load %10[%c0] : memref<1xvector<128xf32>>
}
```

The above is subject to evolution.

PiperOrigin-RevId: 219629745

mlir/include/mlir/Analysis/LoopAnalysis.h
mlir/lib/Analysis/LoopAnalysis.cpp
mlir/lib/Transforms/Vectorize.cpp
mlir/test/Transforms/vectorize.mlir

index 6820ee8ad3f1b94dde2c1fd2bf9b15614ecfc6a4..91c8e7478369a869559429046fdabe874b2c70aa 100644 (file)
@@ -52,13 +52,22 @@ uint64_t getLargestDivisorOfTripCount(const ForStmt &forStmt);
 // For now we assume no layout map or identity layout map in the MemRef.
 // TODO(ntv): support more than identity layout map.
 bool isAccessInvariant(const MLValue &input, MemRefType memRefType,
-                       llvm::ArrayRef<MLValue *> indices, unsigned dim);
+                       llvm::ArrayRef<const MLValue *> indices, unsigned dim);
 
-/// Checks whether all the LoadOp and StoreOp matched have access indexing
-/// functions that are are either:
+/// Checks whether the loop is structurally vectorizable; i.e.:
+/// 1. the loop has proper dependence semantics (parallel, reduction, etc);
+/// 2. no conditionals are nested under the loop;
+/// 3. all nested load/stores are to scalar MemRefs.
+/// TODO(ntv): implement dependence semantics
+/// TODO(ntv): relax the no-conditionals restriction
+bool isVectorizableLoop(const ForStmt &loop);
+
+/// Checks whether the loop is structurally vectorizable and that all the LoadOp
+/// and StoreOp matched have access indexing functions that are are either:
 ///   1. invariant along the loop induction variable created by 'loop';
 ///   2. varying along the 'fastestVaryingDim' memory dimension.
-bool isVectorizableLoop(const ForStmt &loop, unsigned fastestVaryingDim);
+bool isVectorizableLoopAlongFastestVaryingMemRefDim(const ForStmt &loop,
+                                                    unsigned fastestVaryingDim);
 
 /// Checks where SSA dominance would be violated if a for stmt's body statements
 /// are shifted by the specified shifts. This method checks if a 'def' and all
index 1904a6366477629ce469cb19cbfee000bdf0e48c..ce0bc6c3502a715bcbd6b78e1078ec53eaccad1f 100644 (file)
@@ -119,12 +119,12 @@ uint64_t mlir::getLargestDivisorOfTripCount(const ForStmt &forStmt) {
 }
 
 bool mlir::isAccessInvariant(const MLValue &input, MemRefType memRefType,
-                             ArrayRef<MLValue *> indices, unsigned dim) {
+                             ArrayRef<const MLValue *> indices, unsigned dim) {
   assert(indices.size() == memRefType.getRank());
   assert(dim < indices.size());
   auto layoutMap = memRefType.getAffineMaps();
   assert(memRefType.getAffineMaps().size() <= 1);
-  // TODO(ntv): remove dependency on Builder once we support non-identity
+  // TODO(ntv): remove dependence on Builder once we support non-identity
   // layout map.
   Builder b(memRefType.getContext());
   assert(layoutMap.empty() ||
@@ -132,7 +132,8 @@ bool mlir::isAccessInvariant(const MLValue &input, MemRefType memRefType,
   (void)layoutMap;
 
   SmallVector<OperationStmt *, 4> affineApplyOps;
-  getReachableAffineApplyOps({indices[dim]}, affineApplyOps);
+  getReachableAffineApplyOps({const_cast<MLValue *>(indices[dim])},
+                             affineApplyOps);
 
   if (affineApplyOps.empty()) {
     // Pointer equality test because of MLValue pointer semantics.
@@ -168,7 +169,7 @@ static bool isContiguousAccess(const MLValue &input,
                                LoadOrStoreOpPointer memoryOp,
                                unsigned fastestVaryingDim) {
   using namespace functional;
-  auto indices = map([](SSAValue *val) { return dyn_cast<MLValue>(val); },
+  auto indices = map([](const SSAValue *val) { return dyn_cast<MLValue>(val); },
                      memoryOp->getIndices());
   auto memRefType = memoryOp->getMemRefType();
   for (unsigned d = 0, numIndices = indices.size(); d < numIndices; ++d) {
@@ -188,7 +189,11 @@ static bool isVectorElement(LoadOrStoreOpPointer memoryOp) {
   return memRefType.getElementType().template isa<VectorType>();
 }
 
-bool mlir::isVectorizableLoop(const ForStmt &loop, unsigned fastestVaryingDim) {
+using VectorizableStmtFun =
+    std::function<bool(const ForStmt &, const OperationStmt &)>;
+
+static bool isVectorizableLoopWithCond(const ForStmt &loop,
+                                       VectorizableStmtFun isVectorizableStmt) {
   if (!matcher::isParallelLoop(loop) && !matcher::isReductionLoop(loop)) {
     return false;
   }
@@ -214,15 +219,32 @@ bool mlir::isVectorizableLoop(const ForStmt &loop, unsigned fastestVaryingDim) {
     if (vector) {
       return false;
     }
-    bool contiguous = load ? isContiguousAccess(loop, load, fastestVaryingDim)
-                           : isContiguousAccess(loop, store, fastestVaryingDim);
-    if (!contiguous) {
+    if (!isVectorizableStmt(loop, *op)) {
       return false;
     }
   }
   return true;
 }
 
+bool mlir::isVectorizableLoopAlongFastestVaryingMemRefDim(
+    const ForStmt &loop, unsigned fastestVaryingDim) {
+  VectorizableStmtFun fun(
+      [fastestVaryingDim](const ForStmt &loop, const OperationStmt &op) {
+        auto load = op.dyn_cast<LoadOp>();
+        auto store = op.dyn_cast<StoreOp>();
+        return load ? isContiguousAccess(loop, load, fastestVaryingDim)
+                    : isContiguousAccess(loop, store, fastestVaryingDim);
+      });
+  return isVectorizableLoopWithCond(loop, fun);
+}
+
+bool mlir::isVectorizableLoop(const ForStmt &loop) {
+  VectorizableStmtFun fun(
+      // TODO: implement me
+      [](const ForStmt &loop, const OperationStmt &op) { return true; });
+  return isVectorizableLoopWithCond(loop, fun);
+}
+
 /// Checks whether SSA dominance would be violated if a for stmt's body
 /// statements are shifted by the specified shifts. This method checks if a
 /// 'def' and all its uses have the same shift factor.
index a6c9681b6f5927bc4db1d4a638524896dbffe3e5..fa97b7025d426fd8328e4ee5098be6efdc4ba823 100644 (file)
@@ -40,7 +40,6 @@
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/raw_ostream.h"
-#include <functional>
 
 using namespace llvm;
 using namespace mlir;
@@ -54,34 +53,47 @@ using namespace mlir;
 ///      operation. The full semantics of this unaligned load/store is still
 ///      TBD.
 ///
-/// Which loop transformation to apply to coarsen for early vectorization is
-/// still subject to exploratory tradeoffs. In particular, say we want to
-/// vectorize by a factor 128, we want to transform:
+/// Loop transformation:
+//  ====================
+/// The choice of loop transformation to apply for coarsening vectorized loops
+/// is still subject to exploratory tradeoffs. In particular, say we want to
+/// vectorize by a factor 128, we want to transform the following input:
 ///     for %i = %M to %N {
-///       load/store(f(i)) ...
+///       %a = load(f(i))
 ///
-///   traditionally, one would vectorize late (after scheduling, tiling,
+///   Traditionally, one would vectorize late (after scheduling, tiling,
 ///   memory promotion etc) say after stripmining (and potentially unrolling in
 ///   the case of LLVM's SLP vectorizer):
 ///     for %i = floor(%M, 128) to ceil(%N, 128) {
 ///       for %ii = max(%M, 128 * %i) to min(%N, 128*%i + 127) {
-///         load/store(f(ii)) ...
+///         load/store(f(ii))
 ///
-///   we seek to vectorize early and freeze vector types before scheduling:
+///   We seek to vectorize early and freeze vector types before scheduling, so
+///   we want to generate a pattern that resembles:
 ///     for %i = ? to ? step ? {
-///       unaligned_load/unaligned_store(g(i)) ...
+///       unaligned_load/unaligned_store(g(i))
 ///
 ///   i. simply dividing the lower / upper bounds by 128 creates issues
 ///   with representing expressions such as ii + 1 because now we only
 ///   have access to original values that have been divided. Additional
 ///   information is needed to specify accesses at below 128 granularity;
 ///   ii. another alternative is to coarsen the loop step but this may have
-///   consequences on dependency analysis and fusability of loops: fusable
+///   consequences on dependence analysis and fusability of loops: fusable
 ///   loops probably need to have the same step (because we don't want to
 ///   stripmine/unroll to enable fusion).
 /// As a consequence, we choose to represent the coarsening using the loop
 /// step for now and reevaluate in the future. Note that we can renormalize
 /// loop steps later if/when we have evidence that they are problematic.
+///
+/// For the simple strawman example above, vectorizing for a 1-D vector
+/// abstraction of size 128 returns code similar to:
+///   %c0 = constant 0 : index
+///   for %i = %M to %N + 127 step 128 {
+///     %a = alloc() : memref<1xvector<128xf32>>
+///     %r = "n_d_unaligned_load"(%tensor, %i, %a, %c0)
+///     %16 = load %a[%c0] : memref<1xvector<128xf32>>
+///
+/// Note this is still work in progress and not yet functional.
 
 #define DEBUG_TYPE "early-vect"
 
@@ -92,9 +104,9 @@ static cl::list<int> clVirtualVectorSize(
 
 static cl::list<int> clFastestVaryingPattern(
     "test-fastest-varying",
-    cl::desc("Specify a 1-D pattern of fastest varying memory dimensions"
-             " to match. See defaultPatterns in Vectorize.cpp for a description"
-             " and examples. This is used for testing purposes"),
+    cl::desc("Specify a 1-D, 2-D or 3-D pattern of fastest varying memory "
+             "dimensions to match. See defaultPatterns in Vectorize.cpp for a "
+             "description and examples. This is used for testing purposes"),
     cl::ZeroOrMore);
 
 /// Forward declaration.
@@ -104,21 +116,56 @@ isVectorizableLoopPtrFactory(unsigned fastestVaryingMemRefDimension);
 // Build a bunch of predetermined patterns that will be traversed in order.
 // Due to the recursive nature of MLFunctionMatchers, this captures
 // arbitrarily nested pairs of loops at any position in the tree.
-// TODO(ntv): support 2-D and 3-D loop patterns with a common reduction loop
-// that can be matched to GEMMs.
+/// Note that this currently only matches 2 nested loops and will be extended.
+// TODO(ntv): support 3-D loop patterns with a common reduction loop that can
+// be matched to GEMMs.
 static std::vector<MLFunctionMatcher> defaultPatterns() {
   using matcher::For;
   return std::vector<MLFunctionMatcher>{
-      // for i { A[ ??f(not i) , f(i)];}
+      // 3-D patterns
+      For(isVectorizableLoopPtrFactory(2),
+          For(isVectorizableLoopPtrFactory(1),
+              For(isVectorizableLoopPtrFactory(0)))),
+      // for i { for j { A[??f(not i, not j), f(i, not j), f(not i, j)];}}
+      // test independently with:
+      //   --test-fastest-varying=1 --test-fastest-varying=0
+      For(isVectorizableLoopPtrFactory(1),
+          For(isVectorizableLoopPtrFactory(0))),
+      // for i { for j { A[??f(not i, not j), f(i, not j), ?, f(not i, j)];}}
+      // test independently with:
+      //   --test-fastest-varying=2 --test-fastest-varying=0
+      For(isVectorizableLoopPtrFactory(2),
+          For(isVectorizableLoopPtrFactory(0))),
+      // for i { for j { A[??f(not i, not j), f(i, not j), ?, ?, f(not i, j)];}}
+      // test independently with:
+      //   --test-fastest-varying=3 --test-fastest-varying=0
+      For(isVectorizableLoopPtrFactory(3),
+          For(isVectorizableLoopPtrFactory(0))),
+      // for i { for j { A[??f(not i, not j), f(not i, j), f(i, not j)];}}
+      // test independently with:
+      //   --test-fastest-varying=0 --test-fastest-varying=1
+      For(isVectorizableLoopPtrFactory(0),
+          For(isVectorizableLoopPtrFactory(1))),
+      // for i { for j { A[??f(not i, not j), f(not i, j), ?, f(i, not j)];}}
+      // test independently with:
+      //   --test-fastest-varying=0 --test-fastest-varying=2
+      For(isVectorizableLoopPtrFactory(0),
+          For(isVectorizableLoopPtrFactory(2))),
+      // for i { for j { A[??f(not i, not j), f(not i, j), ?, ?, f(i, not j)];}}
+      // test independently with:
+      //   --test-fastest-varying=0 --test-fastest-varying=3
+      For(isVectorizableLoopPtrFactory(0),
+          For(isVectorizableLoopPtrFactory(3))),
+      // for i { A[??f(not i) , f(i)];}
       // test independently with:  --test-fastest-varying=0
       For(isVectorizableLoopPtrFactory(0)),
-      // for i { A[ ??f(not i) , f(i), ?];}
+      // for i { A[??f(not i) , f(i), ?];}
       // test independently with:  --test-fastest-varying=1
       For(isVectorizableLoopPtrFactory(1)),
-      // for i { A[ ??f(not i) , f(i), ?, ?];}
+      // for i { A[??f(not i) , f(i), ?, ?];}
       // test independently with:  --test-fastest-varying=2
       For(isVectorizableLoopPtrFactory(2)),
-      // for i { A[ ??f(not i) , f(i), ?, ?, ?];}
+      // for i { A[??f(not i) , f(i), ?, ?, ?];}
       // test independently with:  --test-fastest-varying=3
       For(isVectorizableLoopPtrFactory(3))};
 }
@@ -131,9 +178,17 @@ static std::vector<MLFunctionMatcher> makePatterns() {
   switch (clFastestVaryingPattern.size()) {
   case 1:
     return {For(isVectorizableLoopPtrFactory(clFastestVaryingPattern[0]))};
+  case 2:
+    return {For(isVectorizableLoopPtrFactory(clFastestVaryingPattern[0]),
+                For(isVectorizableLoopPtrFactory(clFastestVaryingPattern[1])))};
+  case 3:
+    return {For(
+        isVectorizableLoopPtrFactory(clFastestVaryingPattern[0]),
+        For(isVectorizableLoopPtrFactory(clFastestVaryingPattern[1]),
+            For(isVectorizableLoopPtrFactory(clFastestVaryingPattern[2]))))};
   default:
-    assert(false && "Only up to 1-D fastest varying pattern supported atm");
-  };
+    assert(false && "Only up to 3-D fastest varying pattern supported atm");
+  }
   return std::vector<MLFunctionMatcher>();
 }
 
@@ -159,13 +214,11 @@ struct Strategy {
 
 /// Implements a simple strawman strategy for vectorization.
 /// Given a matched pattern `matches` of depth `patternDepth`, this strategy
-/// greedily assigns the fastest varying dimension **of the vector** to the
+/// greedily assigns the fastest varying dimension ** of the vector ** to the
 /// innermost loop in the pattern.
-/// When coupled with a pattern that looks for the fastest varying dimension
-/// ** in load/store MemRefs**, this creates a generic vectorization strategy
-/// that works for any loop in a hierarchy (outermost, innermost or
-/// intermediate) as well as any fastest varying dimension in a load/store
-/// MemRef.
+/// When coupled with a pattern that looks for the fastest varying dimension in
+/// load/store MemRefs, this creates a generic vectorization strategy that works
+/// for any loop in a hierarchy (outermost, innermost or intermediate).
 ///
 /// TODO(ntv): In the future we should additionally increase the power of the
 /// profitability analysis along 3 directions:
@@ -179,6 +232,11 @@ static bool analyzeProfitability(MLFunctionMatches matches,
                                  Strategy *strategy) {
   for (auto m : matches) {
     auto *loop = cast<ForStmt>(m.first);
+    LLVM_DEBUG(dbgs() << "[early-vect][profitability] patternDepth: "
+                      << patternDepth << " depthInPattern: " << depthInPattern
+                      << " loop ");
+    LLVM_DEBUG(loop->print(dbgs()));
+    LLVM_DEBUG(dbgs() << "\n");
     bool fail = analyzeProfitability(m.second, depthInPattern + 1, patternDepth,
                                      strategy);
     if (fail) {
@@ -188,6 +246,10 @@ static bool analyzeProfitability(MLFunctionMatches matches,
     if (patternDepth - depthInPattern <= clVirtualVectorSize.size()) {
       strategy->loopToVectorDim[loop] =
           clVirtualVectorSize.size() - (patternDepth - depthInPattern);
+      LLVM_DEBUG(dbgs() << "[early-vect][profitability] vectorize @ "
+                        << strategy->loopToVectorDim[loop] << " loop ");
+      LLVM_DEBUG(loop->print(dbgs()));
+      LLVM_DEBUG(dbgs() << "\n");
     } else {
       // Don't vectorize
       strategy->loopToVectorDim[loop] = -1;
@@ -323,13 +385,15 @@ static void createUnalignedStore(MLFuncBuilder *b, Location *loc,
 /// MemRef<1 x vector_type> + a custom unaligned load/store pseudoop.
 /// The vector load/store accessing this MemRef always accesses element 0, so we
 /// just memoize a single 0 SSAValue, once upon function entry to avoid clutter.
-static SSAValue *getZeroIndex(MLFuncBuilder *b) {
-  static SSAValue *z = nullptr;
-  if (!z) {
-    auto zero = b->createChecked<ConstantIndexOp>(b->getUnknownLoc(), 0);
-    z = zero->getOperation()->getResult(0);
+/// We create one such SSAValue per function.
+static SSAValue *insertZeroIndex(MLFunction *f) {
+  static thread_local DenseMap<MLFunction *, SSAValue *> zeros;
+  if (zeros.find(f) == zeros.end()) {
+    MLFuncBuilder b(f);
+    auto zero = b.create<ConstantIndexOp>(b.getUnknownLoc(), 0);
+    zeros.insert(std::make_pair(f, zero));
   }
-  return z;
+  return zeros.lookup(f);
 }
 
 /// Unwraps a pointer type to another type (possibly the same).
@@ -356,17 +420,18 @@ static MLValue *materializeVector(MLValue *iv, LoadOrStoreOpPointer memoryOp,
   auto *opStmt = cast<OperationStmt>(memoryOp->getOperation());
   MLFuncBuilder b(opStmt);
   // Create an AllocOp to apply the new shape.
-  auto allocOp = b.createChecked<AllocOp>(opStmt->getLoc(), vectorMemRefType,
-                                          ArrayRef<SSAValue *>{});
+  auto allocOp = b.create<AllocOp>(opStmt->getLoc(), vectorMemRefType,
+                                   ArrayRef<SSAValue *>{});
   auto *allocMemRef = memoryOp->getMemRef();
   using namespace functional;
   if (opStmt->template isa<LoadOp>()) {
     createUnalignedLoad(&b, opStmt->getLoc(), allocMemRef,
                         map(unwrapPtr<SSAValue>(), memoryOp->getIndices()),
-                        allocOp->getResult(), {getZeroIndex(&b)});
+                        allocOp->getResult(),
+                        {insertZeroIndex(iv->getFunction())});
   } else {
     createUnalignedStore(&b, opStmt->getLoc(), allocOp->getResult(),
-                         {getZeroIndex(&b)}, allocMemRef,
+                         {insertZeroIndex(iv->getFunction())}, allocMemRef,
                          map(unwrapPtr<SSAValue>(), memoryOp->getIndices()));
   }
 
@@ -377,6 +442,9 @@ static MLValue *materializeVector(MLValue *iv, LoadOrStoreOpPointer memoryOp,
 namespace {
 
 struct VectorizationState {
+  // `vectorizedByThisPattern` keeps track of statements that have already been
+  // vectorized by this pattern. This allows distinguishing between
+  DenseSet<OperationStmt *> vectorizedByThisPattern;
   DenseSet<ForStmt *> vectorized;
   const Strategy *strategy;
 };
@@ -385,18 +453,18 @@ struct VectorizationState {
 /// Terminal template function for creating a LoadOp.
 static OpPointer<LoadOp> createLoad(MLFuncBuilder *b, Location *loc,
                                     MLValue *memRef) {
-  using namespace functional;
-  return b->createChecked<LoadOp>(loc, memRef,
-                                  ArrayRef<SSAValue *>{getZeroIndex(b)});
+  return b->create<LoadOp>(
+      loc, memRef,
+      ArrayRef<SSAValue *>{insertZeroIndex(memRef->getFunction())});
 }
 
 /// Terminal template function for creating a StoreOp.
 static OpPointer<StoreOp> createStore(MLFuncBuilder *b, Location *loc,
                                       MLValue *memRef,
                                       OpPointer<StoreOp> store) {
-  using namespace functional;
-  return b->createChecked<StoreOp>(loc, store->getValueToStore(), memRef,
-                                   ArrayRef<SSAValue *>{getZeroIndex(b)});
+  return b->create<StoreOp>(
+      loc, store->getValueToStore(), memRef,
+      ArrayRef<SSAValue *>{insertZeroIndex(memRef->getFunction())});
 }
 
 /// Vectorizes the `memoryOp` of type LoadOp or StoreOp along loop `iv` by
@@ -419,6 +487,7 @@ static bool vectorize(MLValue *iv, LoadOrStoreOpPointer memoryOp,
     auto res = createStore(&b, opStmt->getLoc(), materializedMemRef, store);
     resultOperation = res->getOperation();
   }
+  state->vectorizedByThisPattern.insert(cast<OperationStmt>(resultOperation));
   return false;
 }
 
@@ -436,7 +505,14 @@ static bool vectorizeForStmt(ForStmt *loop, AffineMap upperBound,
                       upperBound);
   loop->setStep(step);
 
-  auto loadAndStores = matcher::Op(matcher::isLoadOrStore);
+  FilterFunctionType notVectorizedThisRound = [state](const Statement &stmt) {
+    if (!matcher::isLoadOrStore(stmt)) {
+      return false;
+    }
+    return state->vectorizedByThisPattern.count(cast<OperationStmt>(&stmt)) ==
+           0;
+  };
+  auto loadAndStores = matcher::Op(notVectorizedThisRound);
   auto matches = loadAndStores.match(loop);
   for (auto ls : matches) {
     auto *opStmt = cast<OperationStmt>(ls.first);
@@ -447,6 +523,7 @@ static bool vectorizeForStmt(ForStmt *loop, AffineMap upperBound,
     LLVM_DEBUG(dbgs() << "\n");
     bool vectorizationFails = load ? vectorize(loop, load, vectorSize, state)
                                    : vectorize(loop, store, vectorSize, state);
+    LLVM_DEBUG(dbgs() << "fail: " << vectorizationFails << "\n");
     if (vectorizationFails) {
       // Early exit and trigger RAII cleanups at the root.
       return true;
@@ -467,13 +544,28 @@ static FilterFunctionType
 isVectorizableLoopPtrFactory(unsigned fastestVaryingMemRefDimension) {
   return [fastestVaryingMemRefDimension](const Statement &forStmt) {
     const auto &loop = cast<ForStmt>(forStmt);
-    return isVectorizableLoop(loop, fastestVaryingMemRefDimension);
+    return isVectorizableLoopAlongFastestVaryingMemRefDim(
+        loop, fastestVaryingMemRefDimension);
   };
 }
 
-/// Apply vectorization of `loop` according to `state`.
-static bool doVectorize(ForStmt *loop, VectorizationState *state) {
-  // This loop may have been omitted from vectorization for various reasons
+/// Forward-declaration.
+static bool vectorizeNonRoot(MLFunctionMatches matches,
+                             VectorizationState *state);
+
+/// Apply vectorization of `loop` according to `state`. This is only triggered
+/// if all vectorizations in `childrenMatches` have already succeeded
+/// recursively in DFS post-order.
+static bool doVectorize(ForStmt *loop, MLFunctionMatches childrenMatches,
+                        VectorizationState *state) {
+  // 1. DFS postorder recursion, if any of my children fails, I fail too.
+  auto fail = vectorizeNonRoot(childrenMatches, state);
+  if (fail) {
+    // Early exit and trigger RAII cleanups at the root.
+    return true;
+  }
+
+  // 2. This loop may have been omitted from vectorization for various reasons
   // (e.g. due to the performance model or pattern depth > vector size).
   assert(state->strategy->loopToVectorDim.count(loop));
   assert(state->strategy->loopToVectorDim.find(loop) !=
@@ -484,7 +576,7 @@ static bool doVectorize(ForStmt *loop, VectorizationState *state) {
     return false;
   }
 
-  // Apply transformation.
+  // 3. Actual post-order transformation.
   assert(vectorDim < clVirtualVectorSize.size() && "vector dim overflow");
   //   a. get actual vector size
   auto vectorSize = clVirtualVectorSize[vectorDim];
@@ -498,14 +590,29 @@ static bool doVectorize(ForStmt *loop, VectorizationState *state) {
   std::function<AffineExpr(AffineExpr)> coarsenUb =
       [vectorSize](AffineExpr expr) { return expr + vectorSize - 1; };
   auto newUbs = functional::map(coarsenUb, ubMap.getResults());
+  //   d. recurse
   return vectorizeForStmt(
       loop,
       AffineMap::get(ubMap.getNumDims(), ubMap.getNumSymbols(), newUbs, {}),
       clVirtualVectorSize, loop->getStep() * vectorSize, state);
 }
 
+/// Non-root pattern iterates over the matches at this level, calls doVectorize
+/// and exits early if anything below fails.
+static bool vectorizeNonRoot(MLFunctionMatches matches,
+                             VectorizationState *state) {
+  for (auto m : matches) {
+    auto fail = doVectorize(cast<ForStmt>(m.first), m.second, state);
+    if (fail) {
+      // Early exit and trigger RAII cleanups at the root.
+      return true;
+    }
+  }
+  return false;
+}
+
 /// Sets up error handling for this root loop.
-/// Vectorization is a procedure where anything below can fail.
+/// Vectorization is a recursive procedure where anything below can fail.
 /// The root match thus needs to maintain a clone for handling failure.
 /// Each root may succeed independently but will otherwise clean after itself if
 /// anything below it fails.
@@ -520,17 +627,19 @@ static bool vectorizeRoot(MLFunctionMatches matches,
     // vectorizable. If a pattern is not vectorizable anymore, we just skip it.
     // TODO(ntv): implement a non-greedy profitability analysis that keeps only
     // non-intersecting patterns.
-    if (!isVectorizableLoop(*loop, 0)) {
-      // TODO(ntv): this is too restrictive and will break a bunch of patterns
-      // that do not require vectorization along the 0^th fastest memory
-      // dimension.
+    if (!isVectorizableLoop(*loop)) {
       continue;
     }
-
     DenseMap<const MLValue *, MLValue *> nomap;
     MLFuncBuilder builder(loop->getFunction());
     ForStmt *clonedLoop = cast<ForStmt>(builder.clone(*loop, nomap));
-    doVectorize(loop, state) ? loop->erase() : clonedLoop->erase();
+    auto fail = doVectorize(loop, m.second, state);
+    if (!fail) {
+      LLVM_DEBUG(dbgs() << "[early-vect] success vectorizing loop: ");
+      LLVM_DEBUG(loop->print(dbgs()));
+      LLVM_DEBUG(dbgs() << "\n");
+    }
+    fail ? loop->erase() : clonedLoop->erase();
   }
   return false;
 }
@@ -540,16 +649,14 @@ static bool vectorizeRoot(MLFunctionMatches matches,
 PassResult Vectorize::runOnMLFunction(MLFunction *f) {
   /// Build a zero at the entry of the function to avoid clutter in every single
   /// vectorized loop.
-  {
-    MLFuncBuilder b(f);
-    getZeroIndex(&b);
-  }
+  insertZeroIndex(f);
+
   for (auto pat : makePatterns()) {
     LLVM_DEBUG(dbgs() << "\n[early-vect] Input function is now:\n");
     LLVM_DEBUG(f->print(dbgs()));
+    LLVM_DEBUG(dbgs() << "\n[early-vect] match:\n");
     auto matches = pat.match(f);
     Strategy strategy;
-    assert(pat.getDepth() == 1 && "only 1-D patterns and vector supported atm");
     auto fail = analyzeProfitability(matches, 0, pat.getDepth(), &strategy);
     assert(!fail);
     VectorizationState state;
index 7122b445598c15dca85028e5eac476134744e11e..29f878b5e49ee93089936cfd09897263f27178bf 100644 (file)
@@ -1,4 +1,9 @@
-// RUN: mlir-opt %s -vectorize -virtual-vector-size 128 --test-fastest-varying=0 | FileCheck %s
+// RUN: mlir-opt %s -vectorize -virtual-vector-size 128 --test-fastest-varying=0 | FileCheck %s -check-prefix=VEC1D
+// RUN: mlir-opt %s -vectorize -virtual-vector-size 32 -virtual-vector-size 256 --test-fastest-varying=1 --test-fastest-varying=0 | FileCheck %s -check-prefix=VEC2D
+// RUN: mlir-opt %s -vectorize -virtual-vector-size 32 -virtual-vector-size 256 --test-fastest-varying=0 --test-fastest-varying=1 | FileCheck %s -check-prefix=VEC2D_T
+// RUN: mlir-opt %s -vectorize -virtual-vector-size 32 -virtual-vector-size 256 --test-fastest-varying=2 --test-fastest-varying=0 | FileCheck %s -check-prefix=VEC2D_O
+// RUN: mlir-opt %s -vectorize -virtual-vector-size 32 -virtual-vector-size 256 --test-fastest-varying=0 --test-fastest-varying=2 | FileCheck %s -check-prefix=VEC2D_OT
+// RUN: mlir-opt %s -vectorize -virtual-vector-size 32 -virtual-vector-size 64 -virtual-vector-size 256 --test-fastest-varying=2 --test-fastest-varying=1 --test-fastest-varying=0 | FileCheck %s -check-prefix=VEC3D
 
 #map0 = (d0) -> (d0)
 #map1 = (d0, d1) -> (d0, d1)
 #mapadd3 = (d0) -> (d0 + 3)
 #set0 = (i) : (i >= 0)
 // Maps introduced to vectorize fastest varying memory index.
-// CHECK: [[MAPSHIFT:#map[0-9]*]] = ()[s0] -> (s0 + 127)
+// VEC1D: [[MAPSHIFT:#map[0-9]*]] = ()[s0] -> (s0 + 127)
 mlfunc @vec1d(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
-// CHECK: [[C0:%[a-z0-9_]+]] = constant 0 : index
-// CHECK: [[ARG_M:%[0-9]+]] = dim %arg0, 0 : memref<?x?xf32>
-// CHECK: [[ARG_N:%[0-9]+]] = dim %arg0, 1 : memref<?x?xf32>
-// CHECK: [[ARG_P:%[0-9]+]] = dim %arg1, 2 : memref<?x?x?xf32>
+// VEC1D: [[C0:%[a-z0-9_]+]] = constant 0 : index
+// VEC1D: [[ARG_M:%[0-9]+]] = dim %arg0, 0 : memref<?x?xf32>
+// VEC1D: [[ARG_N:%[0-9]+]] = dim %arg0, 1 : memref<?x?xf32>
+// VEC1D: [[ARG_P:%[0-9]+]] = dim %arg1, 2 : memref<?x?x?xf32>
    %M = dim %A, 0 : memref<?x?xf32>
    %N = dim %A, 1 : memref<?x?xf32>
    %P = dim %B, 2 : memref<?x?x?xf32>
-// CHECK: [[C1:%[a-z0-9_]+]] = constant 0 : index
+// VEC1D: [[C1:%[a-z0-9_]+]] = constant 0 : index
    %cst0 = constant 0 : index
 // CHECK:for [[IV0:%[a-zA-Z0-9]+]] = 0 to [[MAPSHIFT]](){{.}}[[ARG_M]]{{.}} step 128
 // CHECK:   [[ALLOC0:%[a-zA-Z0-9]+]] = alloc() : memref<1xvector<128xf32>>
@@ -35,58 +40,58 @@ mlfunc @vec1d(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
    for %i0 = 0 to %M { // vectorized due to scalar -> vector 
      %a0 = load %A[%cst0, %cst0] : memref<?x?xf32>
    }
-// CHECK:for {{.*}} [[ARG_M]] {
+// VEC1D:for {{.*}} [[ARG_M]] {
    for %i1 = 0 to %M { // not vectorized 
      %a1 = load %A[%i1, %i1] : memref<?x?xf32>
    }
-// CHECK:   for %i{{[0-9]*}} = 0 to %{{[0-9]*}} {
+// VEC1D:   for %i{{[0-9]*}} = 0 to %{{[0-9]*}} {
    for %i2 = 0 to %M { // not vectorized, would vectorize with --test-fastest-varying=1 
      %r2 = affine_apply (d0) -> (d0) (%i2)
      %a2 = load %A[%r2#0, %cst0] : memref<?x?xf32>
    }
-// CHECK:for [[IV3:%[a-zA-Z0-9]+]] = 0 to [[MAPSHIFT]](){{.}}[[ARG_M]]{{.}} step 128
-// CHECK:   [[APP3:%[a-zA-Z0-9]+]] = affine_apply {{.*}}[[IV3]]
-// CHECK:   [[ALLOC3:%[a-zA-Z0-9]+]] = alloc() : memref<1xvector<128xf32>>
-// CHECK:   [[UNALIGNED3: %[0-9]*]] = "n_d_unaligned_load"(%arg0, [[C1]], [[APP3]], [[ALLOC3]], [[C0]]) : {{.*}}
-// CHECK:   %{{.*}} = load [[ALLOC3]]{{.}}[[C0]]{{.}} : memref<1xvector<128xf32>>
+// VEC1D:for [[IV3:%[a-zA-Z0-9]+]] = 0 to [[MAPSHIFT]](){{.}}[[ARG_M]]{{.}} step 128
+// VEC1D:   [[APP3:%[a-zA-Z0-9]+]] = affine_apply {{.*}}[[IV3]]
+// VEC1D:   [[ALLOC3:%[a-zA-Z0-9]+]] = alloc() : memref<1xvector<128xf32>>
+// VEC1D:   [[UNALIGNED3: %[0-9]*]] = "n_d_unaligned_load"(%arg0, [[C1]], [[APP3]], [[ALLOC3]], [[C0]]) : {{.*}}
+// VEC1D:   %{{.*}} = load [[ALLOC3]]{{.}}[[C0]]{{.}} : memref<1xvector<128xf32>>
    for %i3 = 0 to %M { // vectorized
      %r3 = affine_apply (d0) -> (d0) (%i3)
      %a3 = load %A[%cst0, %r3#0] : memref<?x?xf32>
    }
-// CHECK:for [[IV4:%[i0-9]+]] = 0 to [[MAPSHIFT]](){{.}}[[ARG_M]]{{.}} step 128
-// CHECK:   for [[IV5:%[i0-9]*]] = 0 to %{{[0-9]*}} {
-// CHECK:   [[APP5:%[0-9]+]] = affine_apply {{.*}}([[IV4]], [[IV5]])
-// CHECK:   [[ALLOC5:%[0-9]+]] = alloc() : memref<1xvector<128xf32>>
-// CHECK:   [[UNALIGNED5:%[0-9]*]] = "n_d_unaligned_load"(%arg0, [[APP5]]#0, [[APP5]]#1, [[ALLOC5]], [[C0]]) : {{.*}}
-// CHECK:   %{{.*}} = load [[ALLOC5]]{{.}}[[C0]]{{.}} : memref<1xvector<128xf32>>
+// VEC1D:for [[IV4:%[i0-9]+]] = 0 to [[MAPSHIFT]](){{.}}[[ARG_M]]{{.}} step 128
+// VEC1D:   for [[IV5:%[i0-9]*]] = 0 to %{{[0-9]*}} {
+// VEC1D:   [[APP5:%[0-9]+]] = affine_apply {{.*}}([[IV4]], [[IV5]])
+// VEC1D:   [[ALLOC5:%[0-9]+]] = alloc() : memref<1xvector<128xf32>>
+// VEC1D:   [[UNALIGNED5:%[0-9]*]] = "n_d_unaligned_load"(%arg0, [[APP5]]#0, [[APP5]]#1, [[ALLOC5]], [[C0]]) : {{.*}}
+// VEC1D:   %{{.*}} = load [[ALLOC5]]{{.}}[[C0]]{{.}} : memref<1xvector<128xf32>>
    for %i4 = 0 to %M { // vectorized 
      for %i5 = 0 to %N { // not vectorized, would vectorize with --test-fastest-varying=1
        %r5 = affine_apply #map1_t (%i4, %i5)
        %a5 = load %A[%r5#0, %r5#1] : memref<?x?xf32>
      }
    }
-// CHECK: for [[IV6:%[i0-9]*]] = 0 to %{{[0-9]*}} {
-// CHECK:   for [[IV7:%[i0-9]*]] = 0 to %{{[0-9]*}} {
+// VEC1D: for [[IV6:%[i0-9]*]] = 0 to %{{[0-9]*}} {
+// VEC1D:   for [[IV7:%[i0-9]*]] = 0 to %{{[0-9]*}} {
    for %i6 = 0 to %M { // not vectorized, would vectorize with --test-fastest-varying=1
      for %i7 = 0 to %N { // not vectorized, can never vectorize
        %r7 = affine_apply #map2 (%i6, %i7)
        %a7 = load %A[%r7#0, %r7#1] : memref<?x?xf32>
      }
    }
-// CHECK:for [[IV8:%[i0-9]+]] = 0 to [[MAPSHIFT]](){{.}}[[ARG_M]]{{.}} step 128
-// CHECK:   for [[IV9:%[i0-9]*]] = 0 to %{{[0-9]*}} {
-// CHECK:   [[APP9:%[0-9]+]] = affine_apply {{.*}}([[IV8]], [[IV9]])
-// CHECK:   [[ALLOC9:%[0-9]+]] = alloc() : memref<1xvector<128xf32>>
-// CHECK:   [[UNALIGNED9:%[0-9]*]] = "n_d_unaligned_load"(%arg0, [[APP9]]#0, [[APP9]]#1, [[ALLOC9]], [[C0]]) : {{.*}}
-// CHECK:   %{{.*}} = load [[ALLOC9]]{{.}}[[C0]]{{.}} : memref<1xvector<128xf32>>
+// VEC1D:for [[IV8:%[i0-9]+]] = 0 to [[MAPSHIFT]](){{.}}[[ARG_M]]{{.}} step 128
+// VEC1D:   for [[IV9:%[i0-9]*]] = 0 to %{{[0-9]*}} {
+// VEC1D:   [[APP9:%[0-9]+]] = affine_apply {{.*}}([[IV8]], [[IV9]])
+// VEC1D:   [[ALLOC9:%[0-9]+]] = alloc() : memref<1xvector<128xf32>>
+// VEC1D:   [[UNALIGNED9:%[0-9]*]] = "n_d_unaligned_load"(%arg0, [[APP9]]#0, [[APP9]]#1, [[ALLOC9]], [[C0]]) : {{.*}}
+// VEC1D:   %{{.*}} = load [[ALLOC9]]{{.}}[[C0]]{{.}} : memref<1xvector<128xf32>>
    for %i8 = 0 to %M { // vectorized
      for %i9 = 0 to %N {
        %r9 = affine_apply #map3 (%i8, %i9)
        %a9 = load %A[%r9#0, %r9#1] : memref<?x?xf32>
      }
    }
-// CHECK: for [[IV10:%[i0-9]*]] = 0 to %{{[0-9]*}} {
-// CHECK:   for [[IV11:%[i0-9]*]] = 0 to %{{[0-9]*}} {
+// VEC1D: for [[IV10:%[i0-9]*]] = 0 to %{{[0-9]*}} {
+// VEC1D:   for [[IV11:%[i0-9]*]] = 0 to %{{[0-9]*}} {
    for %i10 = 0 to %M { // not vectorized, need per load transposes 
      for %i11 = 0 to %N { // not vectorized, need per load transposes 
        %r11 = affine_apply #map1 (%i10, %i11)
@@ -95,9 +100,9 @@ mlfunc @vec1d(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
        store %a11, %A[%r12#0, %r12#1] : memref<?x?xf32>
      }
    }
-// CHECK: for [[IV12:%[i0-9]*]] = 0 to %{{[0-9]*}} {
-// CHECK:   for [[IV13:%[i0-9]*]] = 0 to %{{[0-9]*}} {
-// CHECK:     for [[IV14:%[i0-9]+]] = 0 to [[MAPSHIFT]](){{.}}[[ARG_P]]{{.}} step 128
+// VEC1D: for [[IV12:%[i0-9]*]] = 0 to %{{[0-9]*}} {
+// VEC1D:   for [[IV13:%[i0-9]*]] = 0 to %{{[0-9]*}} {
+// VEC1D:     for [[IV14:%[i0-9]+]] = 0 to [[MAPSHIFT]](){{.}}[[ARG_P]]{{.}} step 128
    for %i12 = 0 to %M { // not vectorized, can never vectorize
      for %i13 = 0 to %N { // not vectorized, can never vectorize
        for %i14 = 0 to %P { // vectorized
@@ -106,26 +111,204 @@ mlfunc @vec1d(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
        }
      }
    }
-// CHECK:  for %i{{[0-9]*}} = 0 to %{{[0-9]*}} {
+// VEC1D:  for %i{{[0-9]*}} = 0 to %{{[0-9]*}} {
    for %i15 = 0 to %M { // not vectorized due to condition below
      if #set0(%i15) {
        %a15 = load %A[%cst0, %cst0] : memref<?x?xf32>
      }
    }
-// CHECK:  for %i{{[0-9]*}} = 0 to %{{[0-9]*}} {
+// VEC1D:  for %i{{[0-9]*}} = 0 to %{{[0-9]*}} {
    for %i16 = 0 to %M { // not vectorized, can't vectorize a vector load
      %a16 = alloc(%M) : memref<?xvector<2xf32>>
      %l16 = load %a16[%i16] : memref<?xvector<2xf32>>
    }
-// CHECK: for %i{{[0-9]*}} = 0 to %{{[0-9]*}} {
-// CHECK:   for [[IV18:%[a-zA-Z0-9]+]] = 0 to [[MAPSHIFT]](){{.}}[[ARG_M]]{{.}} step 128
-// CHECK:     [[ALLOC18:%[a-zA-Z0-9]+]] = alloc() : memref<1xvector<128xf32>>
-// CHECK:     [[UNALIGNED18: %[0-9]*]] = "n_d_unaligned_load"(%arg0, [[C1]], [[C1]], [[ALLOC18]], [[C0]]) : {{.*}}
-// CHECK:     %{{.*}} = load [[ALLOC18]]{{.}}[[C0]]{{.}} : memref<1xvector<128xf32>>
+// VEC1D: for %i{{[0-9]*}} = 0 to %{{[0-9]*}} {
+// VEC1D:   for [[IV18:%[a-zA-Z0-9]+]] = 0 to [[MAPSHIFT]](){{.}}[[ARG_M]]{{.}} step 128
+// VEC1D:     [[ALLOC18:%[a-zA-Z0-9]+]] = alloc() : memref<1xvector<128xf32>>
+// VEC1D:     [[UNALIGNED18: %[0-9]*]] = "n_d_unaligned_load"(%arg0, [[C1]], [[C1]], [[ALLOC18]], [[C0]]) : {{.*}}
+// VEC1D:     %{{.*}} = load [[ALLOC18]]{{.}}[[C0]]{{.}} : memref<1xvector<128xf32>>
    for %i17 = 0 to %M { // not vectorized, the 1-D pattern that matched %i18 in DFS post-order prevents vectorizing %i17
-     for %i18 = 0 to %M { // vectorized due to scalar -> vector 
+     for %i18 = 0 to %M { // vectorized due to scalar -> vector
        %a18 = load %A[%cst0, %cst0] : memref<?x?xf32>
      }
    }
    return
 }
+
+// VEC2D: [[MAPSHIFT0:#map[0-9]*]] = ()[s0] -> (s0 + 31)
+// VEC2D: [[MAPSHIFT1:#map[0-9]*]] = ()[s0] -> (s0 + 255)
+// VEC2D_T: [[MAPSHIFT0:#map[0-9]*]] = ()[s0] -> (s0 + 31)
+// VEC2D_T: [[MAPSHIFT1:#map[0-9]*]] = ()[s0] -> (s0 + 255)
+// VEC2D_O: [[MAPSHIFT0:#map[0-9]*]] = ()[s0] -> (s0 + 31)
+// VEC2D_O: [[MAPSHIFT1:#map[0-9]*]] = ()[s0] -> (s0 + 255)
+// VEC2D_OT: [[MAPSHIFT0:#map[0-9]*]] = ()[s0] -> (s0 + 31)
+// VEC2D_OT: [[MAPSHIFT1:#map[0-9]*]] = ()[s0] -> (s0 + 255)
+mlfunc @vec2d(%A : memref<?x?x?xf32>) {
+   %M = dim %A, 0 : memref<?x?x?xf32>
+   %N = dim %A, 1 : memref<?x?x?xf32>
+   %P = dim %A, 2 : memref<?x?x?xf32>
+   // VEC2D: for %i{{[0-9]*}} = 0 to %{{[0-9]*}} {
+   // VEC2D:   for {{.*}} = 0 to [[MAPSHIFT0]](){{.*}} step 32
+   // VEC2D:     for {{.*}} = 0 to [[MAPSHIFT1]](){{.*}} step 256
+   // For the case: --test-fastest-varying=1 --test-fastest-varying=0:
+   // for %i0 = 0 to %0 {
+   //   for %i1 = 0 to #map6()[%1] step 32 {
+   //     for %i2 = 0 to #map7()[%2] step 256 {
+   //       %3 = alloc() : memref<1xvector<32x256xf32>>
+   //       %4 = "n_d_unaligned_load"(%arg0, %i0, %i1, %i2, %3, %c0) : (memref<?x?x?xf32>, index, index, index, memref<1xvector<32x256xf32>>, index) -> (memref<?x?x?xf32>, index, index, index, memref<1xvector<32x256xf32>>, index)
+   //       %5 = load %3[%c0] : memref<1xvector<32x256xf32>>
+   //
+   // VEC2D_T: for %i{{[0-9]*}} = 0 to %{{[0-9]*}} {
+   // VEC2D_T:   for %i{{[0-9]*}} = 0 to %{{[0-9]*}} {
+   // VEC2D_T:     for %i{{[0-9]*}} = 0 to %{{[0-9]*}} {
+   // For the case: --test-fastest-varying=0 --test-fastest-varying=1 no
+   // vectorization happens because of loop nesting order (i.e. only one of
+   // VEC2D and VEC2D_T may ever vectorize).
+   //
+   // VEC2D_O: for {{.*}} = 0 to [[MAPSHIFT0]](){{.*}} step 32
+   // VEC2D_O:   for %i{{[0-9]*}} = 0 to %{{[0-9]*}} {
+   // VEC2D_O:     for {{.*}} = 0 to [[MAPSHIFT1]](){{.*}} step 256
+   // For the case: --test-fastest-varying=2 --test-fastest-varying=0:
+   // for %i0 = 0 to #map6()[%0] step 32 {
+   //   for %i1 = 0 to %1 {
+   //     for %i2 = 0 to #map7()[%2] step 256 {
+   //       %3 = alloc() : memref<1xvector<32x256xf32>>
+   //       %4 = "n_d_unaligned_load"(%arg0, %i0, %i1, %i2, %3, %c0) : (memref<?x?x?xf32>, index, index, index, memref<1xvector<32x256xf32>>, index) -> (memref<?x?x?xf32>, index, index, index, memref<1xvector<32x256xf32>>, index)
+   //       %5 = load %3[%c0] : memref<1xvector<32x256xf32>>
+   //
+   // VEC2D_OT: for %i{{[0-9]*}} = 0 to %{{[0-9]*}} {
+   // VEC2D_OT:   for %i{{[0-9]*}} = 0 to %{{[0-9]*}} {
+   // VEC2D_OT:     for %i{{[0-9]*}} = 0 to %{{[0-9]*}} {
+   // For the case: --test-fastest-varying=0 --test-fastest-varying=2 no
+   // vectorization happens because of loop nesting order(i.e. only one of
+   // VEC2D_O and VEC2D_OT may ever vectorize).
+   for %i0 = 0 to %M {
+     for %i1 = 0 to %N {
+       for %i2 = 0 to %P {
+         %a2 = load %A[%i0, %i1, %i2] : memref<?x?x?xf32>
+       }
+     }
+   }
+   // VEC2D: for %i{{[0-9]*}} = 0 to %{{[0-9]*}} {
+   // VEC2D:   for %i{{[0-9]*}} = 0 to %{{[0-9]*}} {
+   // VEC2D:     for %i{{[0-9]*}} = 0 to %{{[0-9]*}} {
+   // For the case: --test-fastest-varying=1 --test-fastest-varying=0 no
+   // vectorization happens because of loop nesting order (i.e. only one of
+   // VEC2D and VEC2D_T may ever vectorize).
+   //
+   // VEC2D_T: for {{.*}} = 0 to [[MAPSHIFT0]](){{.*}} step 32
+   // VEC2D_T:   for %i{{[0-9]*}} = 0 to %{{[0-9]*}} {
+   // VEC2D_T:     for {{.*}} = 0 to [[MAPSHIFT1]](){{.*}} step 256
+   // For the case: --test-fastest-varying=0 --test-fastest-varying=1:
+   // for %i3 = 0 to #map1()[%0] step 32 {
+   //   for %i4 = 0 to %1 {
+   //     for %i5 = 0 to #map2()[%2] step 256 {
+   //       %4 = alloc() : memref<1xvector<32x256xf32>>
+   //       %5 = "n_d_unaligned_load"(%arg0, %i4, %i5, %i3, %4, %c0) : (memref<?x?x?xf32>, index, index, index, memref<1xvector<32x256xf32>>, index) -> (memref<?x?x?xf32>, index, index, index, memref<1xvector<32x256xf32>>, index)
+   //       %6 = load %4[%c0] : memref<1xvector<32x256xf32>>
+   //
+   // VEC2D_O: for %i{{[0-9]*}} = 0 to %{{[0-9]*}} {
+   // VEC2D_O:   for %i{{[0-9]*}} = 0 to %{{[0-9]*}} {
+   // VEC2D_O:     for %i{{[0-9]*}} = 0 to %{{[0-9]*}} {
+   // For the case: --test-fastest-varying=2 --test-fastest-varying=0 no
+   // vectorization happens because of loop nesting order(i.e. only one of
+   // VEC2D_O and VEC2D_OT may ever vectorize).
+   //
+   // VEC2D_OT: for {{.*}} = 0 to [[MAPSHIFT0]](){{.*}} step 32
+   // VEC2D_OT:   for {{.*}} = 0 to [[MAPSHIFT1]](){{.*}} step 256
+   // VEC2D_OT:     for %i{{[0-9]*}} = 0 to %{{[0-9]*}} {
+   // For the case: --test-fastest-varying=0 --test-fastest-varying=2:
+   // for %i3 = 0 to #map6()[%0] step 32 {
+   //   for %i4 = 0 to #map7()[%1] step 256 {
+   //     for %i5 = 0 to %2 {
+   //       %4 = alloc() : memref<1xvector<32x256xf32>>
+   //       %5 = "n_d_unaligned_load"(%arg0, %i4, %i5, %i3, %4, %c0) : (memref<?x?x?xf32>, index, index, index, memref<1xvector<32x256xf32>>, index) -> (memref<?x?x?xf32>, index, index, index, memref<1xvector<32x256xf32>>, index)
+   //       %6 = load %4[%c0] : memref<1xvector<32x256xf32>>
+   for %i3 = 0 to %M {
+     for %i4 = 0 to %N {
+       for %i5 = 0 to %P {
+         %a5 = load %A[%i4, %i5, %i3] : memref<?x?x?xf32>
+       }
+     }
+   }
+   return
+}
+
+mlfunc @vec2d_imperfectly_nested(%A : memref<?x?x?xf32>) {
+   %0 = dim %A, 0 : memref<?x?x?xf32>
+   %1 = dim %A, 1 : memref<?x?x?xf32>
+   %2 = dim %A, 2 : memref<?x?x?xf32>
+   // VEC2D_T: for %i0 = 0 to #map{{.}}()[%0] step 32 {
+   // VEC2D_T:   for %i1 = 0 to #map{{.}}()[%1] step 256 {
+   // VEC2D_T:     for %i2 = 0 to %2 {
+   // VEC2D_T:       %3 = alloc() : memref<1xvector<32x256xf32>>
+   // VEC2D_T:       %4 = "n_d_unaligned_load"(%arg0, %i2, %i1, %i0, %3, %c0) : (memref<?x?x?xf32>, index, index, index, memref<1xvector<32x256xf32>>, index) -> (memref<?x?x?xf32>, index, index, index, memref<1xvector<32x256xf32>>, index)
+   // VEC2D_T:       %5 = load %3[%c0] : memref<1xvector<32x256xf32>>
+   // VEC2D_T:   for %i3 = 0 to %1 {
+   // VEC2D_T:     for %i4 = 0 to #map{{.}}()[%2] step 256 {
+   // VEC2D_T:       %6 = alloc() : memref<1xvector<32x256xf32>>
+   // VEC2D_T:       %7 = "n_d_unaligned_load"(%arg0, %i3, %i4, %i0, %6, %c0) : (memref<?x?x?xf32>, index, index, index, memref<1xvector<32x256xf32>>, index) -> (memref<?x?x?xf32>, index, index, index, memref<1xvector<32x256xf32>>, index)
+   // VEC2D_T:       %8 = load %6[%c0] : memref<1xvector<32x256xf32>>
+   // VEC2D_T:     for %i5 = 0 to #map{{.}}()[%2] step 256 {
+   // VEC2D_T:       %9 = alloc() : memref<1xvector<32x256xf32>>
+   // VEC2D_T:       %10 = "n_d_unaligned_load"(%arg0, %i3, %i5, %i0, %9, %c0) : (memref<?x?x?xf32>, index, index, index, memref<1xvector<32x256xf32>>, index) -> (memref<?x?x?xf32>, index, index, index, memref<1xvector<32x256xf32>>, index)
+   // VEC2D_T:       %11 = load %9[%c0] : memref<1xvector<32x256xf32>>
+   //
+   // VEC2D_OT: for %i0 = 0 to #map{{.}}()[%0] step 32 {
+   // VEC2D_OT:   for %i1 = 0 to %1 {
+   // VEC2D_OT:     for %i2 = 0 to #map{{.}}()[%2] step 256 {
+   // VEC2D_OT:       %3 = alloc() : memref<1xvector<32x256xf32>>
+   // VEC2D_OT:       %4 = "n_d_unaligned_load"(%arg0, %i2, %i1, %i0, %3, %c0) : (memref<?x?x?xf32>, index, index, index, memref<1xvector<32x256xf32>>, index) -> (memref<?x?x?xf32>, index, index, index, memref<1xvector<32x256xf32>>, index)
+   // VEC2D_OT:       %5 = load %3[%c0] : memref<1xvector<32x256xf32>>
+   // VEC2D_OT:   for %i3 = 0 to #map{{.}}()[%1] step 256 {
+   // VEC2D_OT:     for %i4 = 0 to %2 {
+   // VEC2D_OT:       %6 = alloc() : memref<1xvector<32x256xf32>>
+   // VEC2D_OT:       %7 = "n_d_unaligned_load"(%arg0, %i3, %i4, %i0, %6, %c0) : (memref<?x?x?xf32>, index, index, index, memref<1xvector<32x256xf32>>, index) -> (memref<?x?x?xf32>, index, index, index, memref<1xvector<32x256xf32>>, index)
+   // VEC2D_OT:       %8 = load %6[%c0] : memref<1xvector<32x256xf32>>
+   // VEC2D_OT:     for %i5 = 0 to %2 {
+   // VEC2D_OT:       %9 = alloc() : memref<1xvector<32x256xf32>>
+   // VEC2D_OT:       %10 = "n_d_unaligned_load"(%arg0, %i3, %i5, %i0, %9, %c0) : (memref<?x?x?xf32>, index, index, index, memref<1xvector<32x256xf32>>, index) -> (memref<?x?x?xf32>, index, index, index, memref<1xvector<32x256xf32>>, index)
+   // VEC2D_OT:       %11 = load %9[%c0] : memref<1xvector<32x256xf32>>
+   for %i0 = 0 to %0 {
+     for %i1 = 0 to %1 {
+       for %i2 = 0 to %2 {
+         %a2 = load %A[%i2, %i1, %i0] : memref<?x?x?xf32>
+       }
+     }
+     for %i3 = 0 to %1 {
+       for %i4 = 0 to %2 {
+         %a4 = load %A[%i3, %i4, %i0] : memref<?x?x?xf32>
+       }
+       for %i5 = 0 to %2 {
+         %a5 = load %A[%i3, %i5, %i0] : memref<?x?x?xf32>
+       }
+     }
+   }
+   return
+}
+
+mlfunc @vec3d(%A : memref<?x?x?xf32>) {
+   %0 = dim %A, 0 : memref<?x?x?xf32>
+   %1 = dim %A, 1 : memref<?x?x?xf32>
+   %2 = dim %A, 2 : memref<?x?x?xf32>
+   // VEC3D: for %i0 = 0 to %0 {
+   // VEC3D:   for %i1 = 0 to %0 {
+   // VEC3D:     for %i2 = 0 to #map{{.}}()[%0] step 32 {
+   // VEC3D:       for %i3 = 0 to #map{{.}}()[%1] step 64 {
+   // VEC3D:         for %i4 = 0 to #map{{.}}()[%2] step 256 {
+   // VEC3D:           %3 = alloc() : memref<1xvector<32x64x256xf32>>
+   // VEC3D:           %4 = "n_d_unaligned_load"(%arg0, %i2, %i3, %i4, %3, %c0) : (memref<?x?x?xf32>, index, index, index, memref<1xvector<32x64x256xf32>>, index) -> (memref<?x?x?xf32>, index, index, index, memref<1xvector<32x64x256xf32>>, index)
+   // VEC3D:           %5 = load %3[%c0] : memref<1xvector<32x64x256xf32>>
+   for %t0 = 0 to %0 {
+     for %t1 = 0 to %0 {
+       for %i0 = 0 to %0 {
+         for %i1 = 0 to %1 {
+           for %i2 = 0 to %2 {
+             %a2 = load %A[%i0, %i1, %i2] : memref<?x?x?xf32>
+           }
+         }
+       }
+     }
+   }
+   return
+}
\ No newline at end of file