[mlir][sparse] refine insertion code
authorAart Bik <ajcbik@google.com>
Tue, 18 Oct 2022 17:35:00 +0000 (10:35 -0700)
committerAart Bik <ajcbik@google.com>
Tue, 18 Oct 2022 21:16:38 +0000 (14:16 -0700)
builds SSA cycle for compress insertion loop
adds casting on index mismatch during push_back

Reviewed By: Peiming

Differential Revision: https://reviews.llvm.org/D136186

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
mlir/test/Dialect/SparseTensor/codegen.mlir

index 5e5815b..77bfef7 100644 (file)
@@ -265,11 +265,14 @@ static void createAllocFields(OpBuilder &builder, Location loc, Type type,
 }
 
 /// Creates a straightforward counting for-loop.
-static scf::ForOp createFor(OpBuilder &builder, Location loc, Value count) {
+static scf::ForOp createFor(OpBuilder &builder, Location loc, Value count,
+                            SmallVectorImpl<Value> &fields) {
   Type indexType = builder.getIndexType();
   Value zero = constantZero(builder, loc, indexType);
   Value one = constantOne(builder, loc, indexType);
-  scf::ForOp forOp = builder.create<scf::ForOp>(loc, zero, count, one);
+  scf::ForOp forOp = builder.create<scf::ForOp>(loc, zero, count, one, fields);
+  for (unsigned i = 0, e = fields.size(); i < e; i++)
+    fields[i] = forOp.getRegionIterArg(i);
   builder.setInsertionPointToStart(forOp.getBody());
   return forOp;
 }
@@ -280,6 +283,9 @@ static void createPushback(OpBuilder &builder, Location loc,
                            SmallVectorImpl<Value> &fields, unsigned field,
                            Value value) {
   assert(field < fields.size());
+  Type etp = fields[field].getType().cast<ShapedType>().getElementType();
+  if (value.getType() != etp)
+    value = builder.create<arith::IndexCastOp>(loc, etp, value);
   fields[field] =
       builder.create<PushBackOp>(loc, fields[field].getType(), fields[1],
                                  fields[field], value, APInt(64, field));
@@ -298,11 +304,8 @@ static void genInsert(OpBuilder &builder, Location loc, RankedTensorType rtp,
   if (rank != 1 || !isCompressedDim(rtp, 0) || !isUniqueDim(rtp, 0) ||
       !isOrderedDim(rtp, 0))
     return; // TODO: add codegen
-  // push_back memSizes pointers-0 0
   // push_back memSizes indices-0 index
   // push_back memSizes values    value
-  Value zero = constantIndex(builder, loc, 0);
-  createPushback(builder, loc, fields, 2, zero);
   createPushback(builder, loc, fields, 3, indices[0]);
   createPushback(builder, loc, fields, 4, value);
 }
@@ -316,9 +319,12 @@ static void genEndInsert(OpBuilder &builder, Location loc, RankedTensorType rtp,
   if (rtp.getShape().size() != 1 || !isCompressedDim(rtp, 0) ||
       !isUniqueDim(rtp, 0) || !isOrderedDim(rtp, 0))
     return; // TODO: add codegen
+  // push_back memSizes pointers-0 0
   // push_back memSizes pointers-0 memSizes[2]
+  Value zero = constantIndex(builder, loc, 0);
   Value two = constantIndex(builder, loc, 2);
   Value size = builder.create<memref::LoadOp>(loc, fields[1], two);
+  createPushback(builder, loc, fields, 2, zero);
   createPushback(builder, loc, fields, 2, size);
 }
 
@@ -460,6 +466,7 @@ public:
     Location loc = op.getLoc();
     SmallVector<Value, 8> fields;
     createAllocFields(rewriter, loc, resType, adaptor.getOperands(), fields);
+    // Replace operation with resulting memrefs.
     rewriter.replaceOp(op, genTuple(rewriter, loc, resType, fields));
     return success();
   }
@@ -504,6 +511,7 @@ public:
     // Generate optional insertion finalization code.
     if (op.getHasInserts())
       genEndInsert(rewriter, op.getLoc(), srcType, fields);
+    // Replace operation with resulting memrefs.
     rewriter.replaceOp(op, genTuple(rewriter, op.getLoc(), srcType, fields));
     return success();
   }
@@ -591,23 +599,26 @@ public:
     // sparsity of the expanded access pattern.
     //
     // Generate
-    //    for (i = 0; i < count; i++) {
+    //    out_memrefs = for (i = 0; i < count; i++)(in_memrefs) {
     //      index = added[i];
     //      value = values[index];
     //      insert({prev_indices, index}, value);
