#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/FormatVariadic.h"
#include <numeric>
using namespace mlir;
return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue();
}
+/// Returns the number of bits for the given scalar/vector type.
+static int getNumBits(Type type) {
+ if (auto vectorType = type.dyn_cast<VectorType>())
+ return vectorType.cast<ShapedType>().getSizeInBits();
+ return type.getIntOrFloatBitWidth();
+}
+
namespace {
struct VectorBitcastConvert final
if (!dstType)
return failure();
- if (dstType == adaptor.getSource().getType())
+ if (dstType == adaptor.getSource().getType()) {
rewriter.replaceOp(bitcastOp, adaptor.getSource());
- else
- rewriter.replaceOpWithNewOp<spirv::BitcastOp>(bitcastOp, dstType,
- adaptor.getSource());
+ return success();
+ }
+
+ // Check that the source and destination type have the same bitwidth.
+ // Depending on the target environment, we may need to emulate certain
+ // types, which can cause issue with bitcast.
+ Type srcType = adaptor.getSource().getType();
+ if (getNumBits(dstType) != getNumBits(srcType)) {
+ return rewriter.notifyMatchFailure(
+ bitcastOp,
+ llvm::formatv("different source ({0}) and target ({1}) bitwidth",
+ srcType, dstType));
+ }
+ rewriter.replaceOpWithNewOp<spirv::BitcastOp>(bitcastOp, dstType,
+ adaptor.getSource());
return success();
}
};
// -----
+// Check that without the proper capability we fail the pattern application
+// to avoid generating invalid ops.
+
+module attributes { spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [], []>, #spirv.resource_limits<>> } {
+
+// CHECK-LABEL: @bitcast
+func.func @bitcast(%arg0 : vector<2xf32>, %arg1: vector<2xf16>) -> (vector<4xf16>, vector<1xf32>) {
+ // CHECK-COUNT-2: vector.bitcast
+ %0 = vector.bitcast %arg0 : vector<2xf32> to vector<4xf16>
+ %1 = vector.bitcast %arg1 : vector<2xf16> to vector<1xf32>
+ return %0, %1: vector<4xf16>, vector<1xf32>
+}
+
+} // end module
+
+// -----
+
module attributes { spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Kernel], []>, #spirv.resource_limits<>> } {
// CHECK-LABEL: @cl_fma