Fix improperly indexed DimOp in LowerVectorTransfers.cpp
authorNicolas Vasilache <ntv@google.com>
Wed, 16 Jan 2019 22:06:20 +0000 (14:06 -0800)
committerjpienaar <jpienaar@google.com>
Fri, 29 Mar 2019 22:24:13 +0000 (15:24 -0700)
This CL fixes a misunderstanding in how to build DimOp which triggered
execution issues in the CPU path.

The problem is that, given a `memref<?x4x?x8x?xf32>`, the expressions to
construct the dynamic dimensions should be:
`dim %arg, 0 : memref<?x4x?x8x?xf32>`
`dim %arg, 2 : memref<?x4x?x8x?xf32>`
and
`dim %arg, 4 : memref<?x4x?x8x?xf32>`

Before this CL, we wold construct:
`dim %arg, 0 : memref<?x4x?x8x?xf32>`
`dim %arg, 1 : memref<?x4x?x8x?xf32>`
`dim %arg, 2 : memref<?x4x?x8x?xf32>`

and expect the other dimensions to be constants.
This assumption seems consistent at first glance with the syntax of alloc:

```
    %tensor = alloc(%M, %N, %O) : memref<?x4x?x8x?xf32>
```

But this was actuallyincorrect.

This CL also makes the relevant functions available to EDSCs and removes
duplication of the incorrect function.

PiperOrigin-RevId: 229622766

mlir/include/mlir/EDSC/MLIREmitter.h
mlir/lib/EDSC/LowerEDSCTestPass.cpp
mlir/lib/EDSC/MLIREmitter.cpp
mlir/lib/Transforms/LowerVectorTransfers.cpp
mlir/test/Transforms/Vectorize/lower_vector_transfers.mlir
mlir/test/mlir-tblgen/reference-impl.td