+    //      new_memrefs = insert(in_memrefs, {prev_indices, index}, value);
     //      values[index] = 0;
     //      filled[index] = false;
+    //      yield new_memrefs
     //    }
-    Value i = createFor(rewriter, loc, count).getInductionVar();
+    scf::ForOp loop = createFor(rewriter, loc, count, fields);
+    Value i = loop.getInductionVar();
     Value index = rewriter.create<memref::LoadOp>(loc, added, i);
     Value value = rewriter.create<memref::LoadOp>(loc, values, index);
     indices.push_back(index);
-    // TODO: generate yield cycle
     genInsert(rewriter, loc, dstType, fields, indices, value);
     rewriter.create<memref::StoreOp>(loc, constantZero(rewriter, loc, eltType),
                                      values, index);
     rewriter.create<memref::StoreOp>(loc, constantI1(rewriter, loc, false),
                                      filled, index);
+    rewriter.create<scf::YieldOp>(loc, fields);
     // Deallocate the buffers on exit of the full loop nest.
     Operation *parent = op;
     for (; isa<scf::ForOp>(parent->getParentOp()) ||
@@ -620,7 +631,9 @@ public:
     rewriter.create<memref::DeallocOp>(loc, values);
     rewriter.create<memref::DeallocOp>(loc, filled);
     rewriter.create<memref::DeallocOp>(loc, added);
-    rewriter.replaceOp(op, genTuple(rewriter, loc, dstType, fields));
+    // Replace operation with resulting memrefs.
+    rewriter.replaceOp(op,
+                       genTuple(rewriter, loc, dstType, loop->getResults()));
     return success();
   }
 };
@@ -641,6 +654,7 @@ public:
     // Generate insertion.
     Value value = adaptor.getValue();
     genInsert(rewriter, op->getLoc(), dstType, fields, indices, value);
+    // Replace operation with resulting memrefs.
     rewriter.replaceOp(op, genTuple(rewriter, op.getLoc(), dstType, fields));
     return success();
   }
index ecdf0e4..b469e66 100644 (file)
@@ -354,6 +354,49 @@ func.func @sparse_expansion3(%arg0: index, %arg1: index) -> memref<?xindex> {
   return %added : memref<?xindex>
 }
 
+// CHECK-LABEL: func @sparse_compression_1d(
+//  CHECK-SAME: %[[A0:.*0]]: memref<1xindex>,
+//  CHECK-SAME: %[[A1:.*1]]: memref<3xindex>,
+//  CHECK-SAME: %[[A2:.*2]]: memref<?xindex>,
+//  CHECK-SAME: %[[A3:.*3]]: memref<?xindex>,
+//  CHECK-SAME: %[[A4:.*4]]: memref<?xf64>,
+//  CHECK-SAME: %[[A5:.*5]]: memref<?xf64>,
+//  CHECK-SAME: %[[A6:.*6]]: memref<?xi1>,
+//  CHECK-SAME: %[[A7:.*7]]: memref<?xindex>,
+//  CHECK-SAME: %[[A8:.*8]]: index)
+//   CHECK-DAG:  %[[B0:.*]] = arith.constant false
+//   CHECK-DAG:  %[[F0:.*]] = arith.constant 0.000000e+00 : f64
+//   CHECK-DAG:  %[[C0:.*]] = arith.constant 0 : index
+//   CHECK-DAG:  %[[C1:.*]] = arith.constant 1 : index
+//   CHECK-DAG:  %[[C2:.*]] = arith.constant 2 : index
+//       CHECK:  sparse_tensor.sort %[[A8]], %[[A7]] : memref<?xindex>
+//       CHECK:  %[[R:.*]]:2 = scf.for %[[I:.*]] = %[[C0]] to %[[A8]] step %[[C1]] iter_args(%[[P0:.*]] = %[[A3]], %[[P1:.*]] = %[[A4]]) -> (memref<?xindex>, memref<?xf64>) {
+//       CHECK:    %[[T1:.*]] = memref.load %[[A7]][%[[I]]] : memref<?xindex>
+//       CHECK:    %[[T2:.*]] = memref.load %[[A5]][%[[T1]]] : memref<?xf64>
+//       CHECK:    %[[T3:.*]] = sparse_tensor.push_back %[[A1]], %[[P0]], %[[T1]] {idx = 3 : index} : memref<3xindex>, memref<?xindex>, index
+//       CHECK:    %[[T4:.*]] = sparse_tensor.push_back %[[A1]], %[[P1]], %[[T2]] {idx = 4 : index} : memref<3xindex>, memref<?xf64>, f64
+//       CHECK:    memref.store %[[F0]], %arg5[%[[T1]]] : memref<?xf64>
+//       CHECK:    memref.store %[[B0]], %arg6[%[[T1]]] : memref<?xi1>
+//       CHECK:    scf.yield %[[T3]], %[[T4]] : memref<?xindex>, memref<?xf64>
+//       CHECK:  }
+//       CHECK:  memref.dealloc %[[A5]] : memref<?xf64>
+//       CHECK:  memref.dealloc %[[A6]] : memref<?xi1>
+//       CHECK:  memref.dealloc %[[A7]] : memref<?xindex>
+//       CHECK:  %[[LL:.*]] = memref.load %[[A1]][%[[C2]]] : memref<3xindex>
+//       CHECK:    %[[P1:.*]] = sparse_tensor.push_back %[[A1]], %[[A2]], %[[C0]] {idx = 2 : index} : memref<3xindex>, memref<?xindex>, index
+//       CHECK:    %[[P2:.*]] = sparse_tensor.push_back %[[A1]], %[[P1]], %[[LL]] {idx = 2 : index} : memref<3xindex>, memref<?xindex>, index
+//       CHECK:  return %[[A0]], %[[A1]], %[[P2]], %[[R]]#0, %[[R]]#1 : memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>
+func.func @sparse_compression_1d(%tensor: tensor<100xf64, #SV>,
+                                 %values: memref<?xf64>,
+                                 %filled: memref<?xi1>,
+                                 %added: memref<?xindex>,
+                                 %count: index) -> tensor<100xf64, #SV> {
+  %0 = sparse_tensor.compress %values, %filled, %added, %count into %tensor[]
+    : memref<?xf64>, memref<?xi1>, memref<?xindex>, tensor<100xf64, #SV>
+  %1 = sparse_tensor.load %0 hasInserts : tensor<100xf64, #SV>
+  return %1 : tensor<100xf64, #SV>
+}
+
 // CHECK-LABEL: func @sparse_compression(
 //  CHECK-SAME: %[[A0:.*0]]: memref<2xindex>,
 //  CHECK-SAME: %[[A1:.*1]]: memref<3xindex>,
