[mlir][linalg] Extend linalg vectorization to support non-identity input maps
authorthomasraoux <thomasraoux@google.com>
Tue, 16 Mar 2021 21:14:51 +0000 (14:14 -0700)
committerthomasraoux <thomasraoux@google.com>
Thu, 18 Mar 2021 19:32:35 +0000 (12:32 -0700)
This propagates the affine map to transfer_read op in case it is not a
minor identity map.

Differential Revision: https://reviews.llvm.org/D98523

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/lib/Dialect/Vector/VectorOps.cpp
mlir/lib/IR/AffineMap.cpp
mlir/test/Dialect/Linalg/vectorization.mlir

index 880e7f3..dab32d2 100644 (file)
@@ -87,11 +87,14 @@ static VectorType extractVectorTypeFromShapedValue(Value v) {
 /// Build a vector.transfer_read from `source` at indices set to all `0`.
 /// If source has rank zero, build an memref.load.
 /// Return the produced value.
-static Value buildVectorRead(OpBuilder &builder, Value source) {
+static Value buildVectorRead(OpBuilder &builder, Value source,
+                             VectorType vectorType, AffineMap map) {
   edsc::ScopedContext scope(builder);
   auto shapedType = source.getType().cast<ShapedType>();
-  if (VectorType vectorType = extractVectorTypeFromShapedValue(source)) {
+  if (vectorType) {
     SmallVector<Value> indices(shapedType.getRank(), std_constant_index(0));
+    if (map)
+      return vector_transfer_read(vectorType, source, indices, map);
     return vector_transfer_read(vectorType, source, indices);
   }
   return memref_load(source);
@@ -238,6 +241,51 @@ vectorizeOneOp(OpBuilder &builder, Operation *op,
                              builder.createOperation(state)};
 }
 
+/// Detect whether `r` has only ConstantOp, ElementwiseMappable and YieldOp.
+static bool hasOnlyScalarElementwiseOp(Region &r) {
+  if (!llvm::hasSingleElement(r))
+    return false;
+  for (Operation &op : r.front()) {
+    if (!(isa<ConstantOp, linalg::YieldOp>(op) ||
+          OpTrait::hasElementwiseMappableTraits(&op)) ||
+        llvm::any_of(op.getResultTypes(),
+                     [](Type type) { return !type.isIntOrIndexOrFloat(); }))
+      return false;
+  }
+  return true;
+}
+
+// Return true if the op is an element-wise linalg op.
+static bool isElementwise(Operation *op) {
+  auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
+  if (!linalgOp)
+    return false;
+  if (linalgOp.getNumLoops() != linalgOp.getNumParallelLoops())
+    return false;
+  // TODO: relax the restrictions on indexing map.
+  for (unsigned i = 0, e = linalgOp.getNumOutputs(); i < e; i++) {
+    if (!linalgOp.getOutputIndexingMap(i).isIdentity())
+      return false;
+  }
+  if (linalgOp->getNumRegions() != 1)
+    return false;
+  return hasOnlyScalarElementwiseOp(linalgOp->getRegion(0));
+}
+
+// Calculate the map to apply to transfer_read to convert the input shape into
+// the output shape.
+static AffineMap getTransferReadMap(LinalgOp linalgOp, unsigned argIndex) {
+  AffineMap linalgMap = linalgOp.getIndexingMap(argIndex);
+  MLIRContext *context = linalgMap.getContext();
+  AffineExpr zero = mlir::getAffineConstantExpr(0, context);
+  SmallVector<AffineExpr, 4> exprs(linalgMap.getNumInputs(), zero);
+  for (unsigned i : llvm::seq(unsigned(0), linalgMap.getNumResults())) {
+    exprs[linalgMap.getDimPosition(i)] = getAffineDimExpr(i, context);
+  }
+  return AffineMap::get(linalgMap.getNumResults(), /*symbolCount=*/0, exprs,
+                        context);
+}
+
 /// Generic vectorization function that rewrites the body of a `linalgOp` into
 /// vector form. Generic vectorization proceeds as follows:
 ///   1. The region for the linalg op is created if necessary.
@@ -282,7 +330,19 @@ LogicalResult vectorizeAsLinalgGeneric(
   SmallVector<AffineMap> indexings;
   for (auto bbarg : block->getArguments()) {
     Value vectorArg = linalgOp.getShapedOperand(bbarg.getArgNumber());
-    Value vectorRead = buildVectorRead(builder, vectorArg);
+    AffineMap map;
+    VectorType vectorType = extractVectorTypeFromShapedValue(vectorArg);
+    if (isElementwise(linalgOp) &&
+        !linalgOp.getIndexingMap(bbarg.getArgNumber()).isMinorIdentity()) {
+      // Currently assume we don't support output permutations.
+      assert(linalgOp.getNumOutputs() > 0 &&
+             linalgOp.getOutputIndexingMap(0).isIdentity());
+      ArrayRef<int64_t> outputShape =
+          linalgOp.getOutputShapedType(0).getShape();
+      vectorType = VectorType::get(outputShape, vectorType.getElementType());
+      map = getTransferReadMap(linalgOp, bbarg.getArgNumber());
+    }
+    Value vectorRead = buildVectorRead(builder, vectorArg, vectorType, map);
     LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: new vectorized bbarg("
                       << bbarg.getArgNumber() << "): " << vectorRead);
     bvm.map(bbarg, vectorRead);
@@ -316,44 +376,6 @@ LogicalResult vectorizeAsLinalgGeneric(
   return success();
 }
 
-/// Detect whether `r` has only ConstantOp, ElementwiseMappable and YieldOp.
-static bool hasOnlyScalarElementwiseOp(Region &r) {
-  if (!llvm::hasSingleElement(r))
-    return false;
-  for (Operation &op : r.front()) {
-    if (!(isa<ConstantOp, linalg::YieldOp>(op) ||
-          OpTrait::hasElementwiseMappableTraits(&op)) ||
-        llvm::any_of(op.getResultTypes(),
-                     [](Type type) { return !type.isIntOrIndexOrFloat(); }))
-      return false;
-  }
-  return true;
-}
-
-// Return true if the op is an element-wise linalg op.
-static bool isElementwise(Operation *op) {
-  auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
-  if (!linalgOp)
-    return false;
-  if (linalgOp.getNumLoops() != linalgOp.getNumParallelLoops())
-    return false;
-  // TODO: relax the restrictions on indexing map.
-  for (unsigned i = 0, e = linalgOp.getNumOutputs(); i < e; i++) {
-    if (!linalgOp.getOutputIndexingMap(i).isIdentity())
-      return false;
-  }
-  // Currently bound the input indexing map to minor identity as other
-  // permutations might require adding transpose ops to convert the vector read
-  // to the right shape.
-  for (unsigned i = 0, e = linalgOp.getNumInputs(); i < e; i++) {
-    if (!linalgOp.getInputIndexingMap(i).isMinorIdentity())
-      return false;
-  }
-  if (linalgOp->getNumRegions() != 1)
-    return false;
-  return hasOnlyScalarElementwiseOp(linalgOp->getRegion(0));
-}
-
 static LogicalResult vectorizeContraction(OpBuilder &builder, LinalgOp linalgOp,
                                           SmallVectorImpl<Value> &newResults) {
   assert(isaContractionOpInterface(linalgOp) &&
index 6ca28ba..08bf762 100644 (file)
@@ -2294,8 +2294,7 @@ void TransferReadOp::build(OpBuilder &builder, OperationState &result,
 
 static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op) {
   SmallVector<StringRef, 2> elidedAttrs;
-  if (op.permutation_map() ==
-      getTransferMinorIdentityMap(op.getShapedType(), op.getVectorType()))
+  if (op.permutation_map().isMinorIdentity())
     elidedAttrs.push_back(op.getPermutationMapAttrName());
   bool elideMasked = true;
   if (auto maybeMasked = op.masked()) {
index 9de80e9..98ca45b 100644 (file)
@@ -106,8 +106,9 @@ AffineMap AffineMap::getMinorIdentityMap(unsigned dims, unsigned results,
 }
 
 bool AffineMap::isMinorIdentity() const {
-  return *this ==
-         getMinorIdentityMap(getNumDims(), getNumResults(), getContext());
+  return getNumDims() >= getNumResults() &&
+         *this ==
+             getMinorIdentityMap(getNumDims(), getNumResults(), getContext());
 }
 
 /// Returns true if this affine map is a minor identity up to broadcasted
index c43bf07..74ff436 100644 (file)
@@ -341,6 +341,42 @@ func @generic_vectorize_tensor(%arg0: tensor<4x256xf32>,
 
 // -----
 
+// Test different input maps.
+#matmul_trait = {
+  indexing_maps = [
+    affine_map<(d0, d1, d2, d3) -> (d1, d0)>,
+    affine_map<(d0, d1, d2, d3) -> (d3, d1)>,
+    affine_map<(d0, d1, d2, d3) -> (d3, d1, d0, d2)>,
+    affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+  ],
+  iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+}
+
+// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d1, d0, 0, 0)>
+// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (0, d1, 0, d0)>
+// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d2, d1, d3, d0)>
+//       CHECK: func @vectorization_transpose
+//       CHECK: vector.transfer_read {{.*}}{permutation_map = #[[MAP0]]} : memref<14x7xf32>, vector<7x14x8x16xf32>
+//       CHECK: vector.transfer_read {{.*}}{permutation_map = #[[MAP1]]} : memref<16x14xf32>, vector<7x14x8x16xf32>
+//       CHECK: vector.transfer_read {{.*}}{permutation_map = #[[MAP2]]} : memref<16x14x7x8xf32>, vector<7x14x8x16xf32>
+//       CHECK: addf {{.*}} : vector<7x14x8x16xf32>
+//       CHECK: addf {{.*}} : vector<7x14x8x16xf32>
+//       CHECK: vector.transfer_write {{.*}} : vector<7x14x8x16xf32>, memref<7x14x8x16xf32>
+func @vectorization_transpose(%A: memref<14x7xf32>, %B: memref<16x14xf32>,
+                         %C: memref<16x14x7x8xf32>, %D: memref<7x14x8x16xf32>) {
+  linalg.generic #matmul_trait
+    ins(%A, %B, %C : memref<14x7xf32>, memref<16x14xf32>, memref<16x14x7x8xf32>)
+   outs(%D : memref<7x14x8x16xf32>) {
+    ^bb(%a: f32, %b: f32, %c: f32, %d: f32) :
+      %e = addf %a, %b: f32
+      %f = addf %e, %c: f32
+      linalg.yield %f : f32
+  }
+  return
+}
+
+// -----
+
 // CHECK-LABEL: func @matmul_tensors
 //  CHECK-SAME: (%[[ARG0:.*]]: tensor<8x4xf32>, %[[ARG1:.*]]: tensor<4x12xf32>,
 //  CHECK-SAME:  %[[ARG2:.*]]: tensor<8x12xf32>) -> tensor<8x12xf32>