[mlir][sparse] enable more sparse convolution kernels.
authorPeiming Liu <peiming@google.com>
Tue, 4 Apr 2023 18:41:00 +0000 (18:41 +0000)
committerPeiming Liu <peiming@google.com>
Mon, 17 Apr 2023 17:43:52 +0000 (17:43 +0000)
Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D147670

mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conv_1d_nwc_wcf.mlir
mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conv_2d_nhwc_hwcf.mlir
mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conv_3d_ndhwc_dhwcf.mlir

index d031868..f449e93 100644 (file)
@@ -1521,65 +1521,69 @@ std::pair<Operation *, ValueRange> LoopEmitter::genSliceLvlTraverseLoop(
 // }
 ValueRange LoopEmitter::genUnResolvedSliceTreeTraverse(
     OpBuilder &builder, Location loc, TensorId tid,
-    ArrayRef<const SliceInfo *> unResLvls, ValueRange userReduc,
+    ArrayRef<const SliceInfo *> unResLvls,
+    std::optional<std::pair<TensorId, Level>> firstResLvl, ValueRange userReduc,
     LoopBodyBuilder bodyBuilder) {
-  // assert(unResLvls.size() == 1 && "TODO");
-  Value c0 = C_IDX(0), c1 = C_IDX(1), c2 = C_IDX(2);
-
-  const SliceInfo &frontSlice = *unResLvls.back();
-  Level firstLvl = *frontSlice.slicedOnLvl;
-  assert(!lvlFullyResolved(tid, firstLvl) && "TODO");
 
-  // FIXME: it is not zero when the first level is fully resolved.
+  Value c0 = C_IDX(0), c1 = C_IDX(1), c2 = C_IDX(2);
   Value pos = c0;
   OpBuilder::InsertPoint ip;
   SmallVector<Value> innerArgs(userReduc.begin(), userReduc.end());
-  scf::ForOp outerMost = nullptr;
-  if (!lvlFullyResolved(tid, firstLvl)) {
-    if (isCompressedDLT(lvlTypes[tid][firstLvl])) {
-      unsigned depth = frontSlice.depth - 1;
-      Value offset = frontSlice.offset;
-      Value sPtrBuf = slicePosBuffer[tid][firstLvl][depth];
-      Value mSz = genIndexLoad(builder, loc, sPtrBuf, c0); // memSize
-      outerMost = builder.create<scf::ForOp>(
-          loc, c2, mSz, c2, innerArgs,
-          [this, c1, tid, firstLvl, offset, sPtrBuf, &ip, &pos, &innerArgs](
-              OpBuilder &builder, Location loc, Value iv, ValueRange iterArgs) {
-            // generate traversal for each level.
-            Value loopLo = genIndexLoad(builder, loc, sPtrBuf, iv);
-            Value loopHi = genIndexLoad(builder, loc, sPtrBuf, ADDI(iv, c1));
-            ValueRange itArgs =
-                genSliceLvlTraverseLoop(
-                    builder, loc, loopLo, loopHi, offset,
-                    sliceSizes[tid][firstLvl].back(), tid, firstLvl, iterArgs,
-                    false,
-                    [&](OpBuilder &builder, Location, Value iv,
-                        MutableArrayRef<Value> reduc) {
-                      ip = builder.saveInsertionPoint();
-                      pos = iv;
-                      innerArgs.assign(reduc.begin(), reduc.end());
-                    })
-                    .second;
-            YIELD(itArgs);
-          });
-    } else if (isDenseDLT(lvlTypes[tid][firstLvl])) {
-      assert(firstLvl == 0); // This must be the first level.
-      Value lb = frontSlice.offset;
-      Value sliceSz =
-          sliceSizes[tid][*frontSlice.slicedOnLvl][frontSlice.depth - 1];
-      Value ub = ADDI(lb, sliceSz);
-      outerMost = builder.create<scf::ForOp>(
-          loc, lb, ub, c1, innerArgs,
-          [&](OpBuilder &builder, Location loc, Value iv, ValueRange iterArgs) {
-            ip = builder.saveInsertionPoint();
-            pos = iv;
-            innerArgs.assign(iterArgs.begin(), iterArgs.end());
-          });
+  scf::ForOp outerMost = nullptr; // the outtermost loop.
+  if (firstResLvl.has_value()) {
+    // Overwrite position when the first level is fully resolved.
+    pos = posits[firstResLvl->first][firstResLvl->second];
+    ip = builder.saveInsertionPoint();
+  } else {
+    const SliceInfo &frontSlice = *unResLvls.back();
+    Level firstLvl = *frontSlice.slicedOnLvl;
+    if (!lvlFullyResolved(tid, firstLvl)) {
+      if (isCompressedDLT(lvlTypes[tid][firstLvl])) {
+        unsigned depth = frontSlice.depth - 1;
+        Value offset = frontSlice.offset;
+        Value sPtrBuf = slicePosBuffer[tid][firstLvl][depth];
+        Value mSz = genIndexLoad(builder, loc, sPtrBuf, c0); // memSize
+        outerMost = builder.create<scf::ForOp>(
+            loc, c2, mSz, c2, innerArgs,
+            [this, c1, tid, firstLvl, offset, sPtrBuf, &ip, &pos,
+             &innerArgs](OpBuilder &builder, Location loc, Value iv,
+                         ValueRange iterArgs) {
+              // generate traversal for each level.
+              Value loopLo = genIndexLoad(builder, loc, sPtrBuf, iv);
+              Value loopHi = genIndexLoad(builder, loc, sPtrBuf, ADDI(iv, c1));
+              ValueRange itArgs =
+                  genSliceLvlTraverseLoop(
+                      builder, loc, loopLo, loopHi, offset,
+                      sliceSizes[tid][firstLvl].back(), tid, firstLvl, iterArgs,
+                      false,
+                      [&](OpBuilder &builder, Location, Value iv,
+                          MutableArrayRef<Value> reduc) {
+                        ip = builder.saveInsertionPoint();
+                        pos = iv;
+                        innerArgs.assign(reduc.begin(), reduc.end());
+                      })
+                      .second;
+              YIELD(itArgs);
+            });
+      } else if (isDenseDLT(lvlTypes[tid][firstLvl])) {
+        assert(firstLvl == 0); // This must be the first level.
+        Value lb = frontSlice.offset;
+        Value sliceSz =
+            sliceSizes[tid][*frontSlice.slicedOnLvl][frontSlice.depth - 1];
+        Value ub = ADDI(lb, sliceSz);
+        outerMost = builder.create<scf::ForOp>(
+            loc, lb, ub, c1, innerArgs,
+            [&](OpBuilder &builder, Location loc, Value iv,
+                ValueRange iterArgs) {
+              ip = builder.saveInsertionPoint();
+              pos = iv;
+              innerArgs.assign(iterArgs.begin(), iterArgs.end());
+            });
+      }
+      // We generated the loop for the first slice above, now remove it.
+      unResLvls = unResLvls.drop_back();
     }
-    // We generated the loop for the first slice above, now remove it.
-    unResLvls = unResLvls.drop_back();
   }
-
   // Reset the insertion point into the loop body.
   builder.restoreInsertionPoint(ip);
   if (!unResLvls.empty()) {
@@ -1611,12 +1615,21 @@ ValueRange LoopEmitter::genUnResolvedSliceTreeTraverse(
                              bodyBuilder(builder, loc, pos, innerArgs);
                              return innerArgs;
                            });
-    YIELD(denseNest.results);
+
+    if (!outerMost) {
+      // If the outermost loop has not been set, this is the outermost loop.
+      outerMost = denseNest.loops.front();
+    } else {
+      // Otherwise we need to generate yield operations to link the SSA chain.
+      YIELD(denseNest.results);
+    }
   } else {
+    assert(outerMost);
     // Generates user request loop body.
     bodyBuilder(builder, loc, pos, innerArgs);
     YIELD(innerArgs);
   }
+  assert(outerMost);
   // Insert after current while operation.
   builder.setInsertionPointAfter(outerMost);
   return outerMost.getResults();
@@ -1624,7 +1637,6 @@ ValueRange LoopEmitter::genUnResolvedSliceTreeTraverse(
 
 void LoopEmitter::genResolvedSliceBegin(OpBuilder &builder, Location loc,
                                         TensorId tid, Level lvl) {
-  assert(lvl == 0 && "TODO: handle non-first level");
   Value c0 = C_IDX(0), c1 = C_IDX(1), c2 = C_IDX(2), c3 = C_IDX(3),
         c4 = C_IDX(4);
   if (isDenseDLT(lvlTypes[tid][lvl])) {
@@ -1634,14 +1646,23 @@ void LoopEmitter::genResolvedSliceBegin(OpBuilder &builder, Location loc,
                                  lvl, /*depth=*/1);
     return;
   }
-  Value size = sliceSizes[tid][0][0];
-  Value sPtrBuf = slicePosBuffer[tid][0][0];
-  Value pHi = genIndexLoad(builder, loc, positionsBuffers[tid][0], c1);
+  Value size = sliceSizes[tid][lvl][0];
+  Value sPtrBuf = slicePosBuffer[tid][lvl][0];
+  Value pHi, pLo;
+  if (lvl == 0) {
+    pLo = c0;
+    pHi = genIndexLoad(builder, loc, positionsBuffers[tid][0], c1);
+  } else {
+    pLo = genIndexLoad(builder, loc, positionsBuffers[tid][lvl],
+                       posits[tid][lvl - 1]);
+    pHi = genIndexLoad(builder, loc, positionsBuffers[tid][lvl],
+                       ADDI(posits[tid][lvl - 1], c1));
+  }
   // Fills out pIdxBuffer[tid][lvl][0] with [/*memSize =*/4, 0, 0, pHi]
   builder.create<memref::StoreOp>(loc, c4, sPtrBuf, c0);  // memSize = 4
   builder.create<memref::StoreOp>(loc, c0, sPtrBuf, c1);  // index = 0
-  builder.create<memref::StoreOp>(loc, c0, sPtrBuf, c2);  // pLo = 0;
-  builder.create<memref::StoreOp>(loc, pHi, sPtrBuf, c3); // loaded pHi.
+  builder.create<memref::StoreOp>(loc, pLo, sPtrBuf, c2); // pLo
+  builder.create<memref::StoreOp>(loc, pHi, sPtrBuf, c3); // pHi
 
   // This is an non empty tensor if 0 < pHi.
   Value isNonEmpty = CMPI(ult, c0, pHi);
@@ -1703,10 +1724,15 @@ void LoopEmitter::genUnResolvedSliceBegin(OpBuilder &builder, Location loc,
   assert(slicePosBuffer[tid][lvl - 1].size() == sliceStack[tid].back().depth);
 
   SmallVector<const SliceInfo *> unResSlices;
+  std::optional<std::pair<TensorId, Level>> firstResLvl;
   for (Level curLvl = lvl; curLvl >= 1; curLvl--) {
     Level prevLvl = curLvl - 1;
+    if (lvlFullyResolved(tid, prevLvl)) {
+      firstResLvl = std::make_pair(tid, prevLvl);
+      break;
+    }
     unResSlices.push_back(&getMostRecentSliceOnLvl(tid, prevLvl));
-    if (!isDenseDLT(lvlTypes[tid][prevLvl]) || lvlFullyResolved(tid, prevLvl)) {
+    if (!isDenseDLT(lvlTypes[tid][prevLvl])) {
       break;
     }
   }
@@ -1722,7 +1748,7 @@ void LoopEmitter::genUnResolvedSliceBegin(OpBuilder &builder, Location loc,
   };
 
   ValueRange result = genUnResolvedSliceTreeTraverse(
-      builder, loc, tid, unResSlices, reduc,
+      builder, loc, tid, unResSlices, firstResLvl, reduc,
       [this, c1, c2, tid, lvl, sPtrBuf](OpBuilder &builder, Location loc,
                                         Value iv,
                                         MutableArrayRef<Value> reduc) {
@@ -1869,7 +1895,7 @@ bool LoopEmitter::genSliceBegin(OpBuilder &builder, Location loc, TensorId tid,
 void LoopEmitter::invalidateSliceIterIdx(OpBuilder &builder, Location loc,
                                          TensorId tid, Level lvl) {
   for (unsigned i = 0; i <= lvl; i++) {
-    if (!isDenseDLT(lvlTypes[tid][i])) {
+    if (!isDenseDLT(lvlTypes[tid][i]) && !dependentLvlMap[tid][i].empty()) {
       builder.create<memref::StoreOp>(loc, C_IDX(0),
                                       slicePosBuffer[tid][i].back(), C_IDX(1));
     }
index 554f24b..d069633 100644 (file)
@@ -452,11 +452,11 @@ private:
 
   /// Generates a nested loop that iterates over tid on all the coordinates on
   /// lvl.
-  ValueRange
-  genUnResolvedSliceTreeTraverse(OpBuilder &builder, Location loc, TensorId tid,
-                                 ArrayRef<const SliceInfo *> unResLvls,
-                                 ValueRange userReduc,
-                                 LoopBodyBuilder bodyBuilder);
+  ValueRange genUnResolvedSliceTreeTraverse(
+      OpBuilder &builder, Location loc, TensorId tid,
+      ArrayRef<const SliceInfo *> unResLvls,
+      std::optional<std::pair<TensorId, Level>> firstResLvl,
+      ValueRange userReduc, LoopBodyBuilder bodyBuilder);
 
   /// Generates code to get the first non-empty slice of tid on lvl, when all
   /// the previous level before `lvl` are resolved (or lvl is the first level).
index 94d8a2f..d7a4f23 100644 (file)
@@ -1,9 +1,4 @@
-// UNSUPPORTED: target={{.*}}
-// FIXME: The test case is disabled (for now) because affine index on sparse tensor
-// are not handled efficiently by sparse compiler, the test case will be re-enabled
-// after new algorithm is implemented.
-
-// DEFINE: %{option} = enable-runtime-library=true
+// DEFINE: %{option} = "enable-runtime-library=true enable-index-reduction=true"
 // DEFINE: %{compile} = mlir-opt %s --sparse-compiler=%{option}
 // DEFINE: %{run} = mlir-cpu-runner \
 // DEFINE:  -e entry -entry-point-result=void  \
 // RUN: %{compile} | %{run}
 //
 // Do the same run, but now with direct IR generation.
-// REDEFINE: %{option} = "enable-runtime-library=false enable-buffer-initialization=true"
+// REDEFINE: %{option} = "enable-runtime-library=false enable-buffer-initialization=true enable-index-reduction=true"
 // RUN: %{compile} | %{run}
 //
 // Do the same run, but now with direct IR generation and vectorization.
-// REDEFINE: %{option} = "enable-runtime-library=false enable-buffer-initialization=true vl=2 reassociate-fp-reductions=true enable-index-optimizations=true"
+// REDEFINE: %{option} = "enable-runtime-library=false enable-buffer-initialization=true vl=2 reassociate-fp-reductions=true enable-index-optimizations=true enable-index-reduction=true"
 // RUN: %{compile} | %{run}
 
 // Do the same run, but now with direct IR generation and, if available, VLA
 // vectorization.
-// REDEFINE: %{option} = "enable-runtime-library=false vl=4 enable-arm-sve=%ENABLE_VLA"
+// REDEFINE: %{option} = "enable-runtime-library=false vl=4 enable-arm-sve=%ENABLE_VLA enable-index-reduction=true"
 // REDEFINE: %{run} = %lli \
 // REDEFINE:   --entry-function=entry_lli \
 // REDEFINE:   --extra-module=%S/Inputs/main_for_lli.ll \
@@ -55,26 +50,26 @@ func.func @conv_1d_nwc_wcf(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>, %
   return %ret : tensor<?x?x?xf32>
 }
 
-func.func @conv_1d_nwc_wcf_CCC(%arg0: tensor<?x?x?xf32, #CCC>, %arg1: tensor<?x?x?xf32, #CCC>) -> tensor<?x?x?xf32, #CCC> {
+func.func @conv_1d_nwc_wcf_CCC(%arg0: tensor<?x?x?xf32, #CCC>, %arg1: tensor<?x?x?xf32>) -> tensor<?x?x?xf32, #CCC> {
   %c1 = arith.constant 1 : index
   %c3 = arith.constant 3 : index
   %c6 = arith.constant 6 : index
   %s = bufferization.alloc_tensor(%c3, %c6, %c1) : tensor<?x?x?xf32, #CCC>
   %ret = linalg.conv_1d_nwc_wcf {dilations = dense<1> : tensor<1xi64>,
                                    strides = dense<1> : tensor<1xi64>}
-     ins (%arg0, %arg1: tensor<?x?x?xf32, #CCC>, tensor<?x?x?xf32, #CCC>)
+     ins (%arg0, %arg1: tensor<?x?x?xf32, #CCC>, tensor<?x?x?xf32>)
     outs (%s: tensor<?x?x?xf32, #CCC>) -> tensor<?x?x?xf32, #CCC>
   return %ret : tensor<?x?x?xf32, #CCC>
 }
 
-func.func @conv_1d_nwc_wcf_CDC(%arg0: tensor<?x?x?xf32, #CDC>, %arg1: tensor<?x?x?xf32, #CDC>) -> tensor<?x?x?xf32, #CDC> {
+func.func @conv_1d_nwc_wcf_CDC(%arg0: tensor<?x?x?xf32, #CDC>, %arg1: tensor<?x?x?xf32>) -> tensor<?x?x?xf32, #CDC> {
   %c1 = arith.constant 1 : index
   %c3 = arith.constant 3 : index
   %c6 = arith.constant 6 : index
   %s = bufferization.alloc_tensor(%c3, %c6, %c1) : tensor<?x?x?xf32, #CDC>
   %ret = linalg.conv_1d_nwc_wcf {dilations = dense<1> : tensor<1xi64>,
                                    strides = dense<1> : tensor<1xi64>}
-     ins (%arg0, %arg1: tensor<?x?x?xf32, #CDC>, tensor<?x?x?xf32, #CDC>)
+     ins (%arg0, %arg1: tensor<?x?x?xf32, #CDC>, tensor<?x?x?xf32>)
     outs (%s: tensor<?x?x?xf32, #CDC>) -> tensor<?x?x?xf32, #CDC>
   return %ret : tensor<?x?x?xf32, #CDC>
 }
@@ -91,22 +86,18 @@ func.func @entry() {
 
   %in1D_tmp = call @alloc_3d_filled_f32(%c3, %c8, %c1, %val) : (index, index, index, f32) -> (tensor<?x?x?xf32>)
   %in1D_nwc = tensor.insert %f10 into %in1D_tmp[%c0, %c3, %c0] : tensor<?x?x?xf32>
+
   %filter1D_nwc = call @alloc_3d_filled_f32(%c3, %c1, %c1, %val) : (index, index, index, f32) -> (tensor<?x?x?xf32>)
   %out1D_nwc = call @alloc_3d_filled_f32(%c3, %c6, %c1, %zero) : (index, index, index, f32) -> (tensor<?x?x?xf32>)
 
   %in1D_nwc_CCC = sparse_tensor.convert %in1D_nwc
     : tensor<?x?x?xf32> to tensor<?x?x?xf32, #CCC>
-  %filter1D_nwc_CCC = sparse_tensor.convert %filter1D_nwc
-    : tensor<?x?x?xf32> to tensor<?x?x?xf32, #CCC>
-
   %in1D_nwc_CDC = sparse_tensor.convert %in1D_nwc
     : tensor<?x?x?xf32> to tensor<?x?x?xf32, #CDC>
-  %filter1D_nwc_CDC = sparse_tensor.convert %filter1D_nwc
-    : tensor<?x?x?xf32> to tensor<?x?x?xf32, #CDC>
 
   %dense_ret = call @conv_1d_nwc_wcf(%in1D_nwc, %filter1D_nwc, %out1D_nwc) : (tensor<?x?x?xf32>, tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> (tensor<?x?x?xf32>)
-  %CCC_ret = call @conv_1d_nwc_wcf_CCC(%in1D_nwc_CCC, %filter1D_nwc_CCC) : (tensor<?x?x?xf32, #CCC>, tensor<?x?x?xf32, #CCC>) -> (tensor<?x?x?xf32, #CCC>)
-  %CDC_ret = call @conv_1d_nwc_wcf_CDC(%in1D_nwc_CDC, %filter1D_nwc_CDC) : (tensor<?x?x?xf32, #CDC>, tensor<?x?x?xf32, #CDC>) -> (tensor<?x?x?xf32, #CDC>)
+  %CCC_ret = call @conv_1d_nwc_wcf_CCC(%in1D_nwc_CCC, %filter1D_nwc) : (tensor<?x?x?xf32, #CCC>, tensor<?x?x?xf32>) -> (tensor<?x?x?xf32, #CCC>)
+  %CDC_ret = call @conv_1d_nwc_wcf_CDC(%in1D_nwc_CDC, %filter1D_nwc) : (tensor<?x?x?xf32, #CDC>, tensor<?x?x?xf32>) -> (tensor<?x?x?xf32, #CDC>)
 
   //      CHECK: ( ( ( 12 ), ( 28 ), ( 28 ), ( 28 ), ( 12 ), ( 12 ) ),
   // CHECK-SAME:   ( ( 12 ), ( 12 ), ( 12 ), ( 12 ), ( 12 ), ( 12 ) ),
@@ -139,9 +130,7 @@ func.func @entry() {
   bufferization.dealloc_tensor %out1D_nwc : tensor<?x?x?xf32>
 
   bufferization.dealloc_tensor %in1D_nwc_CDC : tensor<?x?x?xf32, #CDC>
-  bufferization.dealloc_tensor %filter1D_nwc_CDC : tensor<?x?x?xf32, #CDC>
   bufferization.dealloc_tensor %in1D_nwc_CCC : tensor<?x?x?xf32, #CCC>
-  bufferization.dealloc_tensor %filter1D_nwc_CCC : tensor<?x?x?xf32, #CCC>
 
   bufferization.dealloc_tensor %CCC_ret : tensor<?x?x?xf32, #CCC>
   bufferization.dealloc_tensor %CDC_ret : tensor<?x?x?xf32, #CDC>
index c7f8cfd..933272c 100644 (file)
@@ -1,9 +1,4 @@
-// UNSUPPORTED: target={{.*}}
-// FIXME: The test case is disabled (for now) because affine index on sparse tensor
-// are not handled efficiently by sparse compiler, the test case will be re-enabled
-// after new algorithm is implemented.
-
-// DEFINE: %{option} = enable-runtime-library=true
+// DEFINE: %{option} = "enable-runtime-library=true enable-index-reduction=true"
 // DEFINE: %{compile} = mlir-opt %s --sparse-compiler=%{option}
 // DEFINE: %{run} = mlir-cpu-runner \
 // DEFINE:  -e entry -entry-point-result=void  \
 // RUN: %{compile} | %{run}
 //
 // Do the same run, but now with direct IR generation.
-// REDEFINE: %{option} = "enable-runtime-library=false enable-buffer-initialization=true"
+// REDEFINE: %{option} = "enable-runtime-library=false enable-buffer-initialization=true enable-index-reduction=true"
 // RUN: %{compile} | %{run}
 //
 // Do the same run, but now with direct IR generation and vectorization.
-// REDEFINE: %{option} = "enable-runtime-library=false enable-buffer-initialization=true vl=2 reassociate-fp-reductions=true enable-index-optimizations=true"
+// REDEFINE: %{option} = "enable-runtime-library=false enable-buffer-initialization=true vl=2 reassociate-fp-reductions=true enable-index-optimizations=true  enable-index-reduction=true"
 // RUN: %{compile} | %{run}
 
 // Do the same run, but now with direct IR generation and, if available, VLA
 // vectorization.
-// REDEFINE: %{option} = "enable-runtime-library=false vl=4 enable-arm-sve=%ENABLE_VLA"
+// REDEFINE: %{option} = "enable-runtime-library=false vl=4 enable-arm-sve=%ENABLE_VLA  enable-index-reduction=true"
 // REDEFINE: %{run} = %lli \
 // REDEFINE:   --entry-function=entry_lli \
 // REDEFINE:   --extra-module=%S/Inputs/main_for_lli.ll \
@@ -54,26 +49,26 @@ func.func @conv_2d_nhwc_hwcf(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<?x?x?x?xf
   return %ret : tensor<?x?x?x?xf32>
 }
 
-func.func @conv_2d_nhwc_hwcf_CCCC(%arg0: tensor<?x?x?x?xf32, #CCCC>, %arg1: tensor<?x?x?x?xf32, #CCCC>) -> tensor<?x?x?x?xf32, #CCCC> {
+func.func @conv_2d_nhwc_hwcf_CCCC(%arg0: tensor<?x?x?x?xf32, #CCCC>, %arg1: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32, #CCCC> {
   %c1 = arith.constant 1 : index
   %c3 = arith.constant 3 : index
   %c6 = arith.constant 6 : index
   %s = bufferization.alloc_tensor(%c3, %c6, %c6, %c1) : tensor<?x?x?x?xf32, #CCCC>
   %ret = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>,
                                      strides = dense<1> : tensor<2xi64>}
-     ins (%arg0, %arg1: tensor<?x?x?x?xf32, #CCCC>, tensor<?x?x?x?xf32, #CCCC>)
+     ins (%arg0, %arg1: tensor<?x?x?x?xf32, #CCCC>, tensor<?x?x?x?xf32>)
     outs (%s: tensor<?x?x?x?xf32, #CCCC>) -> tensor<?x?x?x?xf32, #CCCC>
   return %ret : tensor<?x?x?x?xf32, #CCCC>
 }
 
-func.func @conv_2d_nhwc_hwcf_CDCD(%arg0: tensor<?x?x?x?xf32, #CDCD>, %arg1: tensor<?x?x?x?xf32, #CDCD>) -> tensor<?x?x?x?xf32, #CDCD> {
+func.func @conv_2d_nhwc_hwcf_CDCD(%arg0: tensor<?x?x?x?xf32, #CDCD>, %arg1: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32, #CDCD> {
   %c1 = arith.constant 1 : index
   %c3 = arith.constant 3 : index
   %c6 = arith.constant 6 : index
   %s = bufferization.alloc_tensor(%c3, %c6, %c6, %c1) : tensor<?x?x?x?xf32, #CDCD>
   %ret = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>,
                                      strides = dense<1> : tensor<2xi64>}
-     ins (%arg0, %arg1: tensor<?x?x?x?xf32, #CDCD>, tensor<?x?x?x?xf32, #CDCD>)
+     ins (%arg0, %arg1: tensor<?x?x?x?xf32, #CDCD>, tensor<?x?x?x?xf32>)
     outs (%s: tensor<?x?x?x?xf32, #CDCD>) -> tensor<?x?x?x?xf32, #CDCD>
   return %ret : tensor<?x?x?x?xf32, #CDCD>
 }
@@ -95,17 +90,12 @@ func.func @entry() {
 
   %in2D_nhwc_CCCC = sparse_tensor.convert %in2D_nhwc
     : tensor<?x?x?x?xf32> to tensor<?x?x?x?xf32, #CCCC>
-  %filter2D_nhwc_CCCC = sparse_tensor.convert %filter2D_nhwc
-    : tensor<?x?x?x?xf32> to tensor<?x?x?x?xf32, #CCCC>
-
   %in2D_nhwc_CDCD = sparse_tensor.convert %in2D_nhwc
     : tensor<?x?x?x?xf32> to tensor<?x?x?x?xf32, #CDCD>
-  %filter2D_nhwc_CDCD = sparse_tensor.convert %filter2D_nhwc
-    : tensor<?x?x?x?xf32> to tensor<?x?x?x?xf32, #CDCD>
 
   %dense_ret = call @conv_2d_nhwc_hwcf(%in2D_nhwc, %filter2D_nhwc, %out2D_nhwc) : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> (tensor<?x?x?x?xf32>)
-  %CCCC_ret = call @conv_2d_nhwc_hwcf_CCCC(%in2D_nhwc_CCCC, %filter2D_nhwc_CCCC) : (tensor<?x?x?x?xf32, #CCCC>, tensor<?x?x?x?xf32, #CCCC>) -> (tensor<?x?x?x?xf32, #CCCC>)
-  %CDCD_ret = call @conv_2d_nhwc_hwcf_CDCD(%in2D_nhwc_CDCD, %filter2D_nhwc_CDCD) : (tensor<?x?x?x?xf32, #CDCD>, tensor<?x?x?x?xf32, #CDCD>) -> (tensor<?x?x?x?xf32, #CDCD>)
+  %CCCC_ret = call @conv_2d_nhwc_hwcf_CCCC(%in2D_nhwc_CCCC, %filter2D_nhwc) : (tensor<?x?x?x?xf32, #CCCC>, tensor<?x?x?x?xf32>) -> (tensor<?x?x?x?xf32, #CCCC>)
+  %CDCD_ret = call @conv_2d_nhwc_hwcf_CDCD(%in2D_nhwc_CDCD, %filter2D_nhwc) : (tensor<?x?x?x?xf32, #CDCD>, tensor<?x?x?x?xf32>) -> (tensor<?x?x?x?xf32, #CDCD>)
 
   // CHECK:     ( ( ( ( 108 ), ( 124 ), ( 124 ), ( 124 ), ( 108 ), ( 108 ) ),
   // CHECK-SAME:    ( ( 108 ), ( 108 ), ( 108 ), ( 108 ), ( 108 ), ( 108 ) ),
@@ -183,9 +173,7 @@ func.func @entry() {
   bufferization.dealloc_tensor %out2D_nhwc : tensor<?x?x?x?xf32>
 
   bufferization.dealloc_tensor %in2D_nhwc_CDCD : tensor<?x?x?x?xf32, #CDCD>
-  bufferization.dealloc_tensor %filter2D_nhwc_CDCD : tensor<?x?x?x?xf32, #CDCD>
   bufferization.dealloc_tensor %in2D_nhwc_CCCC : tensor<?x?x?x?xf32, #CCCC>
-  bufferization.dealloc_tensor %filter2D_nhwc_CCCC : tensor<?x?x?x?xf32, #CCCC>
 
   bufferization.dealloc_tensor %CCCC_ret : tensor<?x?x?x?xf32, #CCCC>
   bufferization.dealloc_tensor %CDCD_ret : tensor<?x?x?x?xf32, #CDCD>
index f6363e1..5553f27 100644 (file)
@@ -1,9 +1,4 @@
-// UNSUPPORTED: target={{.*}}
-// FIXME: The test case is disabled (for now) because affine index on sparse tensor
-// are not handled efficiently by sparse compiler, the test case will be re-enabled
-// after new algorithm is implemented.
-
-// DEFINE: %{option} = enable-runtime-library=true
+// DEFINE: %{option} = "enable-runtime-library=true enable-index-reduction=true"
 // DEFINE: %{compile} = mlir-opt %s --sparse-compiler=%{option}
 // DEFINE: %{run} = mlir-cpu-runner \
 // DEFINE:  -e entry -entry-point-result=void  \
 // RUN: %{compile} | %{run}
 //
 // Do the same run, but now with direct IR generation.
-// REDEFINE: %{option} = "enable-runtime-library=false enable-buffer-initialization=true"
+// REDEFINE: %{option} = "enable-runtime-library=false enable-buffer-initialization=true enable-index-reduction=true"
 // RUN: %{compile} | %{run}
 //
 // Do the same run, but now with direct IR generation and vectorization.
-// REDEFINE: %{option} = "enable-runtime-library=false enable-buffer-initialization=true vl=2 reassociate-fp-reductions=true enable-index-optimizations=true"
+// REDEFINE: %{option} = "enable-runtime-library=false enable-buffer-initialization=true vl=2 reassociate-fp-reductions=true enable-index-optimizations=true enable-index-reduction=true"
 // RUN: %{compile} | %{run}
 
 // Do the same run, but now with direct IR generation and, if available, VLA
 // vectorization.
-// REDEFINE: %{option} = "enable-runtime-library=false vl=4 enable-arm-sve=%ENABLE_VLA"
+// REDEFINE: %{option} = "enable-runtime-library=false vl=4 enable-arm-sve=%ENABLE_VLA enable-index-reduction=true"
 // REDEFINE: %{run} = %lli \
 // REDEFINE:   --entry-function=entry_lli \
 // REDEFINE:   --extra-module=%S/Inputs/main_for_lli.ll \
@@ -57,7 +52,7 @@ func.func @conv_3d_ndhwc_dhwcf(%arg0: tensor<?x?x?x?x?xf32>,
 }
 
 func.func @conv_3d_ndhwc_dhwcf_CCCCC(%arg0: tensor<?x?x?x?x?xf32, #CCCCC>,
-                                     %arg1: tensor<?x?x?x?x?xf32, #CCCCC>)
+                                     %arg1: tensor<?x?x?x?x?xf32>)
                                      -> tensor<?x?x?x?x?xf32, #CCCCC> {
   %c1 = arith.constant 1 : index
   %c6 = arith.constant 6 : index
@@ -65,13 +60,13 @@ func.func @conv_3d_ndhwc_dhwcf_CCCCC(%arg0: tensor<?x?x?x?x?xf32, #CCCCC>,
     : tensor<?x?x?x?x?xf32, #CCCCC>
   %ret = linalg.conv_3d_ndhwc_dhwcf {dilations = dense<1> : tensor<3xi64>,
                                        strides = dense<1> : tensor<3xi64>}
-     ins (%arg0, %arg1: tensor<?x?x?x?x?xf32, #CCCCC>, tensor<?x?x?x?x?xf32, #CCCCC>)
+     ins (%arg0, %arg1: tensor<?x?x?x?x?xf32, #CCCCC>, tensor<?x?x?x?x?xf32>)
     outs (%s: tensor<?x?x?x?x?xf32, #CCCCC>) -> tensor<?x?x?x?x?xf32, #CCCCC>
   return %ret : tensor<?x?x?x?x?xf32, #CCCCC>
 }
 
 func.func @conv_3d_ndhwc_dhwcf_CDCDC(%arg0: tensor<?x?x?x?x?xf32, #CDCDC>,
-                                     %arg1: tensor<?x?x?x?x?xf32, #CDCDC>)
+                                     %arg1: tensor<?x?x?x?x?xf32>)
                                      -> tensor<?x?x?x?x?xf32, #CDCDC> {
   %c1 = arith.constant 1 : index
   %c6 = arith.constant 6 : index
@@ -79,7 +74,7 @@ func.func @conv_3d_ndhwc_dhwcf_CDCDC(%arg0: tensor<?x?x?x?x?xf32, #CDCDC>,
     : tensor<?x?x?x?x?xf32, #CDCDC>
   %ret = linalg.conv_3d_ndhwc_dhwcf {dilations = dense<1> : tensor<3xi64>,
                                        strides = dense<1> : tensor<3xi64>}
-     ins (%arg0, %arg1: tensor<?x?x?x?x?xf32, #CDCDC>, tensor<?x?x?x?x?xf32, #CDCDC>)
+     ins (%arg0, %arg1: tensor<?x?x?x?x?xf32, #CDCDC>, tensor<?x?x?x?x?xf32>)
     outs (%s: tensor<?x?x?x?x?xf32, #CDCDC>) -> tensor<?x?x?x?x?xf32, #CDCDC>
   return %ret : tensor<?x?x?x?x?xf32, #CDCDC>
 }
@@ -102,13 +97,8 @@ func.func @entry() {
 
   %in3D_ndhwc_CCCCC = sparse_tensor.convert %in3D_ndhwc
     : tensor<?x?x?x?x?xf32> to tensor<?x?x?x?x?xf32, #CCCCC>
-  %filter3D_ndhwc_CCCCC = sparse_tensor.convert %filter3D_ndhwc
-    : tensor<?x?x?x?x?xf32> to tensor<?x?x?x?x?xf32, #CCCCC>
-
   %in3D_ndhwc_CDCDC = sparse_tensor.convert %in3D_ndhwc
     : tensor<?x?x?x?x?xf32> to tensor<?x?x?x?x?xf32, #CDCDC>
-  %filter3D_ndhwc_CDCDC = sparse_tensor.convert %filter3D_ndhwc
-    : tensor<?x?x?x?x?xf32> to tensor<?x?x?x?x?xf32, #CDCDC>
 
   //      CHECK:( ( ( ( ( 108 ), ( 124 ), ( 124 ), ( 124 ), ( 108 ), ( 108 ) ),
   // CHECK-SAME:      ( ( 108 ), ( 108 ), ( 108 ), ( 108 ), ( 108 ), ( 108 ) ),
@@ -152,9 +142,9 @@ func.func @entry() {
       : tensor<?x?x?x?x?xf32>, vector<1x6x6x6x1xf32>
   vector.print %dense_v : vector<1x6x6x6x1xf32>
 
-  %CCCCC_ret = call @conv_3d_ndhwc_dhwcf_CCCCC(%in3D_ndhwc_CCCCC, %filter3D_ndhwc_CCCCC)
+  %CCCCC_ret = call @conv_3d_ndhwc_dhwcf_CCCCC(%in3D_ndhwc_CCCCC, %filter3D_ndhwc)
       : (tensor<?x?x?x?x?xf32, #CCCCC>,
-         tensor<?x?x?x?x?xf32, #CCCCC>) -> (tensor<?x?x?x?x?xf32, #CCCCC>)
+         tensor<?x?x?x?x?xf32>) -> (tensor<?x?x?x?x?xf32, #CCCCC>)
 
   // CHECK-NEXT:( ( ( ( ( 108 ), ( 124 ), ( 124 ), ( 124 ), ( 108 ), ( 108 ) ),
   // CHECK-SAME:      ( ( 108 ), ( 108 ), ( 108 ), ( 108 ), ( 108 ), ( 108 ) ),
@@ -198,9 +188,9 @@ func.func @entry() {
       : tensor<?x?x?x?x?xf32>, vector<1x6x6x6x1xf32>
   vector.print %v1 : vector<1x6x6x6x1xf32>
 
-  %CDCDC_ret = call @conv_3d_ndhwc_dhwcf_CDCDC(%in3D_ndhwc_CDCDC, %filter3D_ndhwc_CDCDC)
+  %CDCDC_ret = call @conv_3d_ndhwc_dhwcf_CDCDC(%in3D_ndhwc_CDCDC, %filter3D_ndhwc)
       : (tensor<?x?x?x?x?xf32, #CDCDC>,
-         tensor<?x?x?x?x?xf32, #CDCDC>) -> (tensor<?x?x?x?x?xf32, #CDCDC>)
+         tensor<?x?x?x?x?xf32>) -> (tensor<?x?x?x?x?xf32, #CDCDC>)
 
   // CHECK-NEXT:( ( ( ( ( 108 ), ( 124 ), ( 124 ), ( 124 ), ( 108 ), ( 108 ) ),
   // CHECK-SAME:      ( ( 108 ), ( 108 ), ( 108 ), ( 108 ), ( 108 ), ( 108 ) ),
@@ -250,9 +240,7 @@ func.func @entry() {
   bufferization.dealloc_tensor %out3D_ndhwc : tensor<?x?x?x?x?xf32>
 
   bufferization.dealloc_tensor %in3D_ndhwc_CDCDC : tensor<?x?x?x?x?xf32, #CDCDC>
-  bufferization.dealloc_tensor %filter3D_ndhwc_CDCDC : tensor<?x?x?x?x?xf32, #CDCDC>
   bufferization.dealloc_tensor %in3D_ndhwc_CCCCC : tensor<?x?x?x?x?xf32, #CCCCC>
-  bufferization.dealloc_tensor %filter3D_ndhwc_CCCCC : tensor<?x?x?x?x?xf32, #CCCCC>
 
   bufferization.dealloc_tensor %CCCCC_ret : tensor<?x?x?x?x?xf32, #CCCCC>
   bufferization.dealloc_tensor %CDCDC_ret : tensor<?x?x?x?x?xf32, #CDCDC>