[mlir][sparse] improve push_back type checking, printing, parsing
authorAart Bik <ajcbik@google.com>
Tue, 18 Oct 2022 06:08:09 +0000 (23:08 -0700)
committerAart Bik <ajcbik@google.com>
Tue, 18 Oct 2022 16:55:25 +0000 (09:55 -0700)
Rationale:
Enforces type consistency on parsed and generated IR.

Reviewed By: bixia

Differential Revision: https://reviews.llvm.org/D136132

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir
mlir/test/Dialect/SparseTensor/invalid.mlir
mlir/test/Dialect/SparseTensor/roundtrip.mlir
mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_push_back.mlir

index 14e5af0..16a0565 100644 (file)
@@ -238,7 +238,11 @@ def SparseTensor_InsertOp : SparseTensor_Op<"insert",
   let hasVerifier = 1;
 }
 
-def SparseTensor_PushBackOp : SparseTensor_Op<"push_back", []>,
+def SparseTensor_PushBackOp : SparseTensor_Op<"push_back",
+    [TypesMatchWith<"value type matches element type of inBuffer",
+                    "inBuffer", "value",
+                    "$_self.cast<ShapedType>().getElementType()">,
+     AllTypesMatch<["inBuffer", "outBuffer"]>]>,
     Arguments<(ins StridedMemRefRankOf<[Index], [1]>:$bufferSizes,
                StridedMemRefRankOf<[AnyType], [1]>:$inBuffer,
                AnyType:$value, IndexAttr:$idx, UnitAttr:$inbounds)>,
@@ -263,19 +267,18 @@ def SparseTensor_PushBackOp : SparseTensor_Op<"push_back", []>,
     Example:
 
     ```mlir
-    %r = sparse_tensor.push_back %bufferSizes, %buffer, %val {idx = 0 : index}
-      : memref<?xindex>, memref<?xf64>, f64 -> memref<?xf64>
+    %r = sparse_tensor.push_back %bufferSizes, %buffer, %val
+      {idx = 0 : index} : memref<?xindex>, memref<?xf64>, f64
     ```
 
     ```mlir
     %r = sparse_tensor.push_back inbounds %bufferSizes, %buffer, %val
-      {idx = 0 : index} : memref<?xindex>, memref<?xf64>, f64 -> memref<?xf64>
+      {idx = 0 : index} : memref<?xindex>, memref<?xf64>, f64
     ```
   }];
   let assemblyFormat = "(`inbounds` $inbounds^)? $bufferSizes `,` $inBuffer"
                        " `,` $value attr-dict `:` type($bufferSizes) `,`"
-                       " type($inBuffer) `,` type($value) `to`"
-                       " type($outBuffer)";
+                       " type($inBuffer) `,` type($value)";
 }
 
 def SparseTensor_ExpandOp : SparseTensor_Op<"expand", []>,
index dd81c52..31c6ad5 100644 (file)
@@ -22,7 +22,7 @@
 //       CHECK: memref.store %[[P2]], %[[A]]{{\[}}%[[C2]]]
 //       CHECK: return %[[M]] : memref<?xf64>
 func.func @sparse_push_back(%arg0: memref<?xindex>, %arg1: memref<?xf64>, %arg2: f64) -> memref<?xf64> {
-  %0 = sparse_tensor.push_back %arg0, %arg1, %arg2 {idx = 2 : index} : memref<?xindex>, memref<?xf64>, f64 to memref<?xf64>
+  %0 = sparse_tensor.push_back %arg0, %arg1, %arg2 {idx = 2 : index} : memref<?xindex>, memref<?xf64>, f64
   return %0 : memref<?xf64>
 }
 
