[mlir][sparse] reject kernels with non-sparsfiable reduction expression.
authorPeiming Liu <peiming@google.com>
Thu, 8 Dec 2022 20:03:18 +0000 (20:03 +0000)
committerPeiming Liu <peiming@google.com>
Thu, 8 Dec 2022 23:36:30 +0000 (23:36 +0000)
To address https://github.com/llvm/llvm-project/issues/59394.

Reduction on negation of the output tensor is a non-sparsifiable kernel, it creates cyclic dependency.

This patch reject those cases instead of crashing.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D139659

mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
mlir/test/Dialect/SparseTensor/rejected.mlir [new file with mode: 0644]

index 73d1c8f..1ff9f69 100644 (file)
@@ -271,6 +271,15 @@ public:
     return ldx >= numNativeLoops;
   }
 
+  /// Returns true if the expression contains the `t` as an operand.
+  bool expContainsTensor(unsigned e, unsigned t) const;
+
+  /// Returns true if the expression contains a negation on output tensor.
+  /// I.e., `- outTensor` or `exp - outputTensor`
+  /// NOTE: this is an trivial tests in that it does not handle recursive
+  /// negation, i.e., it returns true when the expression is `-(-tensor)`.
+  bool hasNegateOnOut(unsigned e) const;
+
   /// Returns true if given tensor iterates *only* in the given tensor
   /// expression. For the output tensor, this defines a "simply dynamic"
   /// operation [Bik96]. For instance: a(i) *= 2.0 or a(i) += a(i) for
@@ -348,9 +357,9 @@ public:
   void dumpBits(const BitVector &bits) const;
 #endif
 
-  /// Builds the iteration lattices in a bottom-up traversal given the remaining
-  /// tensor (sub)expression and the next loop index in the iteration graph.
-  /// Returns index of the root expression.
+  /// Builds the iteration lattices in a bottom-up traversal given the
+  /// remaining tensor (sub)expression and the next loop index in the
+  /// iteration graph. Returns index of the root expression.
   unsigned buildLattices(unsigned e, unsigned i);
 
   /// Builds a tensor expression from the given Linalg operation.
@@ -380,7 +389,8 @@ private:
   // Map that converts pair<tensor id, loop id> to the corresponding dimension
   // level type.
   std::vector<std::vector<DimLevelType>> dimTypes;
-  // Map that converts pair<tensor id, loop id> to the corresponding dimension.
+  // Map that converts pair<tensor id, loop id> to the corresponding
+  // dimension.
   std::vector<std::vector<Optional<unsigned>>> loopIdxToDim;
   // Map that converts pair<tensor id, dim> to the corresponding loop id.
   std::vector<std::vector<Optional<unsigned>>> dimToLoopIdx;
