[MLIR] Expose makeCanonicalStridedLayoutExpr in StandardTypes.h.
authorAlexander Belyaev <pifon@google.com>
Tue, 3 Mar 2020 23:37:50 +0000 (00:37 +0100)
committerAlexander Belyaev <pifon@google.com>
Tue, 3 Mar 2020 23:40:33 +0000 (00:40 +0100)
Differential Revision: https://reviews.llvm.org/D75575

mlir/include/mlir/IR/StandardTypes.h
mlir/lib/IR/StandardTypes.cpp

index cd5ba07..d1c31ac 100644 (file)
@@ -658,6 +658,20 @@ AffineMap makeStridedLinearLayoutMap(ArrayRef<int64_t> strides, int64_t offset,
 /// `t` with simplified layout.
 MemRefType canonicalizeStridedLayout(MemRefType t);
 
+/// Given MemRef `sizes` that are either static or dynamic, returns the
+/// canonical "contiguous" strides AffineExpr. Strides are multiplicative and
+/// once a dynamic dimension is encountered, all canonical strides become
+/// dynamic and need to be encoded with a different symbol.
+/// For canonical strides expressions, the offset is always 0 and and fastest
+/// varying stride is always `1`.
+///
+/// Examples:
+///   - memref<3x4x5xf32> has canonical stride expression `20*d0 + 5*d1 + d2`.
+///   - memref<3x?x5xf32> has canonical stride expression `s0*d0 + 5*d1 + d2`.
+///   - memref<3x4x?xf32> has canonical stride expression `s1*d0 + s0*d1 + d2`.
+AffineExpr makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
+                                          MLIRContext *context);
+
 /// Return true if the layout for `t` is compatible with strided semantics.
 bool isStrided(MemRefType t);
 
index 774f80a..488601c 100644 (file)
@@ -456,49 +456,6 @@ UnrankedMemRefType::verifyConstructionInvariants(Location loc, Type elementType,
   return success();
 }
 
-/// Given MemRef `sizes` that are either static or dynamic, returns the
-/// canonical "contiguous" strides AffineExpr. Strides are multiplicative and
-/// once a dynamic dimension is encountered, all canonical strides become
-/// dynamic and need to be encoded with a different symbol.
-/// For canonical strides expressions, the offset is always 0 and and fastest
-/// varying stride is always `1`.
-///
-/// Examples:
-///   - memref<3x4x5xf32> has canonical stride expression `20*d0 + 5*d1 + d2`.
-///   - memref<3x?x5xf32> has canonical stride expression `s0*d0 + 5*d1 + d2`.
-///   - memref<3x4x?xf32> has canonical stride expression `s1*d0 + s0*d1 + d2`.
-static AffineExpr makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
-                                                 MLIRContext *context) {
-  AffineExpr expr;
-  bool dynamicPoisonBit = false;
-  unsigned nSymbols = 0;
-  int64_t runningSize = 1;
-  unsigned rank = sizes.size();
-  for (auto en : llvm::enumerate(llvm::reverse(sizes))) {
-    auto size = en.value();
-    auto position = rank - 1 - en.index();
-    // Degenerate case, no size =-> no stride
-    if (size == 0)
-      continue;
-    auto d = getAffineDimExpr(position, context);
-    // Static case: stride = runningSize and runningSize *= size.
-    if (!dynamicPoisonBit) {
-      auto cst = getAffineConstantExpr(runningSize, context);
-      expr = expr ? expr + cst * d : cst * d;
-      if (size > 0)
-        runningSize *= size;
-      else
-        // From now on bail into dynamic mode.
-        dynamicPoisonBit = true;
-      continue;
-    }
-    // Dynamic case, new symbol for each new stride.
-    auto sym = getAffineSymbolExpr(nSymbols++, context);
-    expr = expr ? expr + d * sym : d * sym;
-  }
-  return simplifyAffineExpr(expr, rank, nSymbols);
-}
-
 // Fallback cases for terminal dim/sym/cst that are not part of a binary op (
 // i.e. single term). Accumulate the AffineExpr into the existing one.
 static void extractStridesFromTerm(AffineExpr e,
@@ -766,6 +723,38 @@ MemRefType mlir::canonicalizeStridedLayout(MemRefType t) {
   return MemRefType::Builder(t).setAffineMaps({});
 }
 
+AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
+                                                MLIRContext *context) {
+  AffineExpr expr;
+  bool dynamicPoisonBit = false;
+  unsigned nSymbols = 0;
+  int64_t runningSize = 1;
+  unsigned rank = sizes.size();
+  for (auto en : llvm::enumerate(llvm::reverse(sizes))) {
+    auto size = en.value();
+    auto position = rank - 1 - en.index();
+    // Degenerate case, no size =-> no stride
+    if (size == 0)
+      continue;
+    auto d = getAffineDimExpr(position, context);
+    // Static case: stride = runningSize and runningSize *= size.
+    if (!dynamicPoisonBit) {
+      auto cst = getAffineConstantExpr(runningSize, context);
+      expr = expr ? expr + cst * d : cst * d;
+      if (size > 0)
+        runningSize *= size;
+      else
+        // From now on bail into dynamic mode.
+        dynamicPoisonBit = true;
+      continue;
+    }
+    // Dynamic case, new symbol for each new stride.
+    auto sym = getAffineSymbolExpr(nSymbols++, context);
+    expr = expr ? expr + d * sym : d * sym;
+  }
+  return simplifyAffineExpr(expr, rank, nSymbols);
+}
+
 /// Return true if the layout for `t` is compatible with strided semantics.
 bool mlir::isStrided(MemRefType t) {
   int64_t offset;