/// `t` with simplified layout.
MemRefType canonicalizeStridedLayout(MemRefType t);
+/// Given MemRef `sizes` that are either static or dynamic, returns the
+/// canonical "contiguous" strides AffineExpr. Strides are multiplicative and
+/// once a dynamic dimension is encountered, all canonical strides become
+/// dynamic and need to be encoded with a different symbol.
+/// For canonical strides expressions, the offset is always 0 and and fastest
+/// varying stride is always `1`.
+///
+/// Examples:
+/// - memref<3x4x5xf32> has canonical stride expression `20*d0 + 5*d1 + d2`.
+/// - memref<3x?x5xf32> has canonical stride expression `s0*d0 + 5*d1 + d2`.
+/// - memref<3x4x?xf32> has canonical stride expression `s1*d0 + s0*d1 + d2`.
+AffineExpr makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
+ MLIRContext *context);
+
/// Return true if the layout for `t` is compatible with strided semantics.
bool isStrided(MemRefType t);
return success();
}
-/// Given MemRef `sizes` that are either static or dynamic, returns the
-/// canonical "contiguous" strides AffineExpr. Strides are multiplicative and
-/// once a dynamic dimension is encountered, all canonical strides become
-/// dynamic and need to be encoded with a different symbol.
-/// For canonical strides expressions, the offset is always 0 and and fastest
-/// varying stride is always `1`.
-///
-/// Examples:
-/// - memref<3x4x5xf32> has canonical stride expression `20*d0 + 5*d1 + d2`.
-/// - memref<3x?x5xf32> has canonical stride expression `s0*d0 + 5*d1 + d2`.
-/// - memref<3x4x?xf32> has canonical stride expression `s1*d0 + s0*d1 + d2`.
-static AffineExpr makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
- MLIRContext *context) {
- AffineExpr expr;
- bool dynamicPoisonBit = false;
- unsigned nSymbols = 0;
- int64_t runningSize = 1;
- unsigned rank = sizes.size();
- for (auto en : llvm::enumerate(llvm::reverse(sizes))) {
- auto size = en.value();
- auto position = rank - 1 - en.index();
- // Degenerate case, no size =-> no stride
- if (size == 0)
- continue;
- auto d = getAffineDimExpr(position, context);
- // Static case: stride = runningSize and runningSize *= size.
- if (!dynamicPoisonBit) {
- auto cst = getAffineConstantExpr(runningSize, context);
- expr = expr ? expr + cst * d : cst * d;
- if (size > 0)
- runningSize *= size;
- else
- // From now on bail into dynamic mode.
- dynamicPoisonBit = true;
- continue;
- }
- // Dynamic case, new symbol for each new stride.
- auto sym = getAffineSymbolExpr(nSymbols++, context);
- expr = expr ? expr + d * sym : d * sym;
- }
- return simplifyAffineExpr(expr, rank, nSymbols);
-}
-
// Fallback cases for terminal dim/sym/cst that are not part of a binary op (
// i.e. single term). Accumulate the AffineExpr into the existing one.
static void extractStridesFromTerm(AffineExpr e,
return MemRefType::Builder(t).setAffineMaps({});
}
+AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
+ MLIRContext *context) {
+ AffineExpr expr;
+ bool dynamicPoisonBit = false;
+ unsigned nSymbols = 0;
+ int64_t runningSize = 1;
+ unsigned rank = sizes.size();
+ for (auto en : llvm::enumerate(llvm::reverse(sizes))) {
+ auto size = en.value();
+ auto position = rank - 1 - en.index();
+ // Degenerate case, no size =-> no stride
+ if (size == 0)
+ continue;
+ auto d = getAffineDimExpr(position, context);
+ // Static case: stride = runningSize and runningSize *= size.
+ if (!dynamicPoisonBit) {
+ auto cst = getAffineConstantExpr(runningSize, context);
+ expr = expr ? expr + cst * d : cst * d;
+ if (size > 0)
+ runningSize *= size;
+ else
+ // From now on bail into dynamic mode.
+ dynamicPoisonBit = true;
+ continue;
+ }
+ // Dynamic case, new symbol for each new stride.
+ auto sym = getAffineSymbolExpr(nSymbols++, context);
+ expr = expr ? expr + d * sym : d * sym;
+ }
+ return simplifyAffineExpr(expr, rank, nSymbols);
+}
+
/// Return true if the layout for `t` is compatible with strided semantics.
bool mlir::isStrided(MemRefType t) {
int64_t offset;