[mlir][sparse] add support for std unary operations
authorAart Bik <ajcbik@google.com>
Tue, 13 Jul 2021 19:13:39 +0000 (12:13 -0700)
committerAart Bik <ajcbik@google.com>
Tue, 13 Jul 2021 21:51:13 +0000 (14:51 -0700)
Adds zero-preserving unary operators from std. Also adds xor.
Performs minor refactoring to remove "zero" node, and pushed
the irregular logic for negi (not support in std) into one place.

Reviewed By: gussmith23

Differential Revision: https://reviews.llvm.org/D105928

mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
mlir/test/Dialect/SparseTensor/sparse_fp_ops.mlir
mlir/test/Dialect/SparseTensor/sparse_int_ops.mlir
mlir/unittests/Dialect/SparseTensor/MergerTest.cpp

index d81cb37..aa1e52f 100644 (file)
@@ -21,15 +21,20 @@ namespace mlir {
 namespace sparse_tensor {
 
 /// Dimension level type for a tensor (undef means index does not appear).
-enum class Dim { kSparse, kDense, kSingle, kUndef };
+enum Dim { kSparse, kDense, kSingle, kUndef };
 
 /// Tensor expression kind.
-enum class Kind {
+enum Kind {
   // Leaf.
-  kTensor,
+  kTensor = 0,
   kInvariant,
-  kZero,
-  // Operation.
+  // Unary operations.
+  kAbsF,
+  kCeilF,
+  kFloorF,
+  kNegF,
+  kNegI,
+  // Binary operations.
   kMulF,
   kMulI,
   kDivF,
@@ -41,6 +46,7 @@ enum class Kind {
   kSubI,
   kAndI,
   kOrI,
+  kXorI,
 };
 
 /// Children subexpressions of tensor operations.
@@ -105,8 +111,7 @@ public:
         dims(t + 1, std::vector<Dim>(l, Dim::kUndef)) {}
 
   /// Adds a tensor expression. Returns its index.
-  unsigned addExp(Kind k, unsigned e0 = -1u, unsigned e1 = -1u,
-                  Value v = Value());
+  unsigned addExp(Kind k, unsigned e0, unsigned e1 = -1u, Value v = Value());
   unsigned addExp(Kind k, Value v) { return addExp(k, -1u, -1u, v); }
 
   /// Adds an iteration lattice point. Returns its index.
@@ -129,11 +134,10 @@ public:
   /// Returns the index of the new set.
   unsigned takeDisj(Kind kind, unsigned s0, unsigned s1);
 
-  /// Maps a zero operand over a lattice set, i.e. each lattice point on an
-  /// expression E is simply copied over, but with 0 OP E as new expression.
-  /// This is useful to deal with disjunctive, but non-commutative operators.
-  /// Returns the index of the new set.
-  unsigned mapZero(Kind kind, unsigned s0);
+  /// Maps the unary operator over the lattice set of the operand, i.e. each
+  /// lattice point on an expression E is simply copied over, but with OP E
+  /// as new expression. Returns the index of the new set.
+  unsigned mapSet(Kind kind, unsigned s0);
 
   /// Optimizes the iteration lattice points in the given set. This
   /// method should be called right before code generation to avoid
index 9d5447e..a852a3d 100644 (file)
@@ -609,18 +609,22 @@ static void genReductionEnd(Merger &merger, CodeGen &codegen,
 static Value genExp(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
                     linalg::GenericOp op, unsigned exp) {
   Location loc = op.getLoc();
+  if (exp == -1u)
+    return Value();
   if (merger.exp(exp).kind == Kind::kTensor)
     return genTensorLoad(merger, codegen, rewriter, op, exp);
   if (merger.exp(exp).kind == Kind::kInvariant)
     return genInvariantValue(merger, codegen, rewriter, exp);
-  if (merger.exp(exp).kind == Kind::kZero) {
-    Type tp = op.getOutputTensorTypes()[0].getElementType();
-    merger.exp(exp).val =
-        rewriter.create<ConstantOp>(loc, tp, rewriter.getZeroAttr(tp));
-    return genInvariantValue(merger, codegen, rewriter, exp);
-  }
   Value v0 = genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e0);
   Value v1 = genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e1);
+  if (merger.exp(exp).kind == Kind::kNegI) {
+    // TODO: no negi in std, need to make zero explicit.
+    Type tp = op.getOutputTensorTypes()[0].getElementType();
+    v1 = v0;
+    v0 = rewriter.create<ConstantOp>(loc, tp, rewriter.getZeroAttr(tp));
+    if (codegen.curVecLength > 1)
+      v0 = genVectorInvariantValue(codegen, rewriter, v0);
+  }
   return merger.buildExp(rewriter, loc, exp, v0, v1);
 }
 
@@ -628,6 +632,8 @@ static Value genExp(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
 static void genInvariants(Merger &merger, CodeGen &codegen,
                           PatternRewriter &rewriter, linalg::GenericOp op,
                           unsigned exp, unsigned ldx, bool hoist) {
+  if (exp == -1u)
+    return;
   if (merger.exp(exp).kind == Kind::kTensor) {
     // Inspect tensor indices.
     bool atLevel = ldx == -1u;
@@ -649,8 +655,7 @@ static void genInvariants(Merger &merger, CodeGen &codegen,
       merger.exp(exp).val =
           hoist ? genTensorLoad(merger, codegen, rewriter, op, exp) : Value();
     }
-  } else if (merger.exp(exp).kind != Kind::kInvariant &&
-             merger.exp(exp).kind != Kind::kZero) {
+  } else if (merger.exp(exp).kind != Kind::kInvariant) {
     // Traverse into the binary operations. Note that we only hoist
     // tensor loads, since subsequent MLIR/LLVM passes know how to
     // deal with all other kinds of derived loop invariants.
index 74a3a24..b254f03 100644 (file)
@@ -28,8 +28,14 @@ TensorExp::TensorExp(Kind k, unsigned x, unsigned y, Value v)
   case Kind::kInvariant:
     assert(x == -1u && y == -1u && v);
     break;
-  case Kind::kZero:
-    assert(x == -1u && y == -1u && !v);
+  case kAbsF:
+  case kCeilF:
+  case kFloorF:
+  case kNegF:
+  case kNegI:
+    assert(x != -1u && y == -1u && !v);
+    children.e0 = x;
+    children.e1 = y;
     break;
   default:
     assert(x != -1u && y != -1u && !v);
@@ -89,22 +95,25 @@ unsigned Merger::takeConj(Kind kind, unsigned s0, unsigned s1) {
 
 unsigned Merger::takeDisj(Kind kind, unsigned s0, unsigned s1) {
   unsigned s = takeConj(kind, s0, s1);
-  // Followed by all in s0 and s1.
+  // Followed by all in s0.
   for (unsigned p : latSets[s0])
     latSets[s].push_back(p);
-  if (Kind::kSubF <= kind && kind <= Kind::kSubI)
-    s1 = mapZero(kind, s1);
+  // Map binary 0-y to unary -y.
+  if (kind == Kind::kSubF)
+    s1 = mapSet(Kind::kNegF, s1);
+  else if (kind == Kind::kSubI)
+    s1 = mapSet(Kind::kNegI, s1);
+  // Followed by all in s1.
   for (unsigned p : latSets[s1])
     latSets[s].push_back(p);
   return s;
 }
 
-unsigned Merger::mapZero(Kind kind, unsigned s0) {
-  assert(Kind::kSubF <= kind && kind <= Kind::kSubI);
+unsigned Merger::mapSet(Kind kind, unsigned s0) {
+  assert(Kind::kAbsF <= kind && kind <= Kind::kNegI);
   unsigned s = addSet();
-  unsigned z = addExp(Kind::kZero);
   for (unsigned p : latSets[s0]) {
-    unsigned e = addExp(kind, z, latPoints[p].exp);
+    unsigned e = addExp(kind, latPoints[p].exp);
     latPoints.push_back(LatPoint(latPoints[p].bits, e));
     latSets[s].push_back(latPoints.size() - 1);
   }
@@ -194,6 +203,12 @@ bool Merger::isConjunction(unsigned t, unsigned e) const {
   switch (tensorExps[e].kind) {
   case Kind::kTensor:
     return tensorExps[e].tensor == t;
+  case kAbsF:
+  case kCeilF:
+  case kFloorF:
+  case kNegF:
+  case kNegI:
+    return isConjunction(t, tensorExps[e].children.e0);
   case Kind::kMulF:
   case Kind::kMulI:
   case Kind::kAndI:
@@ -213,30 +228,9 @@ bool Merger::isConjunction(unsigned t, unsigned e) const {
 // Print methods (for debugging).
 //
 
-static char kindToOpSymbol(Kind kind) {
-  switch (kind) {
-  case Kind::kMulF:
-  case Kind::kMulI:
-    return '*';
-  case Kind::kDivF:
-  case Kind::kDivS:
-  case Kind::kDivU:
-    return '/';
-  case Kind::kAddF:
-  case Kind::kAddI:
-    return '+';
-  case Kind::kSubF:
-  case Kind::kSubI:
-    return '-';
-  case Kind::kAndI:
-    return '&';
-  case Kind::kOrI:
-    return '|';
-  default:
-    break;
-  }
-  llvm_unreachable("unexpected kind");
-}
+static const char *kOpSymbols[] = {"",  "",  "abs", "ceil", "floor", "-",
+                                   "-", "*", "*",   "/",    "/",     "+",
+                                   "+", "-", "-",   "&",    "|",     "^"};
 
 void Merger::dumpExp(unsigned e) const {
   switch (tensorExps[e].kind) {
@@ -250,13 +244,18 @@ void Merger::dumpExp(unsigned e) const {
   case Kind::kInvariant:
     llvm::dbgs() << "invariant";
     break;
-  case Kind::kZero:
-    llvm::dbgs() << "zero";
+  case kAbsF:
+  case kCeilF:
+  case kFloorF:
+  case kNegF:
+  case kNegI:
+    llvm::dbgs() << kOpSymbols[tensorExps[e].kind] << " ";
+    dumpExp(tensorExps[e].children.e0);
     break;
   default:
     llvm::dbgs() << "(";
     dumpExp(tensorExps[e].children.e0);
-    llvm::dbgs() << " " << kindToOpSymbol(tensorExps[e].kind) << " ";
+    llvm::dbgs() << " " << kOpSymbols[tensorExps[e].kind] << " ";
     dumpExp(tensorExps[e].children.e1);
     llvm::dbgs() << ")";
   }
@@ -315,8 +314,7 @@ unsigned Merger::buildLattices(unsigned e, unsigned i) {
   Kind kind = tensorExps[e].kind;
   switch (kind) {
   case Kind::kTensor:
-  case Kind::kInvariant:
-  case Kind::kZero: {
+  case Kind::kInvariant: {
     // Either the index is really used in the tensor expression, or it is
     // set to the undefined index in that dimension. An invariant expression
     // is set to a synthetic tensor with undefined indices only.
@@ -325,6 +323,18 @@ unsigned Merger::buildLattices(unsigned e, unsigned i) {
     latSets[s].push_back(addLat(t, i, e));
     return s;
   }
+  case kAbsF:
+  case kCeilF:
+  case kFloorF:
+  case kNegF:
+  case kNegI:
+    // A zero preserving operation (viz. f(0) = 0, [Bik96,Ch5]) maps the
+    // lattice set of the operand through the operator into a new set.
+    //
+    //  -y|!y | y |
+    //  --+---+---+
+    //    | 0 |-y |
+    return mapSet(kind, buildLattices(tensorExps[e].children.e0, i));
   case Kind::kMulF:
   case Kind::kMulI:
   case Kind::kAndI:
@@ -357,16 +367,12 @@ unsigned Merger::buildLattices(unsigned e, unsigned i) {
     return takeConj(kind, // take binary conjunction
                     buildLattices(tensorExps[e].children.e0, i),
                     buildLattices(tensorExps[e].children.e1, i));
-  case Kind::kSubF:
-  case Kind::kSubI:
-    // Special case: 0-y is -y.
-    if (tensorExps[tensorExps[e].children.e0].kind == Kind::kZero)
-      return mapZero(kind, // maps to 0-y with just y's lattices
-                     buildLattices(tensorExps[e].children.e1, i));
-    LLVM_FALLTHROUGH;
   case Kind::kAddF:
   case Kind::kAddI:
+  case Kind::kSubF:
+  case Kind::kSubI:
   case Kind::kOrI:
+  case Kind::kXorI:
     // An additive operation needs to be performed
     // for the disjunction of sparse iteration spaces.
     //
@@ -420,10 +426,15 @@ Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
   if (def->getNumOperands() == 1) {
     auto x = buildTensorExp(op, def->getOperand(0));
     if (x.hasValue()) {
-      unsigned e0 = addExp(Kind::kZero);
-      unsigned e1 = x.getValue();
+      unsigned e = x.getValue();
+      if (isa<AbsFOp>(def))
+        return addExp(Kind::kAbsF, e);
+      if (isa<CeilFOp>(def))
+        return addExp(Kind::kCeilF, e);
+      if (isa<FloorFOp>(def))
+        return addExp(Kind::kFloorF, e);
       if (isa<NegFOp>(def))
-        return addExp(Kind::kSubF, e0, e1);
+        return addExp(Kind::kNegF, e);
       // TODO: no negi in std?
     }
   }
@@ -457,6 +468,8 @@ Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
         return addExp(Kind::kAndI, e0, e1);
       if (isa<OrOp>(def))
         return addExp(Kind::kOrI, e0, e1);
+      if (isa<XOrOp>(def))
+        return addExp(Kind::kXorI, e0, e1);
     }
   }
   // Cannot build.
@@ -468,8 +481,18 @@ Value Merger::buildExp(PatternRewriter &rewriter, Location loc, unsigned e,
   switch (tensorExps[e].kind) {
   case Kind::kTensor:
   case Kind::kInvariant:
-  case Kind::kZero:
     llvm_unreachable("unexpected non-op");
+  case kAbsF:
+    return rewriter.create<AbsFOp>(loc, v0);
+  case kCeilF:
+    return rewriter.create<CeilFOp>(loc, v0);
+  case kFloorF:
+    return rewriter.create<FloorFOp>(loc, v0);
+  case kNegF:
+    return rewriter.create<NegFOp>(loc, v0);
+  case kNegI:
+    assert(v1); // no negi in std
+    return rewriter.create<SubIOp>(loc, v0, v1);
   case Kind::kMulF:
     return rewriter.create<MulFOp>(loc, v0, v1);
   case Kind::kMulI:
@@ -492,6 +515,8 @@ Value Merger::buildExp(PatternRewriter &rewriter, Location loc, unsigned e,
     return rewriter.create<AndOp>(loc, v0, v1);
   case Kind::kOrI:
     return rewriter.create<OrOp>(loc, v0, v1);
+  case Kind::kXorI:
+    return rewriter.create<XOrOp>(loc, v0, v1);
   }
   llvm_unreachable("unexpected expression kind in build");
 }
index ff43aa7..c7c6507 100644 (file)
   doc = "x(i) = a(i) OP c"
 }
 
+// CHECK-LABEL:   func @abs(
+// CHECK-SAME:              %[[VAL_0:.*]]: tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>>,
+// CHECK-SAME:              %[[VAL_1:.*]]: tensor<32xf64> {linalg.inplaceable = true}) -> tensor<32xf64> {
+// CHECK:           %[[VAL_2:.*]] = constant 0 : index
+// CHECK:           %[[VAL_3:.*]] = constant 1 : index
+// CHECK:           %[[VAL_4:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_2]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
+// CHECK:           %[[VAL_5:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_2]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
+// CHECK:           %[[VAL_6:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xf64>
+// CHECK:           %[[VAL_7:.*]] = memref.buffer_cast %[[VAL_1]] : memref<32xf64>
+// CHECK:           %[[VAL_8:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_2]]] : memref<?xindex>
+// CHECK:           %[[VAL_9:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_3]]] : memref<?xindex>
+// CHECK:           scf.for %[[VAL_10:.*]] = %[[VAL_8]] to %[[VAL_9]] step %[[VAL_3]] {
+// CHECK:             %[[VAL_11:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_10]]] : memref<?xindex>
+// CHECK:             %[[VAL_12:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_10]]] : memref<?xf64>
+// CHECK:             %[[VAL_13:.*]] = absf %[[VAL_12]] : f64
+// CHECK:             memref.store %[[VAL_13]], %[[VAL_7]]{{\[}}%[[VAL_11]]] : memref<32xf64>
+// CHECK:           }
+// CHECK:           %[[VAL_14:.*]] = memref.tensor_load %[[VAL_7]] : memref<32xf64>
+// CHECK:           return %[[VAL_14]] : tensor<32xf64>
+func @abs(%arga: tensor<32xf64, #SV>,
+          %argx: tensor<32xf64> {linalg.inplaceable = true}) -> tensor<32xf64> {
+  %0 = linalg.generic #trait1
+     ins(%arga: tensor<32xf64, #SV>)
+    outs(%argx: tensor<32xf64>) {
+      ^bb(%a: f64, %x: f64):
+        %0 = absf %a : f64
+        linalg.yield %0 : f64
+  } -> tensor<32xf64>
+  return %0 : tensor<32xf64>
+}
+
+// CHECK-LABEL:   func @ceil(
+// CHECK-SAME:              %[[VAL_0:.*]]: tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>>,
+// CHECK-SAME:              %[[VAL_1:.*]]: tensor<32xf64> {linalg.inplaceable = true}) -> tensor<32xf64> {
+// CHECK:           %[[VAL_2:.*]] = constant 0 : index
+// CHECK:           %[[VAL_3:.*]] = constant 1 : index
+// CHECK:           %[[VAL_4:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_2]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
+// CHECK:           %[[VAL_5:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_2]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
+// CHECK:           %[[VAL_6:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xf64>
+// CHECK:           %[[VAL_7:.*]] = memref.buffer_cast %[[VAL_1]] : memref<32xf64>
+// CHECK:           %[[VAL_8:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_2]]] : memref<?xindex>
+// CHECK:           %[[VAL_9:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_3]]] : memref<?xindex>
+// CHECK:           scf.for %[[VAL_10:.*]] = %[[VAL_8]] to %[[VAL_9]] step %[[VAL_3]] {
+// CHECK:             %[[VAL_11:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_10]]] : memref<?xindex>
+// CHECK:             %[[VAL_12:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_10]]] : memref<?xf64>
+// CHECK:             %[[VAL_13:.*]] = ceilf %[[VAL_12]] : f64
+// CHECK:             memref.store %[[VAL_13]], %[[VAL_7]]{{\[}}%[[VAL_11]]] : memref<32xf64>
+// CHECK:           }
+// CHECK:           %[[VAL_14:.*]] = memref.tensor_load %[[VAL_7]] : memref<32xf64>
+// CHECK:           return %[[VAL_14]] : tensor<32xf64>
+// CHECK:         }
+func @ceil(%arga: tensor<32xf64, #SV>,
+           %argx: tensor<32xf64> {linalg.inplaceable = true}) -> tensor<32xf64> {
+  %0 = linalg.generic #trait1
+     ins(%arga: tensor<32xf64, #SV>)
+    outs(%argx: tensor<32xf64>) {
+      ^bb(%a: f64, %x: f64):
+        %0 = ceilf %a : f64
+        linalg.yield %0 : f64
+  } -> tensor<32xf64>
+  return %0 : tensor<32xf64>
+}
+
+// CHECK-LABEL:   func @floor(
+// CHECK-SAME:              %[[VAL_0:.*]]: tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>>,
+// CHECK-SAME:              %[[VAL_1:.*]]: tensor<32xf64> {linalg.inplaceable = true}) -> tensor<32xf64> {
+// CHECK:           %[[VAL_2:.*]] = constant 0 : index
+// CHECK:           %[[VAL_3:.*]] = constant 1 : index
+// CHECK:           %[[VAL_4:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_2]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
+// CHECK:           %[[VAL_5:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_2]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
+// CHECK:           %[[VAL_6:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xf64>
+// CHECK:           %[[VAL_7:.*]] = memref.buffer_cast %[[VAL_1]] : memref<32xf64>
+// CHECK:           %[[VAL_8:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_2]]] : memref<?xindex>
+// CHECK:           %[[VAL_9:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_3]]] : memref<?xindex>
+// CHECK:           scf.for %[[VAL_10:.*]] = %[[VAL_8]] to %[[VAL_9]] step %[[VAL_3]] {
+// CHECK:             %[[VAL_11:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_10]]] : memref<?xindex>
+// CHECK:             %[[VAL_12:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_10]]] : memref<?xf64>
+// CHECK:             %[[VAL_13:.*]] = floorf %[[VAL_12]] : f64
+// CHECK:             memref.store %[[VAL_13]], %[[VAL_7]]{{\[}}%[[VAL_11]]] : memref<32xf64>
+// CHECK:           }
+// CHECK:           %[[VAL_14:.*]] = memref.tensor_load %[[VAL_7]] : memref<32xf64>
+// CHECK:           return %[[VAL_14]] : tensor<32xf64>
+// CHECK:         }
+func @floor(%arga: tensor<32xf64, #SV>,
+            %argx: tensor<32xf64> {linalg.inplaceable = true}) -> tensor<32xf64> {
+  %0 = linalg.generic #trait1
+     ins(%arga: tensor<32xf64, #SV>)
+    outs(%argx: tensor<32xf64>) {
+      ^bb(%a: f64, %x: f64):
+        %0 = floorf %a : f64
+        linalg.yield %0 : f64
+  } -> tensor<32xf64>
+  return %0 : tensor<32xf64>
+}
+
 // CHECK-LABEL:   func @neg(
 // CHECK-SAME:              %[[VAL_0:.*]]: tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>>,
 // CHECK-SAME:              %[[VAL_1:.*]]: tensor<32xf64> {linalg.inplaceable = true}) -> tensor<32xf64> {
 // CHECK:           %[[VAL_2:.*]] = constant 0 : index
 // CHECK:           %[[VAL_3:.*]] = constant 1 : index
-// CHECK:           %[[VAL_4:.*]] = constant 0.000000e+00 : f64
-// CHECK:           %[[VAL_5:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_2]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>>
-// CHECK:           %[[VAL_6:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_2]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>>
-// CHECK:           %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>>
-// CHECK:           %[[VAL_8:.*]] = memref.buffer_cast %[[VAL_1]] : memref<32xf64>
-// CHECK:           %[[VAL_9:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_2]]] : memref<?xindex>
-// CHECK:           %[[VAL_10:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_3]]] : memref<?xindex>
-// CHECK:           scf.for %[[VAL_11:.*]] = %[[VAL_9]] to %[[VAL_10]] step %[[VAL_3]] {
-// CHECK:             %[[VAL_12:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_11]]] : memref<?xindex>
-// CHECK:             %[[VAL_13:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_11]]] : memref<?xf64>
-// CHECK:             %[[VAL_14:.*]] = subf %[[VAL_4]], %[[VAL_13]] : f64
-// CHECK:             memref.store %[[VAL_14]], %[[VAL_8]]{{\[}}%[[VAL_12]]] : memref<32xf64>
+// CHECK:           %[[VAL_4:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_2]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
+// CHECK:           %[[VAL_5:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_2]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
+// CHECK:           %[[VAL_6:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xf64>
+// CHECK:           %[[VAL_7:.*]] = memref.buffer_cast %[[VAL_1]] : memref<32xf64>
+// CHECK:           %[[VAL_8:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_2]]] : memref<?xindex>
+// CHECK:           %[[VAL_9:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_3]]] : memref<?xindex>
+// CHECK:           scf.for %[[VAL_10:.*]] = %[[VAL_8]] to %[[VAL_9]] step %[[VAL_3]] {
+// CHECK:             %[[VAL_11:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_10]]] : memref<?xindex>
+// CHECK:             %[[VAL_12:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_10]]] : memref<?xf64>
+// CHECK:             %[[VAL_13:.*]] = negf %[[VAL_12]] : f64
+// CHECK:             memref.store %[[VAL_13]], %[[VAL_7]]{{\[}}%[[VAL_11]]] : memref<32xf64>
 // CHECK:           }
-// CHECK:           %[[VAL_15:.*]] = memref.tensor_load %[[VAL_8]] : memref<32xf64>
-// CHECK:           return %[[VAL_15]] : tensor<32xf64>
+// CHECK:           %[[VAL_14:.*]] = memref.tensor_load %[[VAL_7]] : memref<32xf64>
+// CHECK:           return %[[VAL_14]] : tensor<32xf64>
 // CHECK:         }
 func @neg(%arga: tensor<32xf64, #SV>,
           %argx: tensor<32xf64> {linalg.inplaceable = true}) -> tensor<32xf64> {
@@ -132,47 +226,46 @@ func @add(%arga: tensor<32xf64, #SV>,
 // CHECK:           %[[VAL_4:.*]] = constant 0 : index
 // CHECK:           %[[VAL_5:.*]] = constant true
 // CHECK:           %[[VAL_6:.*]] = constant 1 : index
-// CHECK:           %[[VAL_7:.*]] = constant 0.000000e+00 : f64
-// CHECK:           %[[VAL_8:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_4]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>>
-// CHECK:           %[[VAL_9:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_4]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>>
-// CHECK:           %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>>
-// CHECK:           %[[VAL_11:.*]] = memref.buffer_cast %[[VAL_1]] : memref<32xf64>
-// CHECK:           %[[VAL_12:.*]] = memref.buffer_cast %[[VAL_2]] : memref<32xf64>
-// CHECK:           %[[VAL_13:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_4]]] : memref<?xindex>
-// CHECK:           %[[VAL_14:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_6]]] : memref<?xindex>
-// CHECK:           %[[VAL_15:.*]]:2 = scf.while (%[[VAL_16:.*]] = %[[VAL_13]], %[[VAL_17:.*]] = %[[VAL_4]]) : (index, index) -> (index, index) {
-// CHECK:             %[[VAL_18:.*]] = cmpi ult, %[[VAL_16]], %[[VAL_14]] : index
-// CHECK:             scf.condition(%[[VAL_18]]) %[[VAL_16]], %[[VAL_17]] : index, index
+// CHECK:           %[[VAL_7:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_4]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
+// CHECK:           %[[VAL_8:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_4]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
+// CHECK:           %[[VAL_9:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xf64>
+// CHECK:           %[[VAL_10:.*]] = memref.buffer_cast %[[VAL_1]] : memref<32xf64>
+// CHECK:           %[[VAL_11:.*]] = memref.buffer_cast %[[VAL_2]] : memref<32xf64>
+// CHECK:           %[[VAL_12:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_4]]] : memref<?xindex>
+// CHECK:           %[[VAL_13:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_6]]] : memref<?xindex>
+// CHECK:           %[[VAL_14:.*]]:2 = scf.while (%[[VAL_15:.*]] = %[[VAL_12]], %[[VAL_16:.*]] = %[[VAL_4]]) : (index, index) -> (index, index) {
+// CHECK:             %[[VAL_17:.*]] = cmpi ult, %[[VAL_15]], %[[VAL_13]] : index
+// CHECK:             scf.condition(%[[VAL_17]]) %[[VAL_15]], %[[VAL_16]] : index, index
 // CHECK:           } do {
-// CHECK:           ^bb0(%[[VAL_19:.*]]: index, %[[VAL_20:.*]]: index):
-// CHECK:             %[[VAL_21:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_19]]] : memref<?xindex>
-// CHECK:             %[[VAL_22:.*]] = cmpi eq, %[[VAL_21]], %[[VAL_20]] : index
-// CHECK:             scf.if %[[VAL_22]] {
-// CHECK:               %[[VAL_23:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_19]]] : memref<?xf64>
-// CHECK:               %[[VAL_24:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_20]]] : memref<32xf64>
-// CHECK:               %[[VAL_25:.*]] = subf %[[VAL_23]], %[[VAL_24]] : f64
-// CHECK:               memref.store %[[VAL_25]], %[[VAL_12]]{{\[}}%[[VAL_20]]] : memref<32xf64>
+// CHECK:           ^bb0(%[[VAL_18:.*]]: index, %[[VAL_19:.*]]: index):
+// CHECK:             %[[VAL_20:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_18]]] : memref<?xindex>
+// CHECK:             %[[VAL_21:.*]] = cmpi eq, %[[VAL_20]], %[[VAL_19]] : index
+// CHECK:             scf.if %[[VAL_21]] {
+// CHECK:               %[[VAL_22:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_18]]] : memref<?xf64>
+// CHECK:               %[[VAL_23:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_19]]] : memref<32xf64>
+// CHECK:               %[[VAL_24:.*]] = subf %[[VAL_22]], %[[VAL_23]] : f64
+// CHECK:               memref.store %[[VAL_24]], %[[VAL_11]]{{\[}}%[[VAL_19]]] : memref<32xf64>
 // CHECK:             } else {
 // CHECK:               scf.if %[[VAL_5]] {
-// CHECK:                 %[[VAL_26:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_20]]] : memref<32xf64>
-// CHECK:                 %[[VAL_27:.*]] = subf %[[VAL_7]], %[[VAL_26]] : f64
-// CHECK:                 memref.store %[[VAL_27]], %[[VAL_12]]{{\[}}%[[VAL_20]]] : memref<32xf64>
+// CHECK:                 %[[VAL_25:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_19]]] : memref<32xf64>
+// CHECK:                 %[[VAL_26:.*]] = negf %[[VAL_25]] : f64
+// CHECK:                 memref.store %[[VAL_26]], %[[VAL_11]]{{\[}}%[[VAL_19]]] : memref<32xf64>
 // CHECK:               } else {
 // CHECK:               }
 // CHECK:             }
-// CHECK:             %[[VAL_28:.*]] = cmpi eq, %[[VAL_21]], %[[VAL_20]] : index
-// CHECK:             %[[VAL_29:.*]] = addi %[[VAL_19]], %[[VAL_6]] : index
-// CHECK:             %[[VAL_30:.*]] = select %[[VAL_28]], %[[VAL_29]], %[[VAL_19]] : index
-// CHECK:             %[[VAL_31:.*]] = addi %[[VAL_20]], %[[VAL_6]] : index
-// CHECK:             scf.yield %[[VAL_30]], %[[VAL_31]] : index, index
+// CHECK:             %[[VAL_27:.*]] = cmpi eq, %[[VAL_20]], %[[VAL_19]] : index
+// CHECK:             %[[VAL_28:.*]] = addi %[[VAL_18]], %[[VAL_6]] : index
+// CHECK:             %[[VAL_29:.*]] = select %[[VAL_27]], %[[VAL_28]], %[[VAL_18]] : index
+// CHECK:             %[[VAL_30:.*]] = addi %[[VAL_19]], %[[VAL_6]] : index
+// CHECK:             scf.yield %[[VAL_29]], %[[VAL_30]] : index, index
 // CHECK:           }
-// CHECK:           scf.for %[[VAL_32:.*]] = %[[VAL_33:.*]]#1 to %[[VAL_3]] step %[[VAL_6]] {
-// CHECK:             %[[VAL_34:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_32]]] : memref<32xf64>
-// CHECK:             %[[VAL_35:.*]] = subf %[[VAL_7]], %[[VAL_34]] : f64
-// CHECK:             memref.store %[[VAL_35]], %[[VAL_12]]{{\[}}%[[VAL_32]]] : memref<32xf64>
+// CHECK:           scf.for %[[VAL_31:.*]] = %[[VAL_32:.*]]#1 to %[[VAL_3]] step %[[VAL_6]] {
+// CHECK:             %[[VAL_33:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_31]]] : memref<32xf64>
+// CHECK:             %[[VAL_34:.*]] = negf %[[VAL_33]] : f64
+// CHECK:             memref.store %[[VAL_34]], %[[VAL_11]]{{\[}}%[[VAL_31]]] : memref<32xf64>
 // CHECK:           }
-// CHECK:           %[[VAL_36:.*]] = memref.tensor_load %[[VAL_12]] : memref<32xf64>
-// CHECK:           return %[[VAL_36]] : tensor<32xf64>
+// CHECK:           %[[VAL_35:.*]] = memref.tensor_load %[[VAL_11]] : memref<32xf64>
+// CHECK:           return %[[VAL_35]] : tensor<32xf64>
 // CHECK:         }
 func @sub(%arga: tensor<32xf64, #SV>,
           %argb: tensor<32xf64>,
index 84d2557..11f82f9 100644 (file)
@@ -345,3 +345,62 @@ func @or(%arga: tensor<32xi64, #SV>,
   return %0 : tensor<32xi64>
 }
 
+// CHECK-LABEL:   func @xor(
+// CHECK-SAME:             %[[VAL_0:.*]]: tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>>,
+// CHECK-SAME:             %[[VAL_1:.*]]: tensor<32xi64>,
+// CHECK-SAME:             %[[VAL_2:.*]]: tensor<32xi64> {linalg.inplaceable = true}) -> tensor<32xi64> {
+// CHECK:           %[[VAL_3:.*]] = constant 32 : index
+// CHECK:           %[[VAL_4:.*]] = constant 0 : index
+// CHECK:           %[[VAL_5:.*]] = constant true
+// CHECK:           %[[VAL_6:.*]] = constant 1 : index
+// CHECK:           %[[VAL_7:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_4]] : tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
+// CHECK:           %[[VAL_8:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_4]] : tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
+// CHECK:           %[[VAL_9:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xi64>
+// CHECK:           %[[VAL_10:.*]] = memref.buffer_cast %[[VAL_1]] : memref<32xi64>
+// CHECK:           %[[VAL_11:.*]] = memref.buffer_cast %[[VAL_2]] : memref<32xi64>
+// CHECK:           %[[VAL_12:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_4]]] : memref<?xindex>
+// CHECK:           %[[VAL_13:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_6]]] : memref<?xindex>
+// CHECK:           %[[VAL_14:.*]]:2 = scf.while (%[[VAL_15:.*]] = %[[VAL_12]], %[[VAL_16:.*]] = %[[VAL_4]]) : (index, index) -> (index, index) {
+// CHECK:             %[[VAL_17:.*]] = cmpi ult, %[[VAL_15]], %[[VAL_13]] : index
+// CHECK:             scf.condition(%[[VAL_17]]) %[[VAL_15]], %[[VAL_16]] : index, index
+// CHECK:           } do {
+// CHECK:           ^bb0(%[[VAL_18:.*]]: index, %[[VAL_19:.*]]: index):
+// CHECK:             %[[VAL_20:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_18]]] : memref<?xindex>
+// CHECK:             %[[VAL_21:.*]] = cmpi eq, %[[VAL_20]], %[[VAL_19]] : index
+// CHECK:             scf.if %[[VAL_21]] {
+// CHECK:               %[[VAL_22:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_18]]] : memref<?xi64>
+// CHECK:               %[[VAL_23:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_19]]] : memref<32xi64>
+// CHECK:               %[[VAL_24:.*]] = xor %[[VAL_22]], %[[VAL_23]] : i64
+// CHECK:               memref.store %[[VAL_24]], %[[VAL_11]]{{\[}}%[[VAL_19]]] : memref<32xi64>
+// CHECK:             } else {
+// CHECK:               scf.if %[[VAL_5]] {
+// CHECK:                 %[[VAL_25:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_19]]] : memref<32xi64>
+// CHECK:                 memref.store %[[VAL_25]], %[[VAL_11]]{{\[}}%[[VAL_19]]] : memref<32xi64>
+// CHECK:               } else {
+// CHECK:               }
+// CHECK:             }
+// CHECK:             %[[VAL_26:.*]] = cmpi eq, %[[VAL_20]], %[[VAL_19]] : index
+// CHECK:             %[[VAL_27:.*]] = addi %[[VAL_18]], %[[VAL_6]] : index
+// CHECK:             %[[VAL_28:.*]] = select %[[VAL_26]], %[[VAL_27]], %[[VAL_18]] : index
+// CHECK:             %[[VAL_29:.*]] = addi %[[VAL_19]], %[[VAL_6]] : index
+// CHECK:             scf.yield %[[VAL_28]], %[[VAL_29]] : index, index
+// CHECK:           }
+// CHECK:           scf.for %[[VAL_30:.*]] = %[[VAL_31:.*]]#1 to %[[VAL_3]] step %[[VAL_6]] {
+// CHECK:             %[[VAL_32:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_30]]] : memref<32xi64>
+// CHECK:             memref.store %[[VAL_32]], %[[VAL_11]]{{\[}}%[[VAL_30]]] : memref<32xi64>
+// CHECK:           }
+// CHECK:           %[[VAL_33:.*]] = memref.tensor_load %[[VAL_11]] : memref<32xi64>
+// CHECK:           return %[[VAL_33]] : tensor<32xi64>
+// CHECK:         }
+func @xor(%arga: tensor<32xi64, #SV>,
+          %argb: tensor<32xi64>,
+          %argx: tensor<32xi64> {linalg.inplaceable = true}) -> tensor<32xi64> {
+  %0 = linalg.generic #trait2
+     ins(%arga, %argb: tensor<32xi64, #SV>, tensor<32xi64>)
+    outs(%argx: tensor<32xi64>) {
+      ^bb(%a: i64, %b: i64, %x: i64):
+        %0 = xor %a, %b : i64
+        linalg.yield %0 : i64
+  } -> tensor<32xi64>
+  return %0 : tensor<32xi64>
+}
index b544ce1..1cb4d08 100644 (file)
@@ -145,14 +145,24 @@ protected:
     switch (tensorExp.kind) {
     case Kind::kTensor:
       return tensorExp.tensor == pattern->tensorNum;
-    case Kind::kZero:
-      return true;
+    case Kind::kAbsF:
+    case Kind::kCeilF:
+    case Kind::kFloorF:
+    case Kind::kNegF:
+    case Kind::kNegI:
+      return compareExpression(tensorExp.children.e0, pattern->e0);
     case Kind::kMulF:
     case Kind::kMulI:
+    case Kind::kDivF:
+    case Kind::kDivS:
+    case Kind::kDivU:
     case Kind::kAddF:
     case Kind::kAddI:
     case Kind::kSubF:
     case Kind::kSubI:
+    case Kind::kAndI:
+    case Kind::kOrI:
+    case Kind::kXorI:
       return compareExpression(tensorExp.children.e0, pattern->e0) &&
              compareExpression(tensorExp.children.e1, pattern->e1);
     default: