[mlir][Linalg] Enable vectorization of explicit broadcasts
authorDiego Caballero <diegocaballero@google.com>
Tue, 12 Oct 2021 20:58:06 +0000 (20:58 +0000)
committerDiego Caballero <diegocaballero@google.com>
Tue, 12 Oct 2021 21:08:22 +0000 (21:08 +0000)
This patch teaches `isProjectedPermutation` and `inverseAndBroadcastProjectedPermutation`
utilities to deal with maps representing an explicit broadcast, e.g., (d0, d1) -> (d0, 0).
This extension is needed to enable vectorization of such explicit broadcast in Linalg.

Reviewed By: pifon2a, nicolasvasilache

Differential Revision: https://reviews.llvm.org/D111563

mlir/include/mlir/IR/AffineMap.h
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/lib/IR/AffineMap.cpp
mlir/test/Dialect/Linalg/vectorization.mlir

index 906c53d..40052d7 100644 (file)
@@ -273,8 +273,11 @@ public:
   SmallVector<int64_t, 4> compose(ArrayRef<int64_t> values) const;
 
   /// Returns true if the AffineMap represents a subset (i.e. a projection) of a
-  /// symbol-less permutation map.
-  bool isProjectedPermutation() const;
+  /// symbol-less permutation map. `allowZeroInResults` allows projected
+  /// permutation maps with constant zero result expressions.
+  /// TODO: Remove `allowZeroInResults` when constant zero result expressions
+  /// are broadly supported.
+  bool isProjectedPermutation(bool allowZeroInResults = false) const;
 
   /// Returns true if the AffineMap represents a symbol-less permutation map.
   bool isPermutation() const;
@@ -464,6 +467,17 @@ AffineMap inversePermutation(AffineMap map);
 /// ```mlir
 ///    affine_map<(d0) -> (0, 0, d0, 0)>
 /// ```
+/// Example 4:
+///
+/// ```mlir
+///    affine_map<(d0, d1, d2) -> (d0, 0)>
+/// ```
+///
+/// returns:
+///
+/// ```mlir
+///    affine_map<(d0, d1) -> (d0, 0, 0)>
+/// ```
 AffineMap inverseAndBroadcastProjectedPermuation(AffineMap map);
 
 /// Concatenates a list of `maps` into a single AffineMap, stepping over
