[mlir][sparse] use shared util for DimOp generation
authorAart Bik <ajcbik@google.com>
Wed, 18 Aug 2021 17:39:14 +0000 (10:39 -0700)
committerAart Bik <ajcbik@google.com>
Thu, 19 Aug 2021 00:12:32 +0000 (17:12 -0700)
This shares more code with existing utilities. Also, to be consistent,
we moved dimension permutation on the DimOp to the tensor lowering phase.
This way, both pre-existing DimOps on sparse tensors (not likely but
possible) as well as compiler generated DimOps are handled consistently.

Reviewed By: bixia

Differential Revision: https://reviews.llvm.org/D108309

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
mlir/test/Dialect/SparseTensor/conversion.mlir
mlir/test/Dialect/SparseTensor/sparse_perm_lower.mlir [new file with mode: 0644]

index c5f4b07..01b1800 100644 (file)
@@ -232,12 +232,33 @@ public:
   LogicalResult
   matchAndRewrite(tensor::DimOp op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
-    if (!operands[0].getType().isa<LLVM::LLVMPointerType>())
-      return failure();
     Type resType = op.getType();
+    auto enc = getSparseTensorEncoding(op.source().getType());
+    if (!enc)
+      return failure();
+    // Permute the dim index.
+    Optional<int64_t> index = op.getConstantIndex();
+    if (!index.hasValue())
+      return failure();
+    int64_t idx = index.getValue();
+    AffineMap p = enc.getDimOrdering();
+    if (p) {
+      assert(p.isPermutation());
+      for (unsigned i = 0, sz = p.getNumResults(); i < sz; i++) {
+        if (p.getDimPosition(i) == idx) {
+          idx = i;
+          break;
+        }
+      }
+    }
+    // Generate the call.
     StringRef name = "sparseDimSize";
+    SmallVector<Value, 2> params;
+    params.push_back(operands[0]);
+    params.push_back(
+        rewriter.create<ConstantOp>(op.getLoc(), rewriter.getIndexAttr(idx)));
     rewriter.replaceOpWithNewOp<CallOp>(
-        op, resType, getFunc(op, name, resType, operands), operands);
+        op, resType, getFunc(op, name, resType, params), params);
     return success();
   }
 };
index 2567693..2478455 100644 (file)
@@ -282,17 +282,11 @@ static bool genBuffers(Merger &merger, CodeGen &codegen,
         codegen.indices[tensor][idx] =
             rewriter.create<ToIndicesOp>(loc, indTp, t->get(), dim);
       }
-      // Find lower and upper bound in current dimension. Note that a
-      // permuted encoding queries static type dimensions accordingly,
-      // but queries dynamic type dimensions in the generated order.
-      Value up;
+      // Find upper bound in current dimension.
       unsigned p = perm(enc, d);
-      if (shape[p] == MemRefType::kDynamicSize) {
-        up = rewriter.create<tensor::DimOp>(loc, t->get(), d);
+      Value up = linalg::createOrFoldDimOp(rewriter, loc, t->get(), p);
+      if (shape[p] == MemRefType::kDynamicSize)
         args.push_back(up);
-      } else {
-        up = rewriter.create<ConstantIndexOp>(loc, shape[p]);
-      }
       assert(codegen.highs[tensor][idx] == nullptr);
       codegen.sizes[idx] = codegen.highs[tensor][idx] = up;
     }
index a2a2b4d..2b479b5 100644 (file)
   dimOrdering = affine_map<(i,j,k) -> (k,i,j)>
 }>
 
-// CHECK-LABEL: func @sparse_dim(
+// CHECK-LABEL: func @sparse_dim1d(
 //  CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>)
 //       CHECK: %[[C:.*]] = constant 0 : index
 //       CHECK: %[[D:.*]] = call @sparseDimSize(%[[A]], %[[C]])
 //       CHECK: return %[[D]] : index