@@ -372,7 +415,7 @@ func.func @sparse_expansion3(%arg0: index, %arg1: index) -> memref<?xindex> {
 //       CHECK: sparse_tensor.sort %[[A8]], %[[A7]] : memref<?xindex>
 //  CHECK-NEXT: scf.for %[[I:.*]] = %[[C0]] to %[[A8]] step %[[C1]] {
 //  CHECK-NEXT:   %[[INDEX:.*]] = memref.load %[[A7]][%[[I]]] : memref<?xindex>
-//        TODO:   insert
+//        TODO:   2D-insert
 //   CHECK-DAG:   memref.store %[[F0]], %[[A5]][%[[INDEX]]] : memref<?xf64>
 //   CHECK-DAG:   memref.store %[[B0]], %[[A6]][%[[INDEX]]] : memref<?xi1>
 //  CHECK-NEXT: }
@@ -388,7 +431,8 @@ func.func @sparse_compression(%tensor: tensor<8x8xf64, #CSR>,
                               %i: index) -> tensor<8x8xf64, #CSR> {
   %0 = sparse_tensor.compress %values, %filled, %added, %count into %tensor[%i]
     : memref<?xf64>, memref<?xi1>, memref<?xindex>, tensor<8x8xf64, #CSR>
-  return %0 : tensor<8x8xf64, #CSR>
+  %1 = sparse_tensor.load %0 hasInserts : tensor<8x8xf64, #CSR>
+  return %1 : tensor<8x8xf64, #CSR>
 }
 
 // CHECK-LABEL: func @sparse_compression_unordered(
@@ -409,7 +453,7 @@ func.func @sparse_compression(%tensor: tensor<8x8xf64, #CSR>,
 //   CHECK-NOT: sparse_tensor.sort
 //  CHECK-NEXT: scf.for %[[I:.*]] = %[[C0]] to %[[A8]] step %[[C1]] {
 //  CHECK-NEXT:   %[[INDEX:.*]] = memref.load %[[A7]][%[[I]]] : memref<?xindex>
-//        TODO:   insert
+//        TODO:   2D-insert
 //   CHECK-DAG:   memref.store %[[F0]], %[[A5]][%[[INDEX]]] : memref<?xf64>
 //   CHECK-DAG:   memref.store %[[B0]], %[[A6]][%[[INDEX]]] : memref<?xi1>
 //  CHECK-NEXT: }
@@ -425,7 +469,8 @@ func.func @sparse_compression_unordered(%tensor: tensor<8x8xf64, #UCSR>,
                                         %i: index) -> tensor<8x8xf64, #UCSR> {
   %0 = sparse_tensor.compress %values, %filled, %added, %count into %tensor[%i]
     : memref<?xf64>, memref<?xi1>, memref<?xindex>, tensor<8x8xf64, #UCSR>
-  return %0 : tensor<8x8xf64, #UCSR>
+  %1 = sparse_tensor.load %0 hasInserts : tensor<8x8xf64, #UCSR>
+  return %1 : tensor<8x8xf64, #UCSR>
 }
 
 // CHECK-LABEL: func @sparse_insert(
