[mlir][spirv] Simplify scalar type size calculation.
authorDenis Khalikov <khalikov.denis@huawei.com>
Tue, 21 Jan 2020 17:12:17 +0000 (12:12 -0500)
committerLei Zhang <antiagainst@google.com>
Tue, 21 Jan 2020 17:15:37 +0000 (12:15 -0500)
Simplify scalar type size calculation and reject boolean memrefs.

Differential Revision: https://reviews.llvm.org/D72999

mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp
mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir

index 6c7a453..8ba17a8 100644 (file)
@@ -51,6 +51,9 @@ FuncOpConversion::matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands,
   {
     for (auto argType : enumerate(funcOp.getType().getInputs())) {
       auto convertedType = typeConverter.convertType(argType.value());
+      if (!convertedType) {
+        return matchFailure();
+      }
       signatureConverter.addInputs(argType.index(), convertedType);
     }
   }
index fc6c067..2d340b7 100644 (file)
@@ -41,10 +41,18 @@ Type SPIRVTypeConverter::getIndexType(MLIRContext *context) {
 // TODO(ravishankarm): This is a utility function that should probably be
 // exposed by the SPIR-V dialect. Keeping it local till the use case arises.
 static Optional<int64_t> getTypeNumBytes(Type t) {
-  if (auto integerType = t.dyn_cast<IntegerType>()) {
-    return integerType.getWidth() / 8;
-  } else if (auto floatType = t.dyn_cast<FloatType>()) {
-    return floatType.getWidth() / 8;
+  if (spirv::SPIRVDialect::isValidScalarType(t)) {
+    auto bitWidth = t.getIntOrFloatBitWidth();
+    // According to the SPIR-V spec:
+    // "There is no physical size or bit pattern defined for values with boolean
+    // type. If they are stored (in conjunction with OpVariable), they can only
+    // be used with logical addressing operations, not physical, and only with
+    // non-externally visible shader Storage Classes: Workgroup, CrossWorkgroup,
+    // Private, Function, Input, and Output."
+    if (bitWidth == 1) {
+      return llvm::None;
+    }
+    return bitWidth / 8;
   } else if (auto memRefType = t.dyn_cast<MemRefType>()) {
     // TODO: Layout should also be controlled by the ABI attributes. For now
     // using the layout from MemRef.
index 30b4c4a..7d0020a 100644 (file)
@@ -289,3 +289,12 @@ func @sitofp(%arg0 : i32) {
   %0 = std.sitofp %arg0 : i32 to f32
   return
 }
+
+//===----------------------------------------------------------------------===//
+// memref type
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: func @memref_type({{%.*}}: memref<3xi1>) {
+func @memref_type(%arg0: memref<3xi1>) {
+  return
+}