@@ -518,9 +532,16 @@ SmallVector<T> applyPermutationMap(AffineMap map, llvm::ArrayRef<T> source) {
   assert(map.getNumInputs() == source.size());
   SmallVector<T> result;
   result.reserve(map.getNumResults());
-  for (unsigned i = 0, e = map.getNumResults(); i < e; ++i) {
-    unsigned dim = map.getDimPosition(i);
-    result.push_back(source[dim]);
+  for (AffineExpr expr : map.getResults()) {
+    if (auto dimExpr = expr.dyn_cast<AffineDimExpr>()) {
+      result.push_back(source[dimExpr.getPosition()]);
+    } else if (auto constExpr = expr.dyn_cast<AffineConstantExpr>()) {
+      assert(constExpr.getValue() == 0 &&
+             "Unexpected constant in projected permutation map");
+      result.push_back(0);
+    } else {
+      llvm_unreachable("Unexpected result in projected permutation map");
+    }
   }
   return result;
 }
index b640813..cbe7e3f 100644 (file)
@@ -77,7 +77,8 @@ static OpType getSingleOpOfType(Block &block) {
 /// map is reindexed to `affine_map<(d0, d1, d2) -> (d2, d0, d1)>`, the second
 /// affine map is reindexed to `affine_map<(d0, d1) -> (d0, d1)>`.
 static AffineMap reindexIndexingMap(AffineMap map) {
-  assert(map.isProjectedPermutation() && "expected projected permutation");
+  assert(map.isProjectedPermutation(/*allowZerosInResults=*/true) &&
+         "expected projected permutation");
   auto res = compressUnusedDims(map);
   assert(res.getNumDims() == res.getNumResults() &&
          "expected reindexed map with same number of dims and results");
@@ -593,8 +594,9 @@ static LogicalResult vectorizeContraction(OpBuilder &b, LinalgOp linalgOp,
 }
 
 static bool allIndexingsAreProjectedPermutation(LinalgOp op) {
-  return llvm::all_of(op.getIndexingMaps(),
-                      [](AffineMap m) { return m.isProjectedPermutation(); });
+  return llvm::all_of(op.getIndexingMaps(), [](AffineMap m) {
+    return m.isProjectedPermutation(/*allowZerosInResults=*/true);
+  });
 }
 
 // TODO: probably need some extra checks for reduction followed by consumer
index 9c6f25d..beb47d1 100644 (file)
@@ -495,19 +495,33 @@ SmallVector<int64_t, 4> AffineMap::compose(ArrayRef<int64_t> values) const {
   return res;
 }
 
-bool AffineMap::isProjectedPermutation() const {
+bool AffineMap::isProjectedPermutation(bool allowZeroInResults) const {
   if (getNumSymbols() > 0)
     return false;
+
+  // Having more results than inputs means that results have duplicated dims or
+  // zeros that can't be mapped to input dims.
+  if (getNumResults() > getNumInputs())
+    return false;
+
   SmallVector<bool, 8> seen(getNumInputs(), false);
+  // A projected permutation can have, at most, only one instance of each input
+  // dimension in the result expressions. Zeros are allowed as long as the
+  // number of result expressions is lower or equal than the number of input
+  // expressions.
   for (auto expr : getResults()) {
     if (auto dim = expr.dyn_cast<AffineDimExpr>()) {
       if (seen[dim.getPosition()])
         return false;
       seen[dim.getPosition()] = true;
-      continue;
+    } else {
+      auto constExpr = expr.dyn_cast<AffineConstantExpr>();
+      if (!allowZeroInResults || !constExpr || constExpr.getValue() != 0)
+        return false;
     }
-    return false;
   }
+
+  // Results are either dims or zeros and zeros can be mapped to input dims.
   return true;
 }
 
@@ -696,13 +710,21 @@ AffineMap mlir::inversePermutation(AffineMap map) {
 }
 
 AffineMap mlir::inverseAndBroadcastProjectedPermuation(AffineMap map) {
-  assert(map.isProjectedPermutation());
+  assert(map.isProjectedPermutation(/*allowZeroInResults=*/true));
   MLIRContext *context = map.getContext();
   AffineExpr zero = mlir::getAffineConstantExpr(0, context);
   // Start with all the results as 0.
   SmallVector<AffineExpr, 4> exprs(map.getNumInputs(), zero);
   for (unsigned i : llvm::seq(unsigned(0), map.getNumResults())) {
-    // Reverse each dimension existing in the oringal map result.
+    // Skip zeros from input map. 'exprs' is already initialized to zero.
+    if (auto constExpr = map.getResult(i).dyn_cast<AffineConstantExpr>()) {
+      assert(constExpr.getValue() == 0 &&
+             "Unexpected constant in projected permutation");
+      (void)constExpr;
+      continue;
+    }
+
+    // Reverse each dimension existing in the original map result.
     exprs[map.getDimPosition(i)] = getAffineDimExpr(i, context);
   }
   return AffineMap::get(map.getNumResults(), /*symbolCount=*/0, exprs, context);
index c3a4a05..1e6c801 100644 (file)
@@ -863,6 +863,65 @@ func @red_min_2d(%arg0: tensor<4x4xf32>) -> tensor<4xf32> {
 
 // -----
 
+// CHECK-DAG: #[[$M5:.*]] = affine_map<(d0, d1) -> (d0, 0)>
+
+// CHECK-LABEL:   func @explicit_broadcast(
+func @explicit_broadcast(%arg0: tensor<4x4xf32>, %arg1: tensor<4x1xf32>) -> tensor<4x4xf32> {
+  // CHECK: vector.transfer_read {{.*}} {in_bounds = [true, true]} : tensor<4x4xf32>, vector<4x4xf32>
+  // CHECK: vector.transfer_read {{.*}} {in_bounds = [true, true], permutation_map = #[[$M5]]} : tensor<4x1xf32>, vector<4x4xf32>
+  // CHECK: subf {{.*}} : vector<4x4xf32>
+  // CHECK: vector.transfer_write {{.*}} {in_bounds = [true, true]} : vector<4x4xf32>, tensor<4x4xf32>
+  %c0 = constant 0.0 : f32
+  %init = linalg.init_tensor [4, 4] : tensor<4x4xf32>
+  %fill = linalg.fill(%c0, %init) : f32, tensor<4x4xf32> -> tensor<4x4xf32>
+  %red = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+                                          affine_map<(d0, d1) -> (d0, 0)>,
+                                          affine_map<(d0, d1) -> (d0, d1)>],
+   iterator_types = ["parallel", "parallel"]}
+   ins(%arg0, %arg1 : tensor<4x4xf32>, tensor<4x1xf32>)
+   outs(%fill : tensor<4x4xf32>) {
+    ^bb0(%arg7: f32, %arg8: f32, %arg9: f32):
+      %40 = subf %arg7, %arg8 : f32
+      linalg.yield %40 : f32
+    } -> tensor<4x4xf32>
+  return %red : tensor<4x4xf32>
+}
+
+// -----
+
+// CHECK-DAG: #[[$M6:.*]] = affine_map<(d0, d1) -> (d0, 0)>
+// CHECK-DAG: #[[$M7:.*]] = affine_map<(d0) -> (d0, 0)>
+
+// CHECK-LABEL:   func @fused_broadcast_red_2d
+func @fused_broadcast_red_2d(%arg0: tensor<4x4xf32>, %arg1: tensor<4x1xf32>) -> tensor<4xf32> {
+  // CHECK: vector.transfer_read {{.*}} {in_bounds = [true, true]} : tensor<4x4xf32>, vector<4x4xf32>
+  // CHECK: vector.transfer_read {{.*}} {in_bounds = [true, true], permutation_map = #[[$M6]]} : tensor<4x1xf32>, vector<4x4xf32>
+  // CHECK: vector.transfer_read {{.*}} {in_bounds = [true, true], permutation_map = #[[$M7]]} : tensor<4xf32>, vector<4x4xf32>
+  // CHECK: subf {{.*}} : vector<4x4xf32>
+  // CHECK: math.exp {{.*}} : vector<4x4xf32>
+  // CHECK: addf {{.*}} : vector<4x4xf32>
+  // CHECK: vector.multi_reduction #vector.kind<add>, {{.*}} : vector<4x4xf32> to vector<4xf32>
+  // CHECK: vector.transfer_write {{.*}} {in_bounds = [true]} : vector<4xf32>, tensor<4xf32>
+  %c0 = constant 0.0 : f32
+  %init = linalg.init_tensor [4] : tensor<4xf32>
+  %fill = linalg.fill(%c0, %init) : f32, tensor<4xf32> -> tensor<4xf32>
+  %red = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+                                          affine_map<(d0, d1) -> (d0, 0)>,
+                                          affine_map<(d0, d1) -> (d0)>],
+   iterator_types = ["parallel", "reduction"]}
+   ins(%arg0, %arg1 : tensor<4x4xf32>, tensor<4x1xf32>)
+   outs(%fill : tensor<4xf32>) {
+    ^bb0(%arg7: f32, %arg8: f32, %arg9: f32):
+      %40 = subf %arg7, %arg8 : f32
+      %41 = math.exp %40 : f32
+      %42 = addf %41, %arg9 : f32
+      linalg.yield %42 : f32
+    } -> tensor<4xf32>
+  return %red : tensor<4xf32>
+}
+
+// -----
+
 //  CHECK-LABEL: func @reduce_1d(
 //   CHECK-SAME:   %[[A:.*]]: tensor<32xf32>
 func @reduce_1d(%arg0: tensor<32xf32>) -> tensor<f32> {
@@ -899,4 +958,3 @@ func @reduce_1d(%arg0: tensor<32xf32>) -> tensor<f32> {
 
   return %2 : tensor<f32>
 }
-