index 81f3845..8fbbf6a 100644 (file)
@@ -583,6 +583,19 @@ static bool isAdmissibleTensorExp(Merger &merger, linalg::GenericOp op,
                                   std::vector<unsigned> &topSort, unsigned exp,
                                   OpOperand **sparseOut,
                                   unsigned &outerParNest) {
+  // We reject any expression that makes a reduction from `-outTensor`, as those
+  // expression create dependency between the current iteration (i) and the
+  // previous iteration (i-1). It would then require iterating over the whole
+  // coordinate space, which prevent us from exploiting sparsity for faster
+  // code.
+  for (utils::IteratorType it : op.getIteratorTypesArray()) {
+    if (it == utils::IteratorType::reduction) {
+      if (merger.hasNegateOnOut(exp))
+        return false;
+      break;
+    }
+  }
+
   OpOperand *lhs = op.getDpsInitOperand(0);
   unsigned tensor = lhs->getOperandNumber();
   auto enc = getSparseTensorEncoding(lhs->get().getType());
index 35530bf..bf66120 100644 (file)
 namespace mlir {
 namespace sparse_tensor {
 
+enum class ExpArity {
+  kNullary,
+  kUnary,
+  kBinary,
+};
+
+static ExpArity getExpArity(Kind k) {
+  switch (k) {
+  // Leaf.
+  case kTensor:
+  case kInvariant:
+  case kIndex:
+    return ExpArity::kNullary;
+  case kAbsF:
+  case kAbsC:
+  case kAbsI:
+  case kCeilF:
+  case kFloorF:
+  case kSqrtF:
+  case kSqrtC:
+  case kExpm1F:
+  case kExpm1C:
+  case kLog1pF:
+  case kLog1pC:
+  case kSinF:
+  case kSinC:
+  case kTanhF:
+  case kTanhC:
+  case kTruncF:
+  case kExtF:
+  case kCastFS:
+  case kCastFU:
+  case kCastSF:
+  case kCastUF:
+  case kCastS:
+  case kCastU:
+  case kCastIdx:
+  case kTruncI:
+  case kCIm:
+  case kCRe:
+  case kBitCast:
+  case kBinaryBranch:
+  case kUnary:
+  case kSelect:
+  case kNegF:
+  case kNegC:
+  case kNegI:
+    return ExpArity::kUnary;
+  // Binary operations.
+  case kDivF:
+  case kDivC:
+  case kDivS:
+  case kDivU:
+  case kShrS:
+  case kShrU:
+  case kShlI:
+  case kMulF:
+  case kMulC:
+  case kMulI:
+  case kAndI:
+  case kAddF:
+  case kAddC:
+  case kAddI:
+  case kOrI:
+  case kXorI:
+  case kBinary:
+  case kReduce:
+  case kSubF:
+  case kSubC:
+  case kSubI:
+    return ExpArity::kBinary;
+  }
+  llvm_unreachable("unexpected kind");
+}
+
 //===----------------------------------------------------------------------===//
 // Constructors.
 //===----------------------------------------------------------------------===//
@@ -310,6 +385,57 @@ bool Merger::onlyDenseDiff(unsigned i, unsigned j) {
   return !hasAnySparse(tmp);
 }
 
+bool Merger::expContainsTensor(unsigned e, unsigned t) const {
+  if (tensorExps[e].kind == kTensor)
+    return tensorExps[e].tensor == t;
+
+  switch (getExpArity(tensorExps[e].kind)) {
+  case ExpArity::kNullary:
+    return false;
+  case ExpArity::kUnary: {
+    unsigned op = tensorExps[e].children.e0;
+    if (tensorExps[op].kind == kTensor && tensorExps[op].tensor == t)
+      return true;
+    return expContainsTensor(op, t);
+  }
+  case ExpArity::kBinary: {
+    unsigned op1 = tensorExps[e].children.e0;
+    unsigned op2 = tensorExps[e].children.e1;
+    if ((tensorExps[op1].kind == kTensor && tensorExps[op1].tensor == t) ||
+        (tensorExps[op2].kind == kTensor && tensorExps[op2].tensor == t))
+      return true;
+    return expContainsTensor(op1, t) || expContainsTensor(op2, t);
+  }
+  }
+  llvm_unreachable("unexpected arity");
+}
+
+bool Merger::hasNegateOnOut(unsigned e) const {
+  switch (tensorExps[e].kind) {
+  case kNegF:
+  case kNegC:
+  case kNegI:
+    return expContainsTensor(tensorExps[e].children.e0, outTensor);
+  case kSubF:
+  case kSubC:
+  case kSubI:
+    return expContainsTensor(tensorExps[e].children.e1, outTensor) ||
+           hasNegateOnOut(tensorExps[e].children.e0);
+  default: {
+    switch (getExpArity(tensorExps[e].kind)) {
+    case ExpArity::kNullary:
+      return false;
+    case ExpArity::kUnary:
+      return hasNegateOnOut(tensorExps[e].children.e0);
+    case ExpArity::kBinary:
+      return hasNegateOnOut(tensorExps[e].children.e0) ||
+             hasNegateOnOut(tensorExps[e].children.e1);
+    }
+  }
+  }
+  llvm_unreachable("unexpected kind");
+}
+
 bool Merger::isSingleCondition(unsigned t, unsigned e) const {
   switch (tensorExps[e].kind) {
   // Leaf.
diff --git a/mlir/test/Dialect/SparseTensor/rejected.mlir b/mlir/test/Dialect/SparseTensor/rejected.mlir
new file mode 100644 (file)
index 0000000..63a10c5
--- /dev/null
@@ -0,0 +1,39 @@
+// RUN: mlir-opt %s -sparsification | FileCheck %s
+
+
+// The file contains examples that will be rejected by sparse compiler
+// (we expect the linalg.generic unchanged).
+#SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}>
+
+#trait = {
+  indexing_maps = [ 
+    affine_map<(i) -> (i)>,  // a (in)
+    affine_map<(i) -> ()>    // x (out)
+  ],  
+  iterator_types = ["reduction"]
+}
+
+// CHECK-LABEL:   func.func @sparse_reduction_subi(
+// CHECK-SAME:      %[[VAL_0:.*]]: tensor<i32>,
+// CHECK-SAME:      %[[VAL_1:.*]]: tensor<?xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }>>) -> tensor<i32> {
+// CHECK:           %[[VAL_2:.*]] = linalg.generic
+// CHECK:           ^bb0(%[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32):
+// CHECK:             %[[VAL_5:.*]] = arith.subi %[[VAL_3]], %[[VAL_4]] : i32
+// CHECK:             linalg.yield %[[VAL_5]] : i32
+// CHECK:           } -> tensor<i32>
+// CHECK:           return %[[VAL_6:.*]] : tensor<i32>
+func.func @sparse_reduction_subi(%argx: tensor<i32>,
+                             %arga: tensor<?xi32, #SparseVector>)
+ -> tensor<i32> {
+  %0 = linalg.generic #trait
+     ins(%arga: tensor<?xi32, #SparseVector>)
+      outs(%argx: tensor<i32>) {
+      ^bb(%a: i32, %x: i32):
+        // NOTE: `subi %a, %x` is the reason why the program is rejected by the sparse compiler.
+        // It is because we do not allow `-outTensor` in reduction loops as it creates cyclic
+        // dependences.
+        %t = arith.subi %a, %x: i32 
+        linalg.yield %t : i32 
+  } -> tensor<i32>
+  return %0 : tensor<i32>
+}