Add support for casting elements in vectors for certain Std dialect type conversion...
authorLubomir Litchev <Lubomir.Litchev@intel.com>
Wed, 9 Sep 2020 19:34:08 +0000 (12:34 -0700)
committerLubomir Litchev <Lubomir.Litchev@intel.com>
Mon, 14 Sep 2020 14:45:46 +0000 (07:45 -0700)
Added support to the Std dialect cast operations to do casts in vector types when feasible.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D87410

mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
mlir/lib/Dialect/StandardOps/IR/Ops.cpp
mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir

index afdc3ed..4d0cf76 100644 (file)
@@ -2443,10 +2443,10 @@ def SignExtendIOp : Std_Op<"sexti",
 def SIToFPOp : CastOp<"sitofp">, Arguments<(ins AnyType:$in)> {
   let summary = "cast from integer type to floating-point";
   let description = [{
-    Cast from a value interpreted as signed integer to the corresponding
-    floating-point value. If the value cannot be exactly represented, it is
-    rounded using the default rounding mode. Only scalars are currently
-    supported.
+    Cast from a value interpreted as signed or vector of signed integers to the
+    corresponding floating-point scalar or vector value. If the value cannot be
+    exactly represented, it is rounded using the default rounding mode. Scalars
+    and vector types are currently supported.
   }];
 
   let extraClassDeclaration = [{
@@ -3124,10 +3124,10 @@ def TruncateIOp : Std_Op<"trunci", [NoSideEffect, SameOperandsAndResultShape]> {
 def UIToFPOp : CastOp<"uitofp">, Arguments<(ins AnyType:$in)> {
   let summary = "cast from unsigned integer type to floating-point";
   let description = [{
-    Cast from a value interpreted as unsigned integer to the corresponding
-    floating-point value. If the value cannot be exactly represented, it is
-    rounded using the default rounding mode. Only scalars are currently
-    supported.
+    Cast from a value interpreted as unsigned integer or vector of unsigned
+    integers to the corresponding scalar or vector floating-point value. If the
+    value cannot be exactly represented, it is rounded using the default
+    rounding mode. Scalars and vector types are currently supported.
   }];
 
   let extraClassDeclaration = [{
index cf085a6..c77bc12 100644 (file)
@@ -218,6 +218,26 @@ static LogicalResult foldMemRefCast(Operation *op) {
 }
 
 //===----------------------------------------------------------------------===//
+// Common cast compatibility check for vector types.
+//===----------------------------------------------------------------------===//
+
+/// This method checks for cast compatibility of vector types.
+/// If 'a' and 'b' are vector types, and they are cast compatible,
+/// it calls the 'areElementsCastCompatible' function to check for
+/// element cast compatibility.
+/// Returns 'true' if the vector types are cast compatible,  and 'false'
+/// otherwise.
+static bool areVectorCastSimpleCompatible(
+    Type a, Type b, function_ref<bool(Type, Type)> areElementsCastCompatible) {
+  if (auto va = a.dyn_cast<VectorType>())
+    if (auto vb = b.dyn_cast<VectorType>())
+      return va.getShape().equals(vb.getShape()) &&
+             areElementsCastCompatible(va.getElementType(),
+                                       vb.getElementType());
+  return false;
+}
+
+//===----------------------------------------------------------------------===//
 // AddFOp
 //===----------------------------------------------------------------------===//
 
@@ -1816,11 +1836,7 @@ bool FPExtOp::areCastCompatible(Type a, Type b) {
   if (auto fa = a.dyn_cast<FloatType>())
     if (auto fb = b.dyn_cast<FloatType>())
       return fa.getWidth() < fb.getWidth();
-  if (auto va = a.dyn_cast<VectorType>())
-    if (auto vb = b.dyn_cast<VectorType>())
-      return va.getShape().equals(vb.getShape()) &&
-             areCastCompatible(va.getElementType(), vb.getElementType());
-  return false;
+  return areVectorCastSimpleCompatible(a, b, areCastCompatible);
 }
 
 //===----------------------------------------------------------------------===//
@@ -1828,7 +1844,9 @@ bool FPExtOp::areCastCompatible(Type a, Type b) {
 //===----------------------------------------------------------------------===//
 
 bool FPToSIOp::areCastCompatible(Type a, Type b) {
-  return a.isa<FloatType>() && b.isSignlessInteger();
+  if (a.isa<FloatType>() && b.isSignlessInteger())
+    return true;
+  return areVectorCastSimpleCompatible(a, b, areCastCompatible);
 }
 
 //===----------------------------------------------------------------------===//
@@ -1836,7 +1854,9 @@ bool FPToSIOp::areCastCompatible(Type a, Type b) {
 //===----------------------------------------------------------------------===//
 
 bool FPToUIOp::areCastCompatible(Type a, Type b) {
-  return a.isa<FloatType>() && b.isSignlessInteger();
+  if (a.isa<FloatType>() && b.isSignlessInteger())
+    return true;
+  return areVectorCastSimpleCompatible(a, b, areCastCompatible);
 }
 
 //===----------------------------------------------------------------------===//
@@ -1847,11 +1867,7 @@ bool FPTruncOp::areCastCompatible(Type a, Type b) {
   if (auto fa = a.dyn_cast<FloatType>())
     if (auto fb = b.dyn_cast<FloatType>())
       return fa.getWidth() > fb.getWidth();
-  if (auto va = a.dyn_cast<VectorType>())
-    if (auto vb = b.dyn_cast<VectorType>())
-      return va.getShape().equals(vb.getShape()) &&
-             areCastCompatible(va.getElementType(), vb.getElementType());
-  return false;
+  return areVectorCastSimpleCompatible(a, b, areCastCompatible);
 }
 
 //===----------------------------------------------------------------------===//
@@ -2291,7 +2307,9 @@ OpFoldResult SignedRemIOp::fold(ArrayRef<Attribute> operands) {
 
 // sitofp is applicable from integer types to float types.
 bool SIToFPOp::areCastCompatible(Type a, Type b) {
-  return a.isSignlessInteger() && b.isa<FloatType>();
+  if (a.isSignlessInteger() && b.isa<FloatType>())
+    return true;
+  return areVectorCastSimpleCompatible(a, b, areCastCompatible);
 }
 
 //===----------------------------------------------------------------------===//
@@ -2371,7 +2389,9 @@ OpFoldResult SubIOp::fold(ArrayRef<Attribute> operands) {
 
 // uitofp is applicable from integer types to float types.
 bool UIToFPOp::areCastCompatible(Type a, Type b) {
-  return a.isSignlessInteger() && b.isa<FloatType>();
+  if (a.isSignlessInteger() && b.isa<FloatType>())
+    return true;
+  return areVectorCastSimpleCompatible(a, b, areCastCompatible);
 }
 
 //===----------------------------------------------------------------------===//
index 62be478..bb0363b 100644 (file)
@@ -594,6 +594,24 @@ func @sitofp(%arg0 : i32, %arg1 : i64) {
   return
 }
 
+// Checking conversion of integer vectors to floating point vector types.
+// CHECK-LABEL: @sitofp_vector
+func @sitofp_vector(%arg0 : vector<2xi16>, %arg1 : vector<2xi32>, %arg2 : vector<2xi64>) {
+// CHECK-NEXT: = llvm.sitofp {{.*}} : !llvm.vec<2 x i16> to !llvm.vec<2 x float>
+  %0 = sitofp %arg0: vector<2xi16> to vector<2xf32>
+// CHECK-NEXT: = llvm.sitofp {{.*}} : !llvm.vec<2 x i16> to !llvm.vec<2 x double>
+  %1 = sitofp %arg0: vector<2xi16> to vector<2xf64>
+// CHECK-NEXT: = llvm.sitofp {{.*}} : !llvm.vec<2 x i32> to !llvm.vec<2 x float>
+  %2 = sitofp %arg1: vector<2xi32> to vector<2xf32>
+// CHECK-NEXT: = llvm.sitofp {{.*}} : !llvm.vec<2 x i32> to !llvm.vec<2 x double>
+  %3 = sitofp %arg1: vector<2xi32> to vector<2xf64>
+// CHECK-NEXT: = llvm.sitofp {{.*}} : !llvm.vec<2 x i64> to !llvm.vec<2 x float>
+  %4 = sitofp %arg2: vector<2xi64> to vector<2xf32>
+// CHECK-NEXT: = llvm.sitofp {{.*}} : !llvm.vec<2 x i64> to !llvm.vec<2 x double>
+  %5 = sitofp %arg2: vector<2xi64> to vector<2xf64>
+  return
+}
+
 // Checking conversion of unsigned integer types to floating point.
 // CHECK-LABEL: @uitofp
 func @uitofp(%arg0 : i32, %arg1 : i64) {
@@ -646,6 +664,24 @@ func @fptosi(%arg0 : f32, %arg1 : f64) {
   return
 }
 
+// Checking conversion of floating point vectors to integer vector types.
+// CHECK-LABEL: @fptosi_vector
+func @fptosi_vector(%arg0 : vector<2xf16>, %arg1 : vector<2xf32>, %arg2 : vector<2xf64>) {
+// CHECK-NEXT: = llvm.fptosi {{.*}} : !llvm.vec<2 x half> to !llvm.vec<2 x i32>
+  %0 = fptosi %arg0: vector<2xf16> to vector<2xi32>
+// CHECK-NEXT: = llvm.fptosi {{.*}} : !llvm.vec<2 x half> to !llvm.vec<2 x i64>
+  %1 = fptosi %arg0: vector<2xf16> to vector<2xi64>
+// CHECK-NEXT: = llvm.fptosi {{.*}} : !llvm.vec<2 x float> to !llvm.vec<2 x i32>
+  %2 = fptosi %arg1: vector<2xf32> to vector<2xi32>
+// CHECK-NEXT: = llvm.fptosi {{.*}} : !llvm.vec<2 x float> to !llvm.vec<2 x i64>
+  %3 = fptosi %arg1: vector<2xf32> to vector<2xi64>
+// CHECK-NEXT: = llvm.fptosi {{.*}} : !llvm.vec<2 x double> to !llvm.vec<2 x i32>
+  %4 = fptosi %arg2: vector<2xf64> to vector<2xi32>
+// CHECK-NEXT: = llvm.fptosi {{.*}} : !llvm.vec<2 x double> to !llvm.vec<2 x i64>
+  %5 = fptosi %arg2: vector<2xf64> to vector<2xi64>
+  return
+}
+
 // Checking conversion of floating point to integer types.
 // CHECK-LABEL: @fptoui
 func @fptoui(%arg0 : f32, %arg1 : f64) {
@@ -660,6 +696,41 @@ func @fptoui(%arg0 : f32, %arg1 : f64) {
   return
 }
 
+// Checking conversion of floating point vectors to integer vector types.
+// CHECK-LABEL: @fptoui_vector
+func @fptoui_vector(%arg0 : vector<2xf16>, %arg1 : vector<2xf32>, %arg2 : vector<2xf64>) {
+// CHECK-NEXT: = llvm.fptoui {{.*}} : !llvm.vec<2 x half> to !llvm.vec<2 x i32>
+  %0 = fptoui %arg0: vector<2xf16> to vector<2xi32>
+// CHECK-NEXT: = llvm.fptoui {{.*}} : !llvm.vec<2 x half> to !llvm.vec<2 x i64>
+  %1 = fptoui %arg0: vector<2xf16> to vector<2xi64>
+// CHECK-NEXT: = llvm.fptoui {{.*}} : !llvm.vec<2 x float> to !llvm.vec<2 x i32>
+  %2 = fptoui %arg1: vector<2xf32> to vector<2xi32>
+// CHECK-NEXT: = llvm.fptoui {{.*}} : !llvm.vec<2 x float> to !llvm.vec<2 x i64>
+  %3 = fptoui %arg1: vector<2xf32> to vector<2xi64>
+// CHECK-NEXT: = llvm.fptoui {{.*}} : !llvm.vec<2 x double> to !llvm.vec<2 x i32>
+  %4 = fptoui %arg2: vector<2xf64> to vector<2xi32>
+// CHECK-NEXT: = llvm.fptoui {{.*}} : !llvm.vec<2 x double> to !llvm.vec<2 x i64>
+  %5 = fptoui %arg2: vector<2xf64> to vector<2xi64>
+  return
+}
+
+// Checking conversion of integer vectors to floating point vector types.
+// CHECK-LABEL: @uitofp_vector
+func @uitofp_vector(%arg0 : vector<2xi16>, %arg1 : vector<2xi32>, %arg2 : vector<2xi64>) {
+// CHECK-NEXT: = llvm.uitofp {{.*}} : !llvm.vec<2 x i16> to !llvm.vec<2 x float>
+  %0 = uitofp %arg0: vector<2xi16> to vector<2xf32>
+// CHECK-NEXT: = llvm.uitofp {{.*}} : !llvm.vec<2 x i16> to !llvm.vec<2 x double>
+  %1 = uitofp %arg0: vector<2xi16> to vector<2xf64>
+// CHECK-NEXT: = llvm.uitofp {{.*}} : !llvm.vec<2 x i32> to !llvm.vec<2 x float>
+  %2 = uitofp %arg1: vector<2xi32> to vector<2xf32>
+// CHECK-NEXT: = llvm.uitofp {{.*}} : !llvm.vec<2 x i32> to !llvm.vec<2 x double>
+  %3 = uitofp %arg1: vector<2xi32> to vector<2xf64>
+// CHECK-NEXT: = llvm.uitofp {{.*}} : !llvm.vec<2 x i64> to !llvm.vec<2 x float>
+  %4 = uitofp %arg2: vector<2xi64> to vector<2xf32>
+// CHECK-NEXT: = llvm.uitofp {{.*}} : !llvm.vec<2 x i64> to !llvm.vec<2 x double>
+  %5 = uitofp %arg2: vector<2xi64> to vector<2xf64>
+  return
+}
 
 // Checking conversion of integer types to floating point.
 // CHECK-LABEL: @fptrunc