[mlir][spirv] Convert some 0-D vector extract/insertelement ops
authorLei Zhang <antiagainst@google.com>
Fri, 2 Sep 2022 21:47:31 +0000 (17:47 -0400)
committerLei Zhang <antiagainst@google.com>
Fri, 2 Sep 2022 21:47:49 +0000 (17:47 -0400)
Reviewed By: kuhar

Differential Revision: https://reviews.llvm.org/D133183

mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir

index 5349261..23f664c 100644 (file)
@@ -41,7 +41,7 @@ struct VectorBitcastConvert final
   LogicalResult
   matchAndRewrite(vector::BitCastOp bitcastOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    auto dstType = getTypeConverter()->convertType(bitcastOp.getType());
+    Type dstType = getTypeConverter()->convertType(bitcastOp.getType());
     if (!dstType)
       return failure();
 
@@ -60,15 +60,21 @@ struct VectorBroadcastConvert final
   using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(vector::BroadcastOp broadcastOp, OpAdaptor adaptor,
+  matchAndRewrite(vector::BroadcastOp castOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    if (broadcastOp.getSource().getType().isa<VectorType>() ||
-        !spirv::CompositeType::isValid(broadcastOp.getVectorType()))
+    Type resultType = getTypeConverter()->convertType(castOp.getVectorType());
+    if (!resultType)
       return failure();
-    SmallVector<Value, 4> source(broadcastOp.getVectorType().getNumElements(),
+
+    if (resultType.isa<spirv::ScalarType>()) {
+      rewriter.replaceOp(castOp, adaptor.getSource());
+      return success();
+    }
+
+    SmallVector<Value, 4> source(castOp.getVectorType().getNumElements(),
                                  adaptor.getSource());
     rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
-        broadcastOp, broadcastOp.getVectorType(), source);
+        castOp, castOp.getVectorType(), source);
     return success();
   }
 };
@@ -85,7 +91,7 @@ struct VectorExtractOpConvert final
     if (resultVectorType && resultVectorType.getNumElements() > 1)
       return failure();
 
-    auto dstType = getTypeConverter()->convertType(extractOp.getType());
+    Type dstType = getTypeConverter()->convertType(extractOp.getType());
     if (!dstType)
       return failure();
 
@@ -108,7 +114,7 @@ struct VectorExtractStridedSliceOpConvert final
   LogicalResult
   matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    auto dstType = getTypeConverter()->convertType(extractOp.getType());
+    Type dstType = getTypeConverter()->convertType(extractOp.getType());
     if (!dstType)
       return failure();
 
@@ -183,13 +189,21 @@ struct VectorExtractElementOpConvert final
   using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(vector::ExtractElementOp extractElementOp, OpAdaptor adaptor,
+  matchAndRewrite(vector::ExtractElementOp extractOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    if (!spirv::CompositeType::isValid(extractElementOp.getVectorType()))
+    Type vectorType =
+        getTypeConverter()->convertType(adaptor.getVector().getType());
+    if (!vectorType)
       return failure();
+
+    if (vectorType.isa<spirv::ScalarType>()) {
+      rewriter.replaceOp(extractOp, adaptor.getVector());
+      return success();
+    }
+
     rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>(
-        extractElementOp, extractElementOp.getType(), adaptor.getVector(),
-        extractElementOp.getPosition());
+        extractOp, extractOp.getType(), adaptor.getVector(),
+        extractOp.getPosition());
     return success();
   }
 };
@@ -199,13 +213,20 @@ struct VectorInsertElementOpConvert final
   using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(vector::InsertElementOp insertElementOp, OpAdaptor adaptor,
+  matchAndRewrite(vector::InsertElementOp insertOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    if (!spirv::CompositeType::isValid(insertElementOp.getDestVectorType()))
+    Type vectorType = getTypeConverter()->convertType(insertOp.getType());
+    if (!vectorType)
       return failure();
+
+    if (vectorType.isa<spirv::ScalarType>()) {
+      rewriter.replaceOp(insertOp, adaptor.getSource());
+      return success();
+    }
+
     rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>(
-        insertElementOp, insertElementOp.getType(), insertElementOp.getDest(),
-        adaptor.getSource(), insertElementOp.getPosition());
+        insertOp, vectorType, insertOp.getDest(), adaptor.getSource(),
+        insertOp.getPosition());
     return success();
   }
 };
@@ -354,7 +375,7 @@ struct VectorShuffleOpConvert final
     auto oldResultType = shuffleOp.getVectorType();
     if (!spirv::CompositeType::isValid(oldResultType))
       return failure();
