[sparse][mlir][vectorization] add support for shift-by-invariant
authorAart Bik <ajcbik@google.com>
Fri, 23 Dec 2022 01:20:52 +0000 (17:20 -0800)
committerAart Bik <ajcbik@google.com>
Tue, 27 Dec 2022 19:07:13 +0000 (11:07 -0800)
Reviewed By: Peiming

Differential Revision: https://reviews.llvm.org/D140596

mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
mlir/test/Dialect/SparseTensor/sparse_vector_ops.mlir

index 65af4e0..e652ebd 100644 (file)
@@ -50,6 +50,16 @@ static bool isIntValue(Value val, int64_t idx) {
   return false;
 }
 
+/// Helper test for invariant value (defined outside given block).
+static bool isInvariantValue(Value val, Block *block) {
+  return val.getDefiningOp() && val.getDefiningOp()->getBlock() != block;
+}
+
+/// Helper test for invariant argument (defined outside given block).
+static bool isInvariantArg(BlockArgument arg, Block *block) {
+  return arg.getOwner() != block;
+}
+
 /// Constructs vector type for element type.
 static VectorType vectorType(VL vl, Type etp) {
   unsigned numScalableDims = vl.enableVLAVectorization;
@@ -236,13 +246,15 @@ static bool vectorizeSubscripts(PatternRewriter &rewriter, scf::ForOp forOp,
                                 Value vmask, SmallVectorImpl<Value> &idxs) {
   unsigned d = 0;
   unsigned dim = subs.size();
+  Block *block = &forOp.getRegion().front();
   for (auto sub : subs) {
     bool innermost = ++d == dim;
     // Invariant subscripts in outer dimensions simply pass through.
     // Note that we rely on LICM to hoist loads where all subscripts
     // are invariant in the innermost loop.
-    if (sub.getDefiningOp() &&
-        sub.getDefiningOp()->getBlock() != &forOp.getRegion().front()) {
+    // Example:
+    //   a[inv][i] for inv
+    if (isInvariantValue(sub, block)) {
       if (innermost)
         return false;
       if (codegen)
@@ -252,9 +264,10 @@ static bool vectorizeSubscripts(PatternRewriter &rewriter, scf::ForOp forOp,
     // Invariant block arguments (including outer loop indices) in outer
     // dimensions simply pass through. Direct loop indices in the
     // innermost loop simply pass through as well.
-    if (auto barg = sub.dyn_cast<BlockArgument>()) {
-      bool invariant = barg.getOwner() != &forOp.getRegion().front();
-      if (invariant == innermost)
+    // Example:
+    //   a[i][j] for both i and j
+    if (auto arg = sub.dyn_cast<BlockArgument>()) {
+      if (isInvariantArg(arg, block) == innermost)
         return false;
       if (codegen)
         idxs.push_back(sub);
@@ -281,6 +294,8 @@ static bool vectorizeSubscripts(PatternRewriter &rewriter, scf::ForOp forOp,
     // values, there is no good way to state that the indices are unsigned,
     // which creates the potential of incorrect address calculations in the
     // unlikely case we need such extremely large offsets.
+    // Example:
+    //    a[ ind[i] ]
     if (auto load = cast.getDefiningOp<memref::LoadOp>()) {
       if (!innermost)
         return false;
@@ -303,18 +318,20 @@ static bool vectorizeSubscripts(PatternRewriter &rewriter, scf::ForOp forOp,
       continue; // success so far
     }
     // Address calculation 'i = add inv, idx' (after LICM).
+    // Example:
+    //    a[base + i]
     if (auto load = cast.getDefiningOp<arith::AddIOp>()) {
       Value inv = load.getOperand(0);
       Value idx = load.getOperand(1);
-      if (inv.getDefiningOp() &&
-          inv.getDefiningOp()->getBlock() != &forOp.getRegion().front() &&
-          idx.dyn_cast<BlockArgument>()) {
-        if (!innermost)
-          return false;
-        if (codegen)
-          idxs.push_back(
-              rewriter.create<arith::AddIOp>(forOp.getLoc(), inv, idx));
-        continue; // success so far
+      if (isInvariantValue(inv, block)) {
+        if (auto arg = idx.dyn_cast<BlockArgument>()) {
+          if (isInvariantArg(arg, block) || !innermost)
+            return false;
+          if (codegen)
+            idxs.push_back(
+                rewriter.create<arith::AddIOp>(forOp.getLoc(), inv, idx));
+          continue; // success so far
+        }
       }
     }
     return false;
@@ -389,7 +406,8 @@ static bool vectorizeExpr(PatternRewriter &rewriter, scf::ForOp forOp, VL vl,
   }
   // Something defined outside the loop-body is invariant.
   Operation *def = exp.getDefiningOp();
-  if (def->getBlock() != &forOp.getRegion().front()) {
+  Block *block = &forOp.getRegion().front();
+  if (def->getBlock() != block) {
     if (codegen)
       vexp = genVectorInvariantValue(rewriter, vl, exp);
     return true;
@@ -450,6 +468,17 @@ static bool vectorizeExpr(PatternRewriter &rewriter, scf::ForOp forOp, VL vl,
                       vx) &&
         vectorizeExpr(rewriter, forOp, vl, def->getOperand(1), codegen, vmask,
                       vy)) {
+      // We only accept shift-by-invariant (where the same shift factor applies
+      // to all packed elements). In the vector dialect, this is still
+      // represented with an expanded vector at the right-hand-side, however,
+      // so that we do not have to special case the code generation.
+      if (isa<arith::ShLIOp>(def) || isa<arith::ShRUIOp>(def) ||
+          isa<arith::ShRSIOp>(def)) {
+        Value shiftFactor = def->getOperand(1);
+        if (!isInvariantValue(shiftFactor, block))
+          return false;
+      }
+      // Generate code.
       BINOP(arith::MulFOp)
       BINOP(arith::MulIOp)
       BINOP(arith::DivFOp)
@@ -462,8 +491,10 @@ static bool vectorizeExpr(PatternRewriter &rewriter, scf::ForOp forOp, VL vl,
       BINOP(arith::AndIOp)
       BINOP(arith::OrIOp)
       BINOP(arith::XOrIOp)
+      BINOP(arith::ShLIOp)
+      BINOP(arith::ShRUIOp)
+      BINOP(arith::ShRSIOp)
       // TODO: complex?
-      // TODO: shift by invariant?
     }
   }
   return false;
index 32900d9..bf885f1 100644 (file)
@@ -17,6 +17,8 @@
 // CHECK-DAG:       %[[C1:.*]] = arith.constant dense<2.000000e+00> : vector<8xf32>
 // CHECK-DAG:       %[[C2:.*]] = arith.constant dense<1.000000e+00> : vector<8xf32>
 // CHECK-DAG:       %[[C3:.*]] = arith.constant dense<255> : vector<8xi64>
+// CHECK-DAG:       %[[C4:.*]] = arith.constant dense<4> : vector<8xi32>
+// CHECK-DAG:       %[[C5:.*]] = arith.constant dense<1> : vector<8xi32>
 // CHECK:           scf.for
 // CHECK:             %[[VAL_14:.*]] = vector.load
 // CHECK:             %[[VAL_15:.*]] = math.absf %[[VAL_14]] : vector<8xf32>
 // CHECK:             %[[VAL_31:.*]] = arith.andi %[[VAL_30]], %[[C3]] : vector<8xi64>
 // CHECK:             %[[VAL_32:.*]] = arith.trunci %[[VAL_31]] : vector<8xi64> to vector<8xi16>
 // CHECK:             %[[VAL_33:.*]] = arith.extsi %[[VAL_32]] : vector<8xi16> to vector<8xi32>
-// CHECK:             %[[VAL_34:.*]] = arith.uitofp %[[VAL_33]] : vector<8xi32> to vector<8xf32>
-// CHECK:             vector.store %[[VAL_34]]
+// CHECK:             %[[VAL_34:.*]] = arith.shrsi %[[VAL_33]], %[[C4]] : vector<8xi32>
+// CHECK:             %[[VAL_35:.*]] = arith.shrui %[[VAL_34]], %[[C4]] : vector<8xi32>
+// CHECK:             %[[VAL_36:.*]] = arith.shli %[[VAL_35]], %[[C5]] : vector<8xi32>
+// CHECK:             %[[VAL_37:.*]] = arith.uitofp %[[VAL_36]] : vector<8xi32> to vector<8xf32>
+// CHECK:             vector.store %[[VAL_37]]
 // CHECK:           }
 func.func @vops(%arga: tensor<1024xf32, #DenseVector>,
                 %argb: tensor<1024xf32, #DenseVector>) -> tensor<1024xf32> {
@@ -47,6 +52,8 @@ func.func @vops(%arga: tensor<1024xf32, #DenseVector>,
   %o = arith.constant 1.0 : f32
   %c = arith.constant 2.0 : f32
   %i = arith.constant 255 : i64
+  %s = arith.constant 4 : i32
+  %t = arith.constant 1 : i32
   %0 = linalg.generic #trait
     ins(%arga, %argb: tensor<1024xf32, #DenseVector>, tensor<1024xf32, #DenseVector>)
     outs(%init: tensor<1024xf32>) {
@@ -69,8 +76,11 @@ func.func @vops(%arga: tensor<1024xf32, #DenseVector>,
         %15 = arith.andi %14, %i : i64
         %16 = arith.trunci %15 : i64 to i16
         %17 = arith.extsi %16 : i16 to i32
-        %18 = arith.uitofp %17 : i32 to f32
-        linalg.yield %18 : f32
+       %18 = arith.shrsi %17, %s : i32
+       %19 = arith.shrui %18, %s : i32
+       %20 = arith.shli %19, %t : i32
+        %21 = arith.uitofp %20 : i32 to f32
+        linalg.yield %21 : f32
   } -> tensor<1024xf32>
   return %0 : tensor<1024xf32>
 }