From 6071f6fd67aa7ee7b9a29788118dc46be7f6cdcf Mon Sep 17 00:00:00 2001 From: bixia1 Date: Thu, 30 Mar 2023 13:58:45 -0700 Subject: [PATCH] [mlir][sparse] Fix a problem in handling data type conversion. Previously, the genCast function generates arith.trunci for converting f32 to i32. Fix the function to use mlir::convertScalarToDtype to correctly handle conversion cases beyond index casting. Add a test case for codegen the sparse_tensor.convert op. Reviewed By: aartbik, Peiming, wrengr Differential Revision: https://reviews.llvm.org/D147272 --- .../SparseTensor/Transforms/CodegenUtils.cpp | 25 +++------------------- mlir/test/Dialect/SparseTensor/codegen.mlir | 16 ++++++++++++++ 2 files changed, 19 insertions(+), 22 deletions(-) diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp index bdd6020..957d41b 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp @@ -208,28 +208,9 @@ Value sparse_tensor::genCast(OpBuilder &builder, Location loc, Value value, if (srcTp.isa() || dstTp.isa()) return builder.create(loc, dstTp, value); - const bool ext = - srcTp.getIntOrFloatBitWidth() < dstTp.getIntOrFloatBitWidth(); - - // float => float. - if (srcTp.isa() && dstTp.isa()) { - if (ext) - return builder.create(loc, dstTp, value); - return builder.create(loc, dstTp, value); - } - - // int => int - const auto srcIntTp = srcTp.dyn_cast(); - if (srcIntTp && dstTp.isa()) { - if (!ext) - return builder.create(loc, dstTp, value); - if (srcIntTp.isUnsigned()) - return builder.create(loc, dstTp, value); - if (srcIntTp.isSigned()) - return builder.create(loc, dstTp, value); - } - - llvm_unreachable("unhandled type casting"); + const auto srcIntTp = srcTp.dyn_cast_or_null(); + const bool isUnsignedCast = srcIntTp ? srcIntTp.isUnsigned() : false; + return mlir::convertScalarToDtype(builder, loc, value, dstTp, isUnsignedCast); } mlir::Attribute mlir::sparse_tensor::getOneAttr(Builder &builder, Type tp) { diff --git a/mlir/test/Dialect/SparseTensor/codegen.mlir b/mlir/test/Dialect/SparseTensor/codegen.mlir index 87db474..4a54212 100644 --- a/mlir/test/Dialect/SparseTensor/codegen.mlir +++ b/mlir/test/Dialect/SparseTensor/codegen.mlir @@ -663,6 +663,22 @@ func.func @sparse_nop_convert(%arg0: tensor<32xf32, #SparseVector>) -> tensor } +// CHECK-LABEL: func.func @sparse_convert_element_type( +// CHECK-SAME: %[[A1:.*]]: memref, +// CHECK-SAME: %[[A2:.*]]: memref, +// CHECK-SAME: %[[A3:.*]]: memref, +// CHECK-SAME: %[[A4:.*]]: !sparse_tensor.storage_specifier +// CHECK: scf.for +// CHECK: %[[FValue:.*]] = memref.load +// CHECK: %[[IValue:.*]] = arith.fptosi %[[FValue]] +// CHECK: memref.store %[[IValue]] +// CHECK: return %{{.*}}, %{{.*}}, %{{.*}}, %[[A4]] : +// CHECK-SAME: memref, memref, memref, !sparse_tensor.storage_specifier +func.func @sparse_convert_element_type(%arg0: tensor<32xf32, #SparseVector>) -> tensor { + %0 = sparse_tensor.convert %arg0 : tensor<32xf32, #SparseVector> to tensor + return %0 : tensor +} + // CHECK-LABEL: func.func @sparse_new_coo( // CHECK-SAME: %[[A0:.*]]: !llvm.ptr) -> (memref, memref, memref, !sparse_tensor.storage_specifier<#sparse_tensor.encoding<{ dimLevelType = [ "compressed", "singleton" ] }>>) { // CHECK-DAG: %[[A1:.*]] = arith.constant false -- 2.7.4