[mlir][Linalg] Reimplement and extend getStridesAndOffset
authorNicolas Vasilache <ntv@google.com>
Thu, 2 Jan 2020 20:25:21 +0000 (15:25 -0500)
committerNicolas Vasilache <ntv@google.com>
Mon, 6 Jan 2020 14:41:38 +0000 (09:41 -0500)
Summary: This diff reimplements getStridesAndOffset in a significantly simpler way by operating on the AffineExpr and calling into simplifyAffineExpr instead of rolling its own saturating arithmetic.

As a consequence it becomes quite simple to extend the behavior of getStridesAndOffset to encompass more cases by manipulating the AffineExpr more directly.
The divisions are still filtered out and continue to yield fully dynamic strides.
Simplifying the divisions is left for a later time if compelling use cases arise.

Relevant tests are added.

Reviewers: ftynse

Subscribers: mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, arpith-jacob, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D72098

mlir/lib/IR/StandardTypes.cpp
mlir/test/AffineOps/memref-stride-calculation.mlir

index 441b59e..55d7baa 100644 (file)
@@ -456,126 +456,73 @@ static AffineExpr makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
     auto sym = getAffineSymbolExpr(nSymbols++, context);
     expr = expr ? expr + d * sym : d * sym;
   }
-  return expr;
-}
-
-// Factored out common logic to update `strides` and `seen` for `dim` with value
-// `val`. This handles both saturated and unsaturated cases.
-static void accumulateStrides(MutableArrayRef<int64_t> strides,
-                              MutableArrayRef<bool> seen, unsigned pos,
-                              int64_t val) {
-  if (!seen[pos]) {
-    // Newly seen case, sets value
-    strides[pos] = val;
-    seen[pos] = true;
-    return;
-  }
-  if (strides[pos] != MemRefType::getDynamicStrideOrOffset())
-    // Already seen case accumulates unless they are already saturated.
-    strides[pos] += val;
-}
-
-// This sums multiple offsets as they are seen. In the particular case of
-// accumulating a dynamic offset with either a static of dynamic one, this
-// saturates to MemRefType::getDynamicStrideOrOffset().
-static void accumulateOffset(int64_t &offset, bool &seenOffset, int64_t val) {
-  if (!seenOffset) {
-    // Newly seen case, sets value
-    offset = val;
-    seenOffset = true;
-    return;
-  }
-  if (offset != MemRefType::getDynamicStrideOrOffset())
-    // Already seen case accumulates unless they are already saturated.
-    offset += val;
+  return simplifyAffineExpr(expr, rank, nSymbols);
 }
 
-/// Takes a single AffineExpr `e` and populates the `strides` and `seen` arrays
-/// with the strides values for each dim position and whether a value exists at
-/// that position, respectively.
+// Fallback cases for terminal dim/sym/cst that are not part of a binary op (
+// i.e. single term). Accumulate the AffineExpr into the existing one.
+static void extractStridesFromTerm(AffineExpr e,
+                                   AffineExpr multiplicativeFactor,
+                                   MutableArrayRef<AffineExpr> strides,
+                                   AffineExpr &offset) {
+  if (auto dim = e.dyn_cast<AffineDimExpr>())
+    strides[dim.getPosition()] =
+        strides[dim.getPosition()] + multiplicativeFactor;
+  else
+    offset = offset + e * multiplicativeFactor;
+}
+
+/// Takes a single AffineExpr `e` and populates the `strides` array with the
+/// strides expressions for each dim position.
 /// The convention is that the strides for dimensions d0, .. dn appear in
 /// order to make indexing intuitive into the result.
