[mlir][sparse] Fix a problem in handling data type conversion.
authorbixia1 <bixia@google.com>
Thu, 30 Mar 2023 20:58:45 +0000 (13:58 -0700)
committerbixia1 <bixia@google.com>
Thu, 30 Mar 2023 21:54:53 +0000 (14:54 -0700)
Previously, the genCast function generates arith.trunci for converting f32 to
i32. Fix the function to use mlir::convertScalarToDtype to correctly handle
conversion cases beyond index casting.

Add a test case for codegen the sparse_tensor.convert op.

Reviewed By: aartbik, Peiming, wrengr

Differential Revision: https://reviews.llvm.org/D147272

mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
mlir/test/Dialect/SparseTensor/codegen.mlir

index bdd6020..957d41b 100644 (file)
@@ -208,28 +208,9 @@ Value sparse_tensor::genCast(OpBuilder &builder, Location loc, Value value,
   if (srcTp.isa<IndexType>() || dstTp.isa<IndexType>())
     return builder.create<arith::IndexCastOp>(loc, dstTp, value);
 
-  const bool ext =
-      srcTp.getIntOrFloatBitWidth() < dstTp.getIntOrFloatBitWidth();
-
-  // float => float.
-  if (srcTp.isa<FloatType>() && dstTp.isa<FloatType>()) {
-    if (ext)
-      return builder.create<arith::ExtFOp>(loc, dstTp, value);
-    return builder.create<arith::TruncFOp>(loc, dstTp, value);
-  }
-
-  // int => int
-  const auto srcIntTp = srcTp.dyn_cast<IntegerType>();
-  if (srcIntTp && dstTp.isa<IntegerType>()) {
-    if (!ext)
-      return builder.create<arith::TruncIOp>(loc, dstTp, value);
-    if (srcIntTp.isUnsigned())
-      return builder.create<arith::ExtUIOp>(loc, dstTp, value);
-    if (srcIntTp.isSigned())
-      return builder.create<arith::ExtSIOp>(loc, dstTp, value);
-  }
-
-  llvm_unreachable("unhandled type casting");
+  const auto srcIntTp = srcTp.dyn_cast_or_null<IntegerType>();
+  const bool isUnsignedCast = srcIntTp ? srcIntTp.isUnsigned() : false;
+  return mlir::convertScalarToDtype(builder, loc, value, dstTp, isUnsignedCast);
 }
 
 mlir::Attribute mlir::sparse_tensor::getOneAttr(Builder &builder, Type tp) {
index 87db474..4a54212 100644 (file)
@@ -663,6 +663,22 @@ func.func @sparse_nop_convert(%arg0: tensor<32xf32, #SparseVector>) -> tensor<?x
   return %0 : tensor<?xf32, #SparseVector>
 }
 
+// CHECK-LABEL: func.func @sparse_convert_element_type(
+//  CHECK-SAME: %[[A1:.*]]: memref<?xi32>,
+//  CHECK-SAME: %[[A2:.*]]: memref<?xi64>,
+//  CHECK-SAME: %[[A3:.*]]: memref<?xf32>,
+//  CHECK-SAME: %[[A4:.*]]: !sparse_tensor.storage_specifier
+//       CHECK: scf.for
+//       CHECK:   %[[FValue:.*]] = memref.load
+//       CHECK:   %[[IValue:.*]] = arith.fptosi %[[FValue]]
+//       CHECK:   memref.store %[[IValue]]
+//       CHECK: return  %{{.*}}, %{{.*}}, %{{.*}}, %[[A4]] :
+//  CHECK-SAME:   memref<?xi32>, memref<?xi64>, memref<?xi32>, !sparse_tensor.storage_specifier
+func.func @sparse_convert_element_type(%arg0: tensor<32xf32, #SparseVector>) -> tensor<?xi32, #SparseVector> {
+  %0 = sparse_tensor.convert %arg0 : tensor<32xf32, #SparseVector> to tensor<?xi32, #SparseVector>
+  return %0 : tensor<?xi32, #SparseVector>
+}
+
 // CHECK-LABEL: func.func @sparse_new_coo(
 // CHECK-SAME:  %[[A0:.*]]: !llvm.ptr<i8>) -> (memref<?xindex>, memref<?xindex>, memref<?xf32>, !sparse_tensor.storage_specifier<#sparse_tensor.encoding<{ dimLevelType = [ "compressed", "singleton" ] }>>) {
 //   CHECK-DAG: %[[A1:.*]] = arith.constant false