-    auto newResultType = getTypeConverter()->convertType(oldResultType);
+    Type newResultType = getTypeConverter()->convertType(oldResultType);
 
     auto oldSourceType = shuffleOp.getV1VectorType();
     if (oldSourceType.getNumElements() > 1) {
index c2e2286..fc2f8c9 100644 (file)
@@ -235,7 +235,7 @@ static Type convertVectorType(const spirv::TargetEnv &targetEnv,
                               const SPIRVTypeConverter::Options &options,
                               VectorType type,
                               Optional<spirv::StorageClass> storageClass = {}) {
-  if (type.getRank() == 1 && type.getNumElements() == 1)
+  if (type.getRank() <= 1 && type.getNumElements() == 1)
     return type.getElementType();
 
   if (!spirv::CompositeType::isValid(type)) {
index 4ff67ce..0d3b094 100644 (file)
@@ -240,6 +240,10 @@ func.func @float_vector(
 // CHECK-SAME: %{{.+}}: i32
 func.func @one_element_vector(%arg0: vector<1xi32>) { return }
 
+// CHECK-LABEL: spv.func @zerod_vector
+//  CHECK-SAME: %{{.+}}: f32
+func.func @zerod_vector(%arg0: vector<f32>) { return }
+
 } // end module
 
 // -----
index f1de62c..f483fc3 100644 (file)
@@ -121,6 +121,28 @@ func.func @extract_element_size5_vector(%arg0 : vector<5xf32>, %id : i32) -> f32
 
 // -----
 
+// CHECK-LABEL: @extract_element_size1_vector
+//  CHECK-SAME: (%[[S:.+]]: f32
+func.func @extract_element_size1_vector(%arg0 : f32, %i: index) -> f32 {
+  %bcast = vector.broadcast %arg0 : f32 to vector<1xf32>
+  %0 = vector.extractelement %bcast[%i : index] : vector<1xf32>
+  // CHECK: return %[[S]]
+  return %0: f32
+}
+
+// -----
+
+// CHECK-LABEL: @extract_element_0d_vector
+//  CHECK-SAME: (%[[S:.+]]: f32)
+func.func @extract_element_0d_vector(%arg0 : f32) -> f32 {
+  %bcast = vector.broadcast %arg0 : f32 to vector<f32>
+  %0 = vector.extractelement %bcast[] : vector<f32>
+  // CHECK: return %[[S]]
+  return %0: f32
+}
+
+// -----
+
 // CHECK-LABEL: @extract_strided_slice
 //  CHECK-SAME: %[[ARG:.+]]: vector<4xf32>
 //       CHECK:   spv.VectorShuffle [1 : i32, 2 : i32] %[[ARG]] : vector<4xf32>, %[[ARG]] : vector<4xf32> -> vector<2xf32>
@@ -161,6 +183,28 @@ func.func @insert_element_size5_vector(%val: f32, %arg0 : vector<5xf32>, %id : i
 
 // -----
 
+// CHECK-LABEL: @insert_element_size1_vector
+//  CHECK-SAME: (%[[S:[a-z0-9]+]]: f32
+func.func @insert_element_size1_vector(%scalar: f32, %vector : vector<1xf32>, %i: index) -> vector<1xf32> {
+  %0 = vector.insertelement %scalar, %vector[%i : index] : vector<1xf32>
+  // CHECK: %[[V:.+]] = builtin.unrealized_conversion_cast %arg0 : f32 to vector<1xf32>
+  // CHECK: return %[[V]]
+  return %0: vector<1xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @insert_element_0d_vector
+//  CHECK-SAME: (%[[S:[a-z0-9]+]]: f32
+func.func @insert_element_0d_vector(%scalar: f32, %vector : vector<f32>) -> vector<f32> {
+  %0 = vector.insertelement %scalar, %vector[] : vector<f32>
+  // CHECK: %[[V:.+]] = builtin.unrealized_conversion_cast %arg0 : f32 to vector<f32>
+  // CHECK: return %[[V]]
+  return %0: vector<f32>
+}
+
+// -----
+
 // CHECK-LABEL: @insert_strided_slice
 //  CHECK-SAME: %[[PART:.+]]: vector<2xf32>, %[[ALL:.+]]: vector<4xf32>
 //       CHECK:   spv.VectorShuffle [0 : i32, 4 : i32, 5 : i32, 3 : i32] %[[ALL]] : vector<4xf32>, %[[PART]] : vector<2xf32> -> vector<4xf32>