-static void extractStrides(AffineExpr e, MutableArrayRef<int64_t> strides,
-                           int64_t &offset, MutableArrayRef<bool> seen,
-                           bool &seenOffset, bool &failed) {
+static LogicalResult extractStrides(AffineExpr e,
+                                    AffineExpr multiplicativeFactor,
+                                    MutableArrayRef<AffineExpr> strides,
+                                    AffineExpr &offset) {
   auto bin = e.dyn_cast<AffineBinaryOpExpr>();
-  if (!bin)
-    return;
+  if (!bin) {
+    extractStridesFromTerm(e, multiplicativeFactor, strides, offset);
+    return success();
+  }
 
   if (bin.getKind() == AffineExprKind::CeilDiv ||
       bin.getKind() == AffineExprKind::FloorDiv ||
-      bin.getKind() == AffineExprKind::Mod) {
-    failed = true;
-    return;
-  }
+      bin.getKind() == AffineExprKind::Mod)
+    return failure();
+
   if (bin.getKind() == AffineExprKind::Mul) {
-    // LHS may be more complex than just a single dim (e.g. multiple syms and
-    // dims). Bail out for now and revisit when we have evidence this is needed.
     auto dim = bin.getLHS().dyn_cast<AffineDimExpr>();
-    if (!dim) {
-      failed = true;
-      return;
-    }
-    auto cst = bin.getRHS().dyn_cast<AffineConstantExpr>();
-    if (!cst) {
-      strides[dim.getPosition()] = MemRefType::getDynamicStrideOrOffset();
-      seen[dim.getPosition()] = true;
-    } else {
-      accumulateStrides(strides, seen, dim.getPosition(), cst.getValue());
+    if (dim) {
+      strides[dim.getPosition()] =
+          strides[dim.getPosition()] + bin.getRHS() * multiplicativeFactor;
+      return success();
     }
-    return;
+    // LHS and RHS may both contain complex expressions of dims. Try one path
+    // and if it fails try the other. This is guaranteed to succeed because
+    // only one path may have a `dim`, otherwise this is not an AffineExpr in
+    // the first place.
+    if (bin.getLHS().isSymbolicOrConstant())
+      return extractStrides(bin.getRHS(), multiplicativeFactor * bin.getLHS(),
+                            strides, offset);
+    return extractStrides(bin.getLHS(), multiplicativeFactor * bin.getRHS(),
+                          strides, offset);
   }
+
   if (bin.getKind() == AffineExprKind::Add) {
-    for (auto e : {bin.getLHS(), bin.getRHS()}) {
-      if (auto cst = e.dyn_cast<AffineConstantExpr>()) {
-        // Independent constants cumulate.
-        accumulateOffset(offset, seenOffset, cst.getValue());
-      } else if (auto sym = e.dyn_cast<AffineSymbolExpr>()) {
-        // Independent symbols saturate.
-        offset = MemRefType::getDynamicStrideOrOffset();
-        seenOffset = true;
-      } else if (auto dim = e.dyn_cast<AffineDimExpr>()) {
-        // Independent symbols cumulate 1.
-        accumulateStrides(strides, seen, dim.getPosition(), 1);
-      }
-      // Sum of binary ops dispatch to the respective exprs.
-    }
-    return;
+    auto res1 =
+        extractStrides(bin.getLHS(), multiplicativeFactor, strides, offset);
+    auto res2 =
+        extractStrides(bin.getRHS(), multiplicativeFactor, strides, offset);
+    return success(succeeded(res1) && succeeded(res2));
   }
-  llvm_unreachable("unexpected binary operation");
-}
 
-// Fallback cases for terminal dim/sym/cst that are not part of a binary op (
-// i.e. single term).
-static void extractStridesFromTerm(AffineExpr e,
-                                   MutableArrayRef<int64_t> strides,
-                                   int64_t &offset, MutableArrayRef<bool> seen,
-                                   bool &seenOffset) {
-  if (auto cst = e.dyn_cast<AffineConstantExpr>()) {
-    assert(!seenOffset && "unexpected `seen` bit with single term");
-    offset = cst.getValue();
-    seenOffset = true;
-    return;
-  }
-  if (auto sym = e.dyn_cast<AffineSymbolExpr>()) {
-    assert(!seenOffset && "unexpected `seen` bit with single term");
-    offset = MemRefType::getDynamicStrideOrOffset();
-    seenOffset = true;
-    return;
-  }
-  if (auto dim = e.dyn_cast<AffineDimExpr>()) {
-    assert(!seen[dim.getPosition()] &&
-           "unexpected `seen` bit with single term");
-    strides[dim.getPosition()] = 1;
-    seen[dim.getPosition()] = true;
-    return;
-  }
   llvm_unreachable("unexpected binary operation");
 }
 
-LogicalResult mlir::getStridesAndOffset(MemRefType t,
-                                        SmallVectorImpl<int64_t> &strides,
-                                        int64_t &offset) {
+static LogicalResult getStridesAndOffset(MemRefType t,
+                                         SmallVectorImpl<AffineExpr> &strides,
+                                         AffineExpr &offset) {
   auto affineMaps = t.getAffineMaps();
   // For now strides are only computed on a single affine map with a single
   // result (i.e. the closed subset of linearization maps that are compatible
@@ -583,39 +530,58 @@ LogicalResult mlir::getStridesAndOffset(MemRefType t,
   // TODO(ntv): support more forms on a per-need basis.
   if (affineMaps.size() > 1)
     return failure();
-  AffineExpr stridedExpr;
-  if (affineMaps.empty() || affineMaps[0].isIdentity()) {
-    if (t.getRank() == 0) {
-      // Handle 0-D corner case.
-      offset = 0;
-      return success();
-    }
-    stridedExpr = makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext());
-  } else if (affineMaps[0].getNumResults() == 1) {
-    stridedExpr = affineMaps[0].getResult(0);
-  }
-  if (!stridedExpr)
+  if (affineMaps.size() == 1 && affineMaps[0].getNumResults() != 1)
     return failure();
 
-  bool failed = false;
-  strides = SmallVector<int64_t, 4>(t.getRank(), 0);
-  bool seenOffset = false;
-  SmallVector<bool, 4> seen(t.getRank(), false);
-  if (stridedExpr.isa<AffineBinaryOpExpr>()) {
-    stridedExpr.walk([&](AffineExpr e) {
-      if (!failed)
-        extractStrides(e, strides, offset, seen, seenOffset, failed);
-    });
-  } else {
-    extractStridesFromTerm(stridedExpr, strides, offset, seen, seenOffset);
+  auto zero = getAffineConstantExpr(0, t.getContext());
+  auto one = getAffineConstantExpr(1, t.getContext());
+  offset = zero;
+  strides.assign(t.getRank(), zero);
+
+  AffineMap m;
+  if (!affineMaps.empty()) {
+    m = affineMaps.front();
+    assert(!m.isIdentity() && "unexpected identity map");
   }
 
-  // Constant offset may not be present in `stridedExpr` which means it is
-  // implicitly 0.
-  if (!seenOffset)
-    offset = 0;
+  // Canonical case for empty map.
+  if (!m) {
+    // 0-D corner case, offset is already 0.
+    if (t.getRank() == 0)
+      return success();
+    auto stridedExpr =
+        makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext());
+    if (succeeded(extractStrides(stridedExpr, one, strides, offset)))
+      return success();
+    assert(false && "unexpected failure: extract strides in canonical layout");
+  }
+
+  // Non-canonical case requires more work.
+  auto stridedExpr =
+      simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols());
+  if (failed(extractStrides(stridedExpr, one, strides, offset))) {
+    offset = AffineExpr();
+    strides.clear();
+    return failure();
+  }
 
-  if (failed || !llvm::all_of(seen, [](bool b) { return b; })) {
+  // Simplify results to allow folding to constants and simple checks.
+  unsigned numDims = m.getNumDims();
+  unsigned numSymbols = m.getNumSymbols();
+  offset = simplifyAffineExpr(offset, numDims, numSymbols);
+  for (auto &stride : strides)
+    stride = simplifyAffineExpr(stride, numDims, numSymbols);
+
+  /// In practice, a strided memref must be internally non-aliasing. Test
+  /// against 0 as a proxy.
+  /// TODO(ntv) static cases can have more advanced checks.
+  /// TODO(ntv) dynamic cases would require a way to compare symbolic
+  /// expressions and would probably need an affine set context propagated
+  /// everywhere.
+  if (llvm::any_of(strides, [](AffineExpr e) {
+        return e == getAffineConstantExpr(0, e.getContext());
+      })) {
+    offset = AffineExpr();
     strides.clear();
     return failure();
   }
@@ -623,6 +589,26 @@ LogicalResult mlir::getStridesAndOffset(MemRefType t,
   return success();
 }
 
+LogicalResult mlir::getStridesAndOffset(MemRefType t,
+                                        SmallVectorImpl<int64_t> &strides,
+                                        int64_t &offset) {
+  AffineExpr offsetExpr;
+  SmallVector<AffineExpr, 4> strideExprs;
+  if (failed(::getStridesAndOffset(t, strideExprs, offsetExpr)))
+    return failure();
+  if (auto cst = offsetExpr.dyn_cast<AffineConstantExpr>())
+    offset = cst.getValue();
+  else
+    offset = ShapedType::kDynamicStrideOrOffset;
+  for (auto e : strideExprs) {
+    if (auto c = e.dyn_cast<AffineConstantExpr>())
+      strides.push_back(c.getValue());
+    else
+      strides.push_back(ShapedType::kDynamicStrideOrOffset);
+  }
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 /// ComplexType
 //===----------------------------------------------------------------------===//
index 6efd21d..aacd0c7 100644 (file)
@@ -67,5 +67,15 @@ func @f(%0: index) {
 // CHECK: MemRefType memref<3x4x5xf32, (d0, d1, d2) -> (d0 ceildiv 4 + d1 + d2)> cannot be converted to strided form
   %103 = alloc() : memref<3x4x5xf32, (i, j, k)->(i mod 4 + j + k)>
 // CHECK: MemRefType memref<3x4x5xf32, (d0, d1, d2) -> (d0 mod 4 + d1 + d2)> cannot be converted to strided form
+
+  %200 = alloc()[%0, %0, %0] : memref<3x4x5xf32, (i, j, k)[M, N, K]->(M * i + N * i + N * j + K * k - (M + N - 20)* i)>
+  // CHECK: MemRefType offset: 0 strides: 20, ?, ?
+  %201 = alloc()[%0, %0, %0] : memref<3x4x5xf32, (i, j, k)[M, N, K]->(M * i + N * i + N * K * j + K * K * k - (M + N - 20) * (i + 1))>
+  // CHECK: MemRefType offset: ? strides: 20, ?, ?
+  %202 = alloc()[%0, %0, %0] : memref<3x4x5xf32, (i, j, k)[M, N, K]->(M * (i + 1) + j + k - M)>
+  // CHECK: MemRefType offset: 0 strides: ?, 1, 1
+  %203 = alloc()[%0, %0, %0] : memref<3x4x5xf32, (i, j, k)[M, N, K]->(M + M * (i + N * (j + K * k)))>
+  // CHECK: MemRefType offset: ? strides: ?, ?, ?
+
   return
 }