[MLIR][Arith][NFC] Use the interface of 'getElementTypeOrSelf' to get the resType
authorliqinweng <Liqin.Weng@streamcomputing.com>
Fri, 30 Dec 2022 03:54:06 +0000 (11:54 +0800)
committerliqinweng <Liqin.Weng@streamcomputing.com>
Fri, 30 Dec 2022 03:54:06 +0000 (11:54 +0800)
Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D140608

mlir/lib/Dialect/Arith/IR/ArithOps.cpp

index f6446ea..e61169c 100644 (file)
@@ -1176,12 +1176,9 @@ OpFoldResult arith::ExtUIOp::fold(ArrayRef<Attribute> operands) {
     getInMutable().assign(lhs.getIn());
     return getResult();
   }
-  Type resType = getType();
-  unsigned bitWidth;
-  if (auto shapedType = resType.dyn_cast<ShapedType>())
-    bitWidth = shapedType.getElementTypeBitWidth();
-  else
-    bitWidth = resType.getIntOrFloatBitWidth();
+
+  Type resType = getElementTypeOrSelf(getType());
+  unsigned bitWidth = resType.cast<IntegerType>().getWidth();
   return constFoldCastOp<IntegerAttr, IntegerAttr>(
       operands, getType(), [bitWidth](const APInt &a, bool &castStatus) {
         return a.zext(bitWidth);
@@ -1205,12 +1202,9 @@ OpFoldResult arith::ExtSIOp::fold(ArrayRef<Attribute> operands) {
     getInMutable().assign(lhs.getIn());
     return getResult();
   }
-  Type resType = getType();
-  unsigned bitWidth;
-  if (auto shapedType = resType.dyn_cast<ShapedType>())
-    bitWidth = shapedType.getElementTypeBitWidth();
-  else
-    bitWidth = resType.getIntOrFloatBitWidth();
+
+  Type resType = getElementTypeOrSelf(getType());
+  unsigned bitWidth = resType.cast<IntegerType>().getWidth();
   return constFoldCastOp<IntegerAttr, IntegerAttr>(
       operands, getType(), [bitWidth](const APInt &a, bool &castStatus) {
         return a.sext(bitWidth);
@@ -1259,13 +1253,8 @@ OpFoldResult arith::TruncIOp::fold(ArrayRef<Attribute> operands) {
     return getResult();
   }
 
-  Type resType = getType();
-  unsigned bitWidth;
-  if (auto shapedType = resType.dyn_cast<ShapedType>())
-    bitWidth = shapedType.getElementTypeBitWidth();
-  else
-    bitWidth = resType.getIntOrFloatBitWidth();
-
+  Type resType = getElementTypeOrSelf(getType());
+  unsigned bitWidth = resType.cast<IntegerType>().getWidth();
   return constFoldCastOp<IntegerAttr, IntegerAttr>(
       operands, getType(), [bitWidth](const APInt &a, bool &castStatus) {
         return a.trunc(bitWidth);
@@ -1361,12 +1350,7 @@ bool arith::UIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
 }
 
 OpFoldResult arith::UIToFPOp::fold(ArrayRef<Attribute> operands) {
-  Type resType = getType();
-  Type resEleType;
-  if (auto shapedType = resType.dyn_cast<ShapedType>())
-    resEleType = shapedType.getElementType();
-  else
-    resEleType = resType;
+  Type resEleType = getElementTypeOrSelf(getType());
   return constFoldCastOp<IntegerAttr, FloatAttr>(
       operands, getType(), [&resEleType](const APInt &a, bool &castStatus) {
         FloatType floatTy = resEleType.cast<FloatType>();
@@ -1387,12 +1371,7 @@ bool arith::SIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
 }
 
 OpFoldResult arith::SIToFPOp::fold(ArrayRef<Attribute> operands) {
-  Type resType = getType();
-  Type resEleType;
-  if (auto shapedType = resType.dyn_cast<ShapedType>())
-    resEleType = shapedType.getElementType();
-  else
-    resEleType = resType;
+  Type resEleType = getElementTypeOrSelf(getType());
   return constFoldCastOp<IntegerAttr, FloatAttr>(
       operands, getType(), [&resEleType](const APInt &a, bool &castStatus) {
         FloatType floatTy = resEleType.cast<FloatType>();
@@ -1412,17 +1391,12 @@ bool arith::FPToUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
 }
 
 OpFoldResult arith::FPToUIOp::fold(ArrayRef<Attribute> operands) {
-  Type resType = getType();
-  Type resEleType;
-  if (auto shapedType = resType.dyn_cast<ShapedType>())
-    resEleType = shapedType.getElementType();
-  else
-    resEleType = resType;
+  Type resType = getElementTypeOrSelf(getType());
+  unsigned bitWidth = resType.cast<IntegerType>().getWidth();
   return constFoldCastOp<FloatAttr, IntegerAttr>(
-      operands, getType(), [&resEleType](const APFloat &a, bool &castStatus) {
-        IntegerType intTy = resEleType.cast<IntegerType>();
+      operands, getType(), [&bitWidth](const APFloat &a, bool &castStatus) {
         bool ignored;
-        APSInt api(intTy.getWidth(), /*isUnsigned=*/true);
+        APSInt api(bitWidth, /*isUnsigned=*/true);
         castStatus = APFloat::opInvalidOp !=
                      a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
         return api;
@@ -1438,17 +1412,12 @@ bool arith::FPToSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
 }
 
 OpFoldResult arith::FPToSIOp::fold(ArrayRef<Attribute> operands) {
-  Type resType = getType();
-  Type resEleType;
-  if (auto shapedType = resType.dyn_cast<ShapedType>())
-    resEleType = shapedType.getElementType();
-  else
-    resEleType = resType;
+  Type resType = getElementTypeOrSelf(getType());
+  unsigned bitWidth = resType.cast<IntegerType>().getWidth();
   return constFoldCastOp<FloatAttr, IntegerAttr>(
-      operands, getType(), [&resEleType](const APFloat &a, bool &castStatus) {
-        IntegerType intTy = resEleType.cast<IntegerType>();
+      operands, getType(), [&bitWidth](const APFloat &a, bool &castStatus) {
         bool ignored;
-        APSInt api(intTy.getWidth(), /*isUnsigned=*/false);
+        APSInt api(bitWidth, /*isUnsigned=*/false);
         castStatus = APFloat::opInvalidOp !=
                      a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
         return api;