@@ -438,10 +483,10 @@ func.func @sparse_compression_unordered(%tensor: tensor<8x8xf64, #UCSR>,
 //  CHECK-SAME: %[[A6:.*6]]: f64)
 //   CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
 //   CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
-//       CHECK: %[[T0:.*]] = sparse_tensor.push_back %[[A1]], %[[A2]], %[[C0]]
 //       CHECK: %[[T1:.*]] = sparse_tensor.push_back %[[A1]], %[[A3]], %[[A5]]
 //       CHECK: %[[T2:.*]] = sparse_tensor.push_back %[[A1]], %[[A4]], %[[A6]]
 //       CHECK: %[[T3:.*]] = memref.load %[[A1]][%[[C2]]] : memref<3xindex>
+//       CHECK: %[[T0:.*]] = sparse_tensor.push_back %[[A1]], %[[A2]], %[[C0]]
 //       CHECK: %[[T4:.*]] = sparse_tensor.push_back %[[A1]], %[[T0]], %[[T3]]
 //       CHECK: return %[[A0]], %[[A1]], %[[T4]], %[[T1]], %[[T2]] : memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>
 func.func @sparse_insert(%arg0: tensor<128xf64, #SV>, %arg1: index, %arg2: f64) -> tensor<128xf64, #SV> {
@@ -449,3 +494,27 @@ func.func @sparse_insert(%arg0: tensor<128xf64, #SV>, %arg1: index, %arg2: f64)
   %1 = sparse_tensor.load %0 hasInserts : tensor<128xf64, #SV>
   return %1 : tensor<128xf64, #SV>
 }
+
+// CHECK-LABEL: func @sparse_insert_typed(
+//  CHECK-SAME: %[[A0:.*0]]: memref<1xindex>,
+//  CHECK-SAME: %[[A1:.*1]]: memref<3xindex>,
+//  CHECK-SAME: %[[A2:.*2]]: memref<?xi32>,
+//  CHECK-SAME: %[[A3:.*3]]: memref<?xi64>,
+//  CHECK-SAME: %[[A4:.*4]]: memref<?xf64>,
+//  CHECK-SAME: %[[A5:.*5]]: index,
+//  CHECK-SAME: %[[A6:.*6]]: f64)
+//   CHECK-DAG: %[[C0:.*]] = arith.constant 0 : i32
+//   CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+//       CHECK: %[[S1:.*]] = arith.index_cast %[[A5]] : index to i64
+//       CHECK: %[[T1:.*]] = sparse_tensor.push_back %[[A1]], %[[A3]], %[[S1]]
+//       CHECK: %[[T2:.*]] = sparse_tensor.push_back %[[A1]], %[[A4]], %[[A6]]
+//       CHECK: %[[T3:.*]] = memref.load %[[A1]][%[[C2]]] : memref<3xindex>
+//       CHECK: %[[T0:.*]] = sparse_tensor.push_back %[[A1]], %[[A2]], %[[C0]]
+//       CHECK: %[[S2:.*]] = arith.index_cast %[[T3]] : index to i32
+//       CHECK: %[[T4:.*]] = sparse_tensor.push_back %[[A1]], %[[T0]], %[[S2]]
+//       CHECK: return %[[A0]], %[[A1]], %[[T4]], %[[T1]], %[[T2]] : memref<1xindex>, memref<3xindex>, memref<?xi32>, memref<?xi64>, memref<?xf64>
+func.func @sparse_insert_typed(%arg0: tensor<128xf64, #SparseVector>, %arg1: index, %arg2: f64) -> tensor<128xf64, #SparseVector> {
+  %0 = sparse_tensor.insert %arg2 into %arg0[%arg1] : tensor<128xf64, #SparseVector>
+  %1 = sparse_tensor.load %0 hasInserts : tensor<128xf64, #SparseVector>
+  return %1 : tensor<128xf64, #SparseVector>
+}