index 1b6c6fc23b00e59533373d66556a99ea890a15d5..24b5a5f16a375fa15177498cdd754b38fd9e01e8 100644 (file)
@@ -114,6 +114,15 @@ struct MLIREmitter {
   /// Prerequisite: it must exist.
   Value *getValue(Expr expr) { return ssaBindings.lookup(expr); }
 
+  /// Returns a list of `Bindable` that are bound to the dimensions of the
+  /// memRef. The proper DimOp and ConstantOp are constructed at the current
+  /// insertion point in `builder`. They can be later hoisted and simplified in
+  /// a separate pass.
+  ///
+  /// Prerequisite:
+  /// `memRef` is a Value of type MemRefType.
+  SmallVector<edsc::Bindable, 8> makeBoundSizes(Value *memRef);
+
 private:
   FuncBuilder *builder;
   Location location;
index 76aa9ffd9951e1edda2404a0e966494f8cc7f365..e891be68fd34c42e3c1260cdfad9790a5736cff6 100644 (file)
@@ -42,36 +42,6 @@ struct LowerEDSCTestPass : public FunctionPass {
 
 char LowerEDSCTestPass::passID = 0;
 
-// TODO: These should be in a common library.
-static bool isDynamicSize(int size) { return size < 0; }
-
-static SmallVector<Value *, 8> getMemRefSizes(FuncBuilder *b, Location loc,
-                                              Value *memRef) {
-  auto memRefType = memRef->getType().cast<MemRefType>();
-  SmallVector<Value *, 8> res;
-  res.reserve(memRefType.getShape().size());
-  unsigned countSymbolicShapes = 0;
-  for (int size : memRefType.getShape()) {
-    if (isDynamicSize(size)) {
-      res.push_back(b->create<DimOp>(loc, memRef, countSymbolicShapes++));
-    } else {
-      res.push_back(b->create<ConstantIndexOp>(loc, size));
-    }
-  }
-  return res;
-}
-
-SmallVector<edsc::Bindable, 8> makeBoundSizes(edsc::MLIREmitter *emitter,
-                                              Value *memRef) {
-  MemRefType memRefType = memRef->getType().cast<MemRefType>();
-  auto memRefSizes = edsc::makeBindables(memRefType.getShape().size());
-  auto memrefSizeValues =
-      getMemRefSizes(emitter->getBuilder(), emitter->getLocation(), memRef);
-  assert(memrefSizeValues.size() == memRefSizes.size());
-  emitter->bindZipRange(llvm::zip(memRefSizes, memrefSizeValues));
-  return memRefSizes;
-}
-
 #include "mlir/EDSC/reference-impl.inc"
 
 PassResult LowerEDSCTestPass::runOnFunction(Function *f) {
index 7bc24ee8f7d80dd5df9c2f1efcf434e430d0bd9c..1020173e1a639733de4cdc90303d0b07e01f68c1 100644 (file)
@@ -334,7 +334,6 @@ void MLIREmitter::emitStmt(const Stmt &stmt) {
   if (stmt.getRHS().getKind() != ExprKind::Block) {
     auto *val = emit(stmt.getRHS());
     if (!val) {
-      llvm::errs() << "\n" << stmt.getRHS();
       assert((stmt.getRHS().getKind() == ExprKind::Dealloc ||
               stmt.getRHS().getKind() == ExprKind::Store) &&
              "dealloc or store expected as the only 0-result ops");
@@ -356,5 +355,62 @@ void MLIREmitter::emitStmts(ArrayRef<Stmt> stmts) {
   }
 }
 
+static bool isDynamicSize(int size) { return size < 0; }
+
+/// This function emits the proper Value* at the place of insertion of b,
+/// where each value is the proper ConstantOp or DimOp. Returns a vector with
+/// these Value*. Note this function does not concern itself with hoisting of
+/// constants and will produce redundant IR. Subsequent MLIR simplification
+/// passes like LICM and CSE are expected to clean this up.
+///
+/// More specifically, a MemRefType has a shape vector in which:
+///   - constant ranks are embedded explicitly with their value;
+///   - symbolic ranks are represented implicitly by -1 and need to be recovered
+///     with a DimOp operation.
+///
+/// Example:
+/// When called on:
+///
+/// ```mlir
+///    memref<?x3x4x?x5xf32>
+/// ```
+///
+/// This emits MLIR similar to:
+///
+/// ```mlir
+///    %d0 = dim %0, 0 : memref<?x3x4x?x5xf32>
+///    %c3 = constant 3 : index
+///    %c4 = constant 4 : index
+///    %d3 = dim %0, 3 : memref<?x3x4x?x5xf32>
+///    %c5 = constant 5 : index
+/// ```
+///
+/// and returns the vector with {%d0, %c3, %c4, %d3, %c5}.
+static SmallVector<Value *, 8> getMemRefSizes(FuncBuilder *b, Location loc,
+                                              Value *memRef) {
+  auto memRefType = memRef->getType().template cast<MemRefType>();
+  SmallVector<Value *, 8> res;
+  res.reserve(memRefType.getShape().size());
+  const auto &shape = memRefType.getShape();
+  for (unsigned idx = 0, n = shape.size(); idx < n; ++idx) {
+    if (isDynamicSize(shape[idx])) {
+      res.push_back(b->create<DimOp>(loc, memRef, idx));
+    } else {
+      res.push_back(b->create<ConstantIndexOp>(loc, shape[idx]));
+    }
+  }
+  return res;
+}
+
+SmallVector<edsc::Bindable, 8> MLIREmitter::makeBoundSizes(Value *memRef) {
+  assert(memRef->getType().isa<MemRefType>() && "Expected a MemRef value");
+  MemRefType memRefType = memRef->getType().cast<MemRefType>();
+  auto memRefSizes = edsc::makeBindables(memRefType.getShape().size());
+  auto memrefSizeValues = getMemRefSizes(getBuilder(), getLocation(), memRef);
+  assert(memrefSizeValues.size() == memRefSizes.size());
+  bindZipRange(llvm::zip(memRefSizes, memrefSizeValues));
+  return memRefSizes;
+}
+
 } // namespace edsc
 } // namespace mlir
index 918c1602e16915323319bf473469a55a8094857b..bd19fbce2f63ac523ab0ddcd9500f4de11b73d6a 100644 (file)
@@ -62,52 +62,6 @@ using namespace mlir;
 
 #define DEBUG_TYPE "lower-vector-transfers"
 
-/// This function emits the proper Value* at the place of insertion of b,
-/// where each value is the proper ConstantOp or DimOp. Returns a vector with
-/// these Value*. Note this function does not concern itself with hoisting of
-/// constants and will produce redundant IR. Subsequent MLIR simplification
-/// passes like LICM and CSE are expected to clean this up.
-///
-/// More specifically, a MemRefType has a shape vector in which:
-///   - constant ranks are embedded explicitly with their value;
-///   - symbolic ranks are represented implicitly by -1 and need to be recovered
-///     with a DimOp operation.
-///
-/// Example:
-/// When called on:
-///
-/// ```mlir
-///    memref<?x3x4x?x5xf32>
-/// ```
-///
-/// This emits MLIR similar to:
-///
-/// ```mlir
-///    %d0 = dim %0, 0 : memref<?x3x4x?x5xf32>
-///    %c3 = constant 3 : index
-///    %c4 = constant 4 : index
-///    %d1 = dim %0, 0 : memref<?x3x4x?x5xf32>
-///    %c5 = constant 5 : index
-/// ```
-///
-/// and returns the vector with {%d0, %c3, %c4, %d1, %c5}.
-bool isDynamicSize(int size) { return size < 0; }
-SmallVector<Value *, 8> getMemRefSizes(FuncBuilder *b, Location loc,
-                                       Value *memRef) {
-  auto memRefType = memRef->getType().template cast<MemRefType>();
-  SmallVector<Value *, 8> res;
-  res.reserve(memRefType.getShape().size());
-  unsigned countSymbolicShapes = 0;
-  for (int size : memRefType.getShape()) {
-    if (isDynamicSize(size)) {
-      res.push_back(b->create<DimOp>(loc, memRef, countSymbolicShapes++));
-    } else {
-      res.push_back(b->create<ConstantIndexOp>(loc, size));
-    }
-  }
-  return res;
-}
-
 namespace {
 /// Helper structure to hold information about loop nest, clipped accesses to
 /// the original scalar MemRef as well as full accesses to temporary MemRef in
@@ -215,12 +169,7 @@ VectorTransferRewriter<VectorTransferOpTy>::makeVectorTransferAccessInfo() {
   auto ivs = makeBindables(vectorShape.size());
 
   // Create and bind Bindables to refer to the Value for memref sizes.
-  auto memRefSizes = makeBindables(memrefShape.size());
-  auto memrefSizeValues = getMemRefSizes(
-      emitter.getBuilder(), emitter.getLocation(), transfer->getMemRef());
-  assert(memrefSizeValues.size() == memRefSizes.size());
-  // Bind
-  emitter.bindZipRange(llvm::zip(memRefSizes, memrefSizeValues));
+  auto memRefSizes = emitter.makeBoundSizes(transfer->getMemRef());
 
   // Create the edsc::Expr for the clipped and transposes access expressions
   // using the permutationMap. Additionally, capture the index accessing the
index c5f914201a965885dd747de55b75a95a24898246..bbf6458f17e9c7ee3fbafb7273e66a505df98a12 100644 (file)
@@ -26,6 +26,31 @@ func @materialize_read_1d() {
   return
 }
 
+// CHECK-LABEL: func @materialize_read_1d_partially_specialized
+func @materialize_read_1d_partially_specialized(%dyn1 : index, %dyn2 : index, %dyn4 : index) {
+  %A = alloc (%dyn1, %dyn2, %dyn4) : memref<7x?x?x42x?xf32>
+  for %i0 = 0 to 7 {
+    for %i1 = 0 to %dyn1 {
+      for %i2 = 0 to %dyn2 {
+        for %i3 = 0 to 42 step 2 {
+          for %i4 = 0 to %dyn4 {
+            %f1 = vector_transfer_read %A, %i0, %i1, %i2, %i3, %i4 {permutation_map: (d0, d1, d2, d3, d4) -> (d3)} : ( memref<7x?x?x42x?xf32>, index, index, index, index, index) -> vector<4xf32>
+            %i3p1 = affine_apply (d0) -> (d0 + 1) (%i3)
+            %f2 = vector_transfer_read %A, %i0, %i1, %i2, %i3p1, %i4 {permutation_map: (d0, d1, d2, d3, d4) -> (d3)} : ( memref<7x?x?x42x?xf32>, index, index, index, index, index) -> vector<4xf32>
+          }
+        }
+      }
+    }
+  }
+  // CHECK: %[[tensor:[0-9]+]] = alloc
+  // CHECK-NOT: {{.*}} dim %[[tensor]], 0
+  // CHECK: {{.*}} dim %[[tensor]], 1
+  // CHECK: {{.*}} dim %[[tensor]], 2
+  // CHECK-NOT: {{.*}} dim %[[tensor]], 3
+  // CHECK: {{.*}} dim %[[tensor]], 4
+  return
+}
+
 // CHECK-LABEL: func @materialize_read(%arg0: index, %arg1: index, %arg2: index, %arg3: index) {
 func @materialize_read(%M: index, %N: index, %O: index, %P: index) {
   // CHECK-NEXT:  %0 = alloc(%arg0, %arg1, %arg2, %arg3) : memref<?x?x?x?xf32>
index 57b2c899b96ef5927681cacf8b00d19c8b4e0d63..cb249f5a0ee542e7912607b442efb70fb1026d87 100644 (file)
@@ -16,7 +16,7 @@ def X_AddOp : Op<"x.add">,
     auto *resultMemRef = *(f->getArguments().begin() + 2);
 
     Bindable lhs, rhs, result;
-    auto lhsShape = makeBoundSizes(&emitter, lhsMemRef);
+    auto lhsShape = emitter.makeBoundSizes(lhsMemRef);
 
     auto ivs = makeBindables(lhsShape.size());
     Bindable zero, one;