@@ -40,7 +40,7 @@ func.func @sparse_push_back(%arg0: memref<?xindex>, %arg1: memref<?xf64>, %arg2:
 //       CHECK: memref.store %[[P2]], %[[A]]{{\[}}%[[C2]]]
 //       CHECK: return %[[B]] : memref<?xf64>
 func.func @sparse_push_back_inbound(%arg0: memref<?xindex>, %arg1: memref<?xf64>, %arg2: f64) -> memref<?xf64> {
-  %0 = sparse_tensor.push_back inbounds %arg0, %arg1, %arg2 {idx = 2 : index} : memref<?xindex>, memref<?xf64>, f64 to memref<?xf64>
+  %0 = sparse_tensor.push_back inbounds %arg0, %arg1, %arg2 {idx = 2 : index} : memref<?xindex>, memref<?xf64>, f64
   return %0 : memref<?xf64>
 }
 
index 7ac5179..21aa160 100644 (file)
@@ -124,6 +124,14 @@ func.func @sparse_wrong_arity_insert(%arg0: tensor<128x64xf64, #CSR>, %arg1: ind
 
 // -----
 
+func.func @sparse_push_back(%arg0: memref<?xindex>, %arg1: memref<?xf64>, %arg2: f32) -> memref<?xf64> {
+  // expected-error@+1 {{'sparse_tensor.push_back' op failed to verify that value type matches element type of inBuffer}}
+  %0 = sparse_tensor.push_back %arg0, %arg1, %arg2 {idx = 2 : index} : memref<?xindex>, memref<?xf64>, f32
+  return %0 : memref<?xf64>
+}
+
+// -----
+
 func.func @sparse_unannotated_expansion(%arg0: tensor<128xf64>) {
   // expected-error@+1 {{'sparse_tensor.expand' op operand #0 must be sparse tensor of any type values, but got 'tensor<128xf64>'}}
   %values, %filled, %added, %count = sparse_tensor.expand %arg0
index ca3f884..8c8ab6b 100644 (file)
@@ -136,10 +136,10 @@ func.func @sparse_insert(%arg0: tensor<128xf64, #SparseVector>, %arg1: index, %a
 //  CHECK-SAME: %[[A:.*]]: memref<?xindex>,
 //  CHECK-SAME: %[[B:.*]]: memref<?xf64>,
 //  CHECK-SAME: %[[C:.*]]: f64) -> memref<?xf64> {
-//       CHECK: %[[D:.*]] = sparse_tensor.push_back %[[A]], %[[B]], %[[C]] {idx = 2 : index} : memref<?xindex>, memref<?xf64>, f64 to memref<?xf64>
+//       CHECK: %[[D:.*]] = sparse_tensor.push_back %[[A]], %[[B]], %[[C]] {idx = 2 : index} : memref<?xindex>, memref<?xf64>, f64
 //       CHECK: return %[[D]]
 func.func @sparse_push_back(%arg0: memref<?xindex>, %arg1: memref<?xf64>, %arg2: f64) -> memref<?xf64> {
-  %0 = sparse_tensor.push_back %arg0, %arg1, %arg2 {idx = 2 : index} : memref<?xindex>, memref<?xf64>, f64 to memref<?xf64>
+  %0 = sparse_tensor.push_back %arg0, %arg1, %arg2 {idx = 2 : index} : memref<?xindex>, memref<?xf64>, f64
   return %0 : memref<?xf64>
 }
 
@@ -149,10 +149,10 @@ func.func @sparse_push_back(%arg0: memref<?xindex>, %arg1: memref<?xf64>, %arg2:
 //  CHECK-SAME: %[[A:.*]]: memref<?xindex>,
 //  CHECK-SAME: %[[B:.*]]: memref<?xf64>,
 //  CHECK-SAME: %[[C:.*]]: f64) -> memref<?xf64> {
-//       CHECK: %[[D:.*]] = sparse_tensor.push_back inbounds %[[A]], %[[B]], %[[C]] {idx = 2 : index} : memref<?xindex>, memref<?xf64>, f64 to memref<?xf64>
+//       CHECK: %[[D:.*]] = sparse_tensor.push_back inbounds %[[A]], %[[B]], %[[C]] {idx = 2 : index} : memref<?xindex>, memref<?xf64>, f64
 //       CHECK: return %[[D]]
 func.func @sparse_push_back_inbound(%arg0: memref<?xindex>, %arg1: memref<?xf64>, %arg2: f64) -> memref<?xf64> {
-  %0 = sparse_tensor.push_back inbounds %arg0, %arg1, %arg2 {idx = 2 : index} : memref<?xindex>, memref<?xf64>, f64 to memref<?xf64>
+  %0 = sparse_tensor.push_back inbounds %arg0, %arg1, %arg2 {idx = 2 : index} : memref<?xindex>, memref<?xf64>, f64
   return %0 : memref<?xf64>
 }
 
index ff57bfe..90d1b37 100644 (file)
@@ -16,8 +16,8 @@ module {
     %buffer = memref.alloc(%c1) : memref<?xf32>
 
     memref.store %c0, %bufferSizes[%c0] : memref<?xindex>
-    %buffer2 = sparse_tensor.push_back %bufferSizes, %buffer, %d2 {idx=0 : index} : memref<?xindex>, memref<?xf32>, f32 to memref<?xf32>
-    %buffer3 = sparse_tensor.push_back %bufferSizes, %buffer2, %d1 {idx=0 : index} : memref<?xindex>, memref<?xf32>, f32 to memref<?xf32>
+    %buffer2 = sparse_tensor.push_back %bufferSizes, %buffer, %d2 {idx=0 : index} : memref<?xindex>, memref<?xf32>, f32
+    %buffer3 = sparse_tensor.push_back %bufferSizes, %buffer2, %d1 {idx=0 : index} : memref<?xindex>, memref<?xf32>, f32
 
     // CHECK: ( 2 )
     %sizeValue = vector.transfer_read %bufferSizes[%c0], %c0: memref<?xindex>, vector<1xindex>