[mlir][spirv] Fail vector.bitcast conversion with different bitwidth
authorLei Zhang <antiagainst@gmail.com>
Thu, 29 Dec 2022 23:27:29 +0000 (15:27 -0800)
committerLei Zhang <antiagainst@gmail.com>
Thu, 29 Dec 2022 23:43:55 +0000 (15:43 -0800)
Depending on the target environment, we may need to emulate certain
types, which can cause issue with bitcast.

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D140437

mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir

index 11608db..da50588 100644 (file)
@@ -23,6 +23,7 @@
 #include "mlir/Transforms/DialectConversion.h"
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/FormatVariadic.h"
 #include <numeric>
 
 using namespace mlir;
@@ -33,6 +34,13 @@ static uint64_t getFirstIntValue(ArrayAttr attr) {
   return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue();
 }
 
+/// Returns the number of bits for the given scalar/vector type.
+static int getNumBits(Type type) {
+  if (auto vectorType = type.dyn_cast<VectorType>())
+    return vectorType.cast<ShapedType>().getSizeInBits();
+  return type.getIntOrFloatBitWidth();
+}
+
 namespace {
 
 struct VectorBitcastConvert final
@@ -46,12 +54,24 @@ struct VectorBitcastConvert final
     if (!dstType)
       return failure();
 
-    if (dstType == adaptor.getSource().getType())
+    if (dstType == adaptor.getSource().getType()) {
       rewriter.replaceOp(bitcastOp, adaptor.getSource());
-    else
-      rewriter.replaceOpWithNewOp<spirv::BitcastOp>(bitcastOp, dstType,
-                                                    adaptor.getSource());
+      return success();
+    }
+
+    // Check that the source and destination type have the same bitwidth.
+    // Depending on the target environment, we may need to emulate certain
+    // types, which can cause issue with bitcast.
+    Type srcType = adaptor.getSource().getType();
+    if (getNumBits(dstType) != getNumBits(srcType)) {
+      return rewriter.notifyMatchFailure(
+          bitcastOp,
+          llvm::formatv("different source ({0}) and target ({1}) bitwidth",
+                        srcType, dstType));
+    }
 
+    rewriter.replaceOpWithNewOp<spirv::BitcastOp>(bitcastOp, dstType,
+                                                  adaptor.getSource());
     return success();
   }
 };
index 4bb1835..26a2ab6 100644 (file)
@@ -16,6 +16,23 @@ func.func @bitcast(%arg0 : vector<2xf32>, %arg1: vector<2xf16>) -> (vector<4xf16
 
 // -----
 
+// Check that without the proper capability we fail the pattern application
+// to avoid generating invalid ops.
+
+module attributes { spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [], []>, #spirv.resource_limits<>> } {
+
+// CHECK-LABEL: @bitcast
+func.func @bitcast(%arg0 : vector<2xf32>, %arg1: vector<2xf16>) -> (vector<4xf16>, vector<1xf32>) {
+  // CHECK-COUNT-2: vector.bitcast
+  %0 = vector.bitcast %arg0 : vector<2xf32> to vector<4xf16>
+  %1 = vector.bitcast %arg1 : vector<2xf16> to vector<1xf32>
+  return %0, %1: vector<4xf16>, vector<1xf32>
+}
+
+} // end module
+
+// -----
+
 module attributes { spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Kernel], []>, #spirv.resource_limits<>> } {
 
 // CHECK-LABEL: @cl_fma