-func @sparse_dim(%arg0: tensor<?xf64, #SparseVector>) -> index {
+func @sparse_dim1d(%arg0: tensor<?xf64, #SparseVector>) -> index {
   %c = constant 0 : index
   %0 = tensor.dim %arg0, %c : tensor<?xf64, #SparseVector>
   return %0 : index
 }
 
+// CHECK-LABEL: func @sparse_dim3d(
+//  CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>)
+//       CHECK: %[[C:.*]] = constant 2 : index
+//       CHECK: %[[D:.*]] = call @sparseDimSize(%[[A]], %[[C]])
+//       CHECK: return %[[D]] : index
+func @sparse_dim3d(%arg0: tensor<?x?x?xf64, #SparseTensor>) -> index {
+  // Needs permuting 1 into 2.
+  %c = constant 1 : index
+  %0 = tensor.dim %arg0, %c : tensor<?x?x?xf64, #SparseTensor>
+  return %0 : index
+}
+
 // CHECK-LABEL: func @sparse_new1d(
 //  CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>) -> !llvm.ptr<i8>
 //   CHECK-DAG: %[[U:.*]] = constant dense<1> : tensor<1xi8>
diff --git a/mlir/test/Dialect/SparseTensor/sparse_perm_lower.mlir b/mlir/test/Dialect/SparseTensor/sparse_perm_lower.mlir
new file mode 100644 (file)
index 0000000..405b44b
--- /dev/null
@@ -0,0 +1,92 @@
+// RUN: mlir-opt %s -sparsification --canonicalize | FileCheck %s --check-prefix=CHECK-HIR
+//
+// RUN: mlir-opt %s -sparsification --sparse-tensor-conversion --canonicalize | \
+// RUN: FileCheck %s --check-prefix=CHECK-MIR
+
+#X = #sparse_tensor.encoding<{
+ dimLevelType = [ "dense", "dense", "dense" ],
+ dimOrdering = affine_map<(i,j,k) -> (k,i,j)>
+}>
+
+#trait = {
+  indexing_maps = [
+    affine_map<(i,j,k) -> (k,i,j)>,  // A (in)
+    affine_map<(i,j,k) -> ()>        // X (out)
+  ],
+  iterator_types = ["reduction", "reduction", "reduction"]
+}
+
+// CHECK-HIR-LABEL:   builtin.func @sparse_dynamic_dims(
+// CHECK-HIR-SAME:                                      %[[VAL_0:.*]]: tensor<?x?x?xf32,  #sparse_tensor.encoding<{{{.*}}}>>,
+// CHECK-HIR-SAME:                                      %[[VAL_1:.*]]: tensor<f32>) -> tensor<f32> {
+// CHECK-HIR-DAG:       %[[C0:.*]] = constant 0 : index
+// CHECK-HIR-DAG:       %[[C1:.*]] = constant 1 : index
+// CHECK-HIR-DAG:       %[[C2:.*]] = constant 2 : index
+// CHECK-HIR:           %[[VAL_5:.*]] = tensor.dim %[[VAL_0]], %[[C2]] : tensor<?x?x?xf32,  #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK-HIR:           %[[VAL_6:.*]] = tensor.dim %[[VAL_0]], %[[C0]] : tensor<?x?x?xf32,  #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK-HIR:           %[[VAL_7:.*]] = tensor.dim %[[VAL_0]], %[[C1]] : tensor<?x?x?xf32,  #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK-HIR:           %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<?x?x?xf32,  #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK-HIR:           %[[VAL_9:.*]] = memref.buffer_cast %[[VAL_1]] : memref<f32>
+// CHECK-HIR:           %[[VAL_10:.*]] = memref.alloc() : memref<f32>
+// CHECK-HIR:           memref.copy %[[VAL_9]], %[[VAL_10]] : memref<f32> to memref<f32>
+// CHECK-HIR:           scf.for %[[VAL_11:.*]] = %[[C0]] to %[[VAL_5]] step %[[C1]] {
+// CHECK-HIR:             scf.for %[[VAL_12:.*]] = %[[C0]] to %[[VAL_6]] step %[[C1]] {
+// CHECK-HIR:               %[[VAL_13:.*]] = muli %[[VAL_6]], %[[VAL_11]] : index
+// CHECK-HIR:               %[[VAL_14:.*]] = addi %[[VAL_13]], %[[VAL_12]] : index
+// CHECK-HIR:               %[[VAL_15:.*]] = memref.load %[[VAL_10]][] : memref<f32>
+// CHECK-HIR:               %[[VAL_16:.*]] = scf.for %[[VAL_17:.*]] = %[[C0]] to %[[VAL_7]] step %[[C1]] iter_args(%[[VAL_18:.*]] = %[[VAL_15]]) -> (f32) {
+// CHECK-HIR:                 %[[VAL_19:.*]] = muli %[[VAL_7]], %[[VAL_14]] : index
+// CHECK-HIR:                 %[[VAL_20:.*]] = addi %[[VAL_19]], %[[VAL_17]] : index
+// CHECK-HIR:                 %[[VAL_21:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_20]]] : memref<?xf32>
+// CHECK-HIR:                 %[[VAL_22:.*]] = addf %[[VAL_18]], %[[VAL_21]] : f32
+// CHECK-HIR:                 scf.yield %[[VAL_22]] : f32
+// CHECK-HIR:               }
+// CHECK-HIR:               memref.store %[[VAL_23:.*]], %[[VAL_10]][] : memref<f32>
+// CHECK-HIR:             }
+// CHECK-HIR:           }
+// CHECK-HIR:           %[[VAL_24:.*]] = memref.tensor_load %[[VAL_10]] : memref<f32>
+// CHECK-HIR:           return %[[VAL_24]] : tensor<f32>
+// CHECK-HIR:         }
+//
+// CHECK-MIR-LABEL:   builtin.func @sparse_dynamic_dims(
+// CHECK-MIR-SAME:                                      %[[VAL_0:.*]]: !llvm.ptr<i8>,
+// CHECK-MIR-SAME:                                      %[[VAL_1:.*]]: tensor<f32>) -> tensor<f32> {
+// CHECK-MIR-DAG:       %[[C0:.*]] = constant 0 : index
+// CHECK-MIR-DAG:       %[[C1:.*]] = constant 1 : index
+// CHECK-MIR-DAG:       %[[C2:.*]] = constant 2 : index
+// CHECK-MIR:           %[[VAL_5:.*]] = call @sparseDimSize(%[[VAL_0]], %[[C0]]) : (!llvm.ptr<i8>, index) -> index
+// CHECK-MIR:           %[[VAL_6:.*]] = call @sparseDimSize(%[[VAL_0]], %[[C1]]) : (!llvm.ptr<i8>, index) -> index
+// CHECK-MIR:           %[[VAL_7:.*]] = call @sparseDimSize(%[[VAL_0]], %[[C2]]) : (!llvm.ptr<i8>, index) -> index
+// CHECK-MIR:           %[[VAL_8:.*]] = call @sparseValuesF32(%[[VAL_0]]) : (!llvm.ptr<i8>) -> memref<?xf32>
+// CHECK-MIR:           %[[VAL_9:.*]] = memref.buffer_cast %[[VAL_1]] : memref<f32>
+// CHECK-MIR:           %[[VAL_10:.*]] = memref.alloc() : memref<f32>
+// CHECK-MIR:           memref.copy %[[VAL_9]], %[[VAL_10]] : memref<f32> to memref<f32>
+// CHECK-MIR:           scf.for %[[VAL_11:.*]] = %[[C0]] to %[[VAL_5]] step %[[C1]] {
+// CHECK-MIR:             scf.for %[[VAL_12:.*]] = %[[C0]] to %[[VAL_6]] step %[[C1]] {
+// CHECK-MIR:               %[[VAL_13:.*]] = muli %[[VAL_6]], %[[VAL_11]] : index
+// CHECK-MIR:               %[[VAL_14:.*]] = addi %[[VAL_13]], %[[VAL_12]] : index
+// CHECK-MIR:               %[[VAL_15:.*]] = memref.load %[[VAL_10]][] : memref<f32>
+// CHECK-MIR:               %[[VAL_16:.*]] = scf.for %[[VAL_17:.*]] = %[[C0]] to %[[VAL_7]] step %[[C1]] iter_args(%[[VAL_18:.*]] = %[[VAL_15]]) -> (f32) {
+// CHECK-MIR:                 %[[VAL_19:.*]] = muli %[[VAL_7]], %[[VAL_14]] : index
+// CHECK-MIR:                 %[[VAL_20:.*]] = addi %[[VAL_19]], %[[VAL_17]] : index
+// CHECK-MIR:                 %[[VAL_21:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_20]]] : memref<?xf32>
+// CHECK-MIR:                 %[[VAL_22:.*]] = addf %[[VAL_18]], %[[VAL_21]] : f32
+// CHECK-MIR:                 scf.yield %[[VAL_22]] : f32
+// CHECK-MIR:               }
+// CHECK-MIR:               memref.store %[[VAL_23:.*]], %[[VAL_10]][] : memref<f32>
+// CHECK-MIR:             }
+// CHECK-MIR:           }
+// CHECK-MIR:           %[[VAL_24:.*]] = memref.tensor_load %[[VAL_10]] : memref<f32>
+// CHECK-MIR:           return %[[VAL_24]] : tensor<f32>
+// CHECK-MIR:         }
+func @sparse_dynamic_dims(%arga: tensor<?x?x?xf32, #X>,
+                          %argx: tensor<f32>) -> tensor<f32> {
+  %0 = linalg.generic #trait
+    ins(%arga: tensor<?x?x?xf32, #X>)
+    outs(%argx: tensor<f32>) {
+      ^bb(%a : f32, %x: f32):
+        %0 = addf %x, %a : f32
+        linalg.yield %0 : f32
+  } -> tensor<f32>
+  return %0 : tensor<f32>
+}