[mlir][linalg] Use affine apply in im2col gather index calculations
authorQuinn Dawkins <quinn@nod-labs.com>
Fri, 24 Mar 2023 15:44:51 +0000 (11:44 -0400)
committerQuinn Dawkins <quinn@nod-labs.com>
Fri, 24 Mar 2023 15:49:15 +0000 (11:49 -0400)
Differential Revision: https://reviews.llvm.org/D146816

mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir

index 58a23e2..491c533 100644 (file)
@@ -6,6 +6,7 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Dialect/Affine/Utils.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
@@ -41,37 +42,17 @@ static Value createMul(Location loc, Value x, Value y, OpBuilder &builder) {
   return builder.create<arith::MulFOp>(loc, x, y);
 }
 
-// Unrolls the given composite `index` into a set of subindices with maximum
-// iteration ranges specified by `factors` according to the following
-// assumptions:
-//   1. The iteration range for `index` is [0, f1 * f2 * ... * fn] i.e. the
-//   product of the given list of factors
-//   2. The iterators corresponding to the entries in `factors` are ordered from
-//   slowest to fastest varying
-// Each subindex is then computed as:
-//    subindex[i] = floor( (index % (fi * ... * fn)) / (fi-1 * ... * fn) )
-static SmallVector<Value, 3> unrollIndex(OpBuilder &b, Location loc,
-                                         Value index,
-                                         ArrayRef<int64_t> factors) {
+// Delinearizes the given composite `index` by the basis specified in `factors`.
+static SmallVector<Value> unrollIndex(OpBuilder &b, Location loc, Value index,
+                                      ArrayRef<int64_t> factors) {
   assert(factors.size() >= 1 && "empty factor list");
-  SmallVector<Value, 3> indices(factors.size());
-  int64_t runningProd = 1;
-  for (int i = factors.size() - 1, end = 0; i >= end; i--) {
-    Value unrolledIndex = index;
-    if (i > 0) {
-      Value modBase = b.create<arith::ConstantOp>(
-          loc, b.getIndexAttr(runningProd * factors[i]));
-      unrolledIndex = b.create<arith::RemUIOp>(loc, unrolledIndex, modBase);
-    }
-    if (runningProd > 1) {
-      Value divDenom =
-          b.create<arith::ConstantOp>(loc, b.getIndexAttr(runningProd));
-      unrolledIndex = b.create<arith::DivUIOp>(loc, unrolledIndex, divDenom);
-    }
-    runningProd *= factors[i];
-    indices[i] = unrolledIndex;
-  }
-  return indices;
+  SmallVector<Value> basis;
+  for (int64_t f : factors)
+    basis.push_back(b.create<arith::ConstantOp>(loc, b.getIndexAttr(f)));
+  FailureOr<SmallVector<Value>> multiIndex =
+      delinearizeIndex(b, loc, index, basis);
+  assert(!failed(multiIndex) && "Failed to linearize img2col index");
+  return *multiIndex;
 }
 
 // Given indices corresponding to iterators in the output (oIndex) and filter
@@ -79,9 +60,10 @@ static SmallVector<Value, 3> unrollIndex(OpBuilder &b, Location loc,
 // input as `oIndex * stride + fIndex`.
 static Value getConvolvedIndex(OpBuilder &b, Location loc, Value oIndex,
                                Value fIndex, int64_t stride) {
-  Value strideVal = b.create<arith::ConstantOp>(loc, b.getIndexAttr(stride));
-  Value convIndex = b.create<arith::MulIOp>(loc, oIndex, strideVal);
-  return b.create<arith::AddIOp>(loc, convIndex, fIndex);
+  AffineExpr oExpr, fExpr;
+  bindSymbols(b.getContext(), oExpr, fExpr);
+  AffineMap convMap = AffineMap::get(0, 2, stride * oExpr + fExpr);
+  return makeComposedAffineApply(b, loc, convMap, ValueRange{oIndex, fIndex});
 }
 
 FailureOr<std::pair<Operation *, Operation *>>
@@ -159,12 +141,12 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp) {
         Value kIndex = nestedBuilder.create<linalg::IndexOp>(loc, 2);
 
         // Recover the original iteration indices from the problem/input sizes.
-        SmallVector<Value, 3> mIndices = unrollIndex(
+        SmallVector<Value> mIndices = unrollIndex(
             nestedBuilder, nestedLoc, mIndex, ArrayRef<int64_t>{oh, ow});
         auto ohIndex = mIndices[0];
         auto owIndex = mIndices[1];
 
-        SmallVector<Value, 3> kIndices = unrollIndex(
+        SmallVector<Value> kIndices = unrollIndex(
             nestedBuilder, nestedLoc, kIndex, ArrayRef<int64_t>{fh, fw, ic});
         auto fhIndex = kIndices[0];
         auto fwIndex = kIndices[1];
@@ -443,13 +425,13 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNchwFchwOp convOp) {
         Value nIndex = nestedBuilder.create<linalg::IndexOp>(loc, 2);
 
         // Recover the original iteration indices from the problem/input sizes.
-        SmallVector<Value, 3> kIndices = unrollIndex(
+        SmallVector<Value> kIndices = unrollIndex(
             nestedBuilder, nestedLoc, kIndex, ArrayRef<int64_t>{ic, fh, fw});
         auto icIndex = kIndices[0];
         auto fhIndex = kIndices[1];
         auto fwIndex = kIndices[2];
 
-        SmallVector<Value, 3> nIndices = unrollIndex(
+        SmallVector<Value> nIndices = unrollIndex(
             nestedBuilder, nestedLoc, nIndex, ArrayRef<int64_t>{oh, ow});
         auto ohIndex = nIndices[0];
         auto owIndex = nIndices[1];
index ffcba10..38c6349 100644 (file)
@@ -37,29 +37,12 @@ transform.sequence failures(propagate) {
 // CHECK: %[[MINDEX:.+]] = linalg.index 1 : index
 // CHECK: %[[KINDEX:.+]] = linalg.index 2 : index
 
-// Unrolled output shape indices.
-// CHECK: %[[C14:.+]] = arith.constant 14 : index
-// CHECK: %[[OWINDEX:.+]] = arith.remui %[[MINDEX]], %[[C14]] : index
-// CHECK: %[[C14_1:.+]] = arith.constant 14 : index
-// CHECK: %[[OHINDEX:.+]] = arith.divui %[[MINDEX]], %[[C14_1]] : index
+// Compute input channel/convolved indices.
+// CHECK: %[[ICINDEX:.+]] = affine.apply affine_map<(d0) -> (d0 mod 4)>(%[[KINDEX]])
+// CHECK: %[[CONVH:.+]] = affine.apply affine_map<(d0, d1) -> (d0 floordiv 14 + d1 floordiv 12)>(%[[MINDEX]], %[[KINDEX]])
+// CHECK: %[[CONVW:.+]] = affine.apply affine_map<(d0, d1) -> (d0 mod 14 + (d1 mod 12) floordiv 4)>(%[[MINDEX]], %[[KINDEX]])
 
-// Unrolled filter shape indices.
-// CHECK: %[[C4:.+]] = arith.constant 4 : index
-// CHECK: %[[ICINDEX:.+]] = arith.remui %[[KINDEX]], %[[C4]] : index
-// CHECK: %[[C12:.+]] = arith.constant 12 : index
-// CHECK: %[[FWREM:.+]] = arith.remui %[[KINDEX]], %[[C12]] : index
-// CHECK: %[[C4_2:.+]] = arith.constant 4 : index
-// CHECK: %[[FWINDEX:.+]] = arith.divui %[[FWREM]], %[[C4_2]] : index
-// CHECK: %[[C12_3:.+]] = arith.constant 12 : index
-// CHECK: %[[FHINDEX:.+]] = arith.divui %[[KINDEX]], %[[C12_3]] : index
-
-// Compute input indices.
-// CHECK: %[[SH:.+]] = arith.constant 1 : index
-// CHECK: %[[STRIDEDOH:.+]] = arith.muli %[[OHINDEX]], %[[SH]] : index
-// CHECK: %[[CONVH:.+]] = arith.addi %[[STRIDEDOH]], %[[FHINDEX]] : index
-// CHECK: %[[SW:.+]] = arith.constant 1 : index
-// CHECK: %[[STRIDEDOW:.+]] = arith.muli %[[OWINDEX]], %[[SW]] : index
-// CHECK: %[[CONVW:.+]] = arith.addi %[[STRIDEDOW]], %[[FWINDEX]] : index
+// Extract from the input tensor.
 // CHECK: %[[EXTRACTED_INPUT:.+]] = tensor.extract
 // CHECK-SAME: %{{.+}}{{\[}}%[[BINDEX]], %[[CONVH]], %[[CONVW]], %[[ICINDEX]]] : tensor<1x16x16x4xf32>
 // CHECK: linalg.yield %[[EXTRACTED_INPUT]] : f32
@@ -234,6 +217,13 @@ transform.sequence failures(propagate) {
 // -----
 
 //  CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+
+//  Im2col maps
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (d0 floordiv 9)>
+//  CHECK-DAG: #[[MAP7:.+]] = affine_map<(d0, d1) -> (d0 floordiv 14 + (d1 mod 9) floordiv 3)>
+//  CHECK-DAG: #[[MAP8:.+]] = affine_map<(d0, d1) -> (d0 + d1 - (d0 floordiv 14) * 14 - (d1 floordiv 3) * 3)>
+
+
 //  CHECK-DAG: #[[LHSMAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d3)>
 //  CHECK-DAG: #[[RHSMAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
 //  CHECK-DAG: #[[RESMAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
@@ -252,29 +242,12 @@ transform.sequence failures(propagate) {
 //      CHECK:       %[[KINDEX:.+]] = linalg.index 1 : index
 //      CHECK:       %[[NINDEX:.+]] = linalg.index 2 : index
 
-//      Unrolled filter shape indices.
-//      CHECK:       %[[C3:.+]] = arith.constant 3 : index
-//      CHECK:       %[[FWINDEX:.+]] = arith.remui %[[KINDEX]], %[[C3]] : index
-//      CHECK:       %[[C9:.+]] = arith.constant 9 : index
-//      CHECK:       %[[FHREM:.+]] = arith.remui %[[KINDEX]], %[[C9]] : index
-//      CHECK:       %[[C3_1:.+]] = arith.constant 3 : index
-//      CHECK:       %[[FHINDEX:.+]] = arith.divui %[[FHREM]], %[[C3_1]] : index
-//      CHECK:       %[[C9_2:.+]] = arith.constant 9 : index
-//      CHECK:       %[[ICINDEX:.+]] = arith.divui %[[KINDEX]], %[[C9_2]] : index
-
-//      Unrolled output shape indices.
-//      CHECK:       %[[C14:.+]] = arith.constant 14 : index
-//      CHECK:       %[[OWINDEX:.+]] = arith.remui %[[NINDEX]], %[[C14]] : index
-//      CHECK:       %[[C14_3:.+]] = arith.constant 14 : index
-//      CHECK:       %[[OHINDEX:.+]] = arith.divui %[[NINDEX]], %[[C14_3]] : index
+//      Compute input channel/convolved indices.
+//      CHECK:       %[[ICINDEX:.+]] = affine.apply #[[MAP1]](%[[KINDEX]])
+//      CHECK:       %[[CONVH:.+]] = affine.apply #[[MAP7]](%[[NINDEX]], %[[KINDEX]])
+//      CHECK:       %[[CONVW:.+]] = affine.apply #[[MAP8]](%[[NINDEX]], %[[KINDEX]])
 
-//      Compute input indices.
-//      CHECK:       %[[SH:.+]] = arith.constant 1 : index
-//      CHECK:       %[[STRIDEDOH:.+]] = arith.muli %[[OHINDEX]], %[[SH]] : index
-//      CHECK:       %[[CONVH:.+]] = arith.addi %[[STRIDEDOH]], %[[FHINDEX]] : index
-//      CHECK:       %[[SW:.+]] = arith.constant 1 : index
-//      CHECK:       %[[STRIDEDOW:.+]] = arith.muli %[[OWINDEX]], %[[SW]] : index
-//      CHECK:       %[[CONVW:.+]] = arith.addi %[[STRIDEDOW]], %[[FWINDEX]] : index
+//      Extract from the input tensor.
 //      CHECK:       %[[EXTRACTED_INPUT:.+]] = tensor.extract
 //      CHECK-SAME:  %[[INPUT]]{{\[}}%[[BINDEX]], %[[ICINDEX]], %[[CONVH]], %[[CONVW]]] : tensor<8x4x16x16xf32>
 //      CHECK: linalg.yield %[[EXTRACTED_INPUT]] : f32