LogicalResult
matchAndRewrite(vector::BitCastOp bitcastOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto dstType = getTypeConverter()->convertType(bitcastOp.getType());
+ Type dstType = getTypeConverter()->convertType(bitcastOp.getType());
if (!dstType)
return failure();
using OpConversionPattern::OpConversionPattern;
LogicalResult
- matchAndRewrite(vector::BroadcastOp broadcastOp, OpAdaptor adaptor,
+ matchAndRewrite(vector::BroadcastOp castOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- if (broadcastOp.getSource().getType().isa<VectorType>() ||
- !spirv::CompositeType::isValid(broadcastOp.getVectorType()))
+ Type resultType = getTypeConverter()->convertType(castOp.getVectorType());
+ if (!resultType)
return failure();
- SmallVector<Value, 4> source(broadcastOp.getVectorType().getNumElements(),
+
+ if (resultType.isa<spirv::ScalarType>()) {
+ rewriter.replaceOp(castOp, adaptor.getSource());
+ return success();
+ }
+
+ SmallVector<Value, 4> source(castOp.getVectorType().getNumElements(),
adaptor.getSource());
rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
- broadcastOp, broadcastOp.getVectorType(), source);
+ castOp, castOp.getVectorType(), source);
return success();
}
};
if (resultVectorType && resultVectorType.getNumElements() > 1)
return failure();
- auto dstType = getTypeConverter()->convertType(extractOp.getType());
+ Type dstType = getTypeConverter()->convertType(extractOp.getType());
if (!dstType)
return failure();
LogicalResult
matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto dstType = getTypeConverter()->convertType(extractOp.getType());
+ Type dstType = getTypeConverter()->convertType(extractOp.getType());
if (!dstType)
return failure();
using OpConversionPattern::OpConversionPattern;
LogicalResult
- matchAndRewrite(vector::ExtractElementOp extractElementOp, OpAdaptor adaptor,
+ matchAndRewrite(vector::ExtractElementOp extractOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- if (!spirv::CompositeType::isValid(extractElementOp.getVectorType()))
+ Type vectorType =
+ getTypeConverter()->convertType(adaptor.getVector().getType());
+ if (!vectorType)
return failure();
+
+ if (vectorType.isa<spirv::ScalarType>()) {
+ rewriter.replaceOp(extractOp, adaptor.getVector());
+ return success();
+ }
+
rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>(
- extractElementOp, extractElementOp.getType(), adaptor.getVector(),
- extractElementOp.getPosition());
+ extractOp, extractOp.getType(), adaptor.getVector(),
+ extractOp.getPosition());
return success();
}
};
using OpConversionPattern::OpConversionPattern;
LogicalResult
- matchAndRewrite(vector::InsertElementOp insertElementOp, OpAdaptor adaptor,
+ matchAndRewrite(vector::InsertElementOp insertOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- if (!spirv::CompositeType::isValid(insertElementOp.getDestVectorType()))
+ Type vectorType = getTypeConverter()->convertType(insertOp.getType());
+ if (!vectorType)
return failure();
+
+ if (vectorType.isa<spirv::ScalarType>()) {
+ rewriter.replaceOp(insertOp, adaptor.getSource());
+ return success();
+ }
+
rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>(
- insertElementOp, insertElementOp.getType(), insertElementOp.getDest(),
- adaptor.getSource(), insertElementOp.getPosition());
+ insertOp, vectorType, insertOp.getDest(), adaptor.getSource(),
+ insertOp.getPosition());
return success();
}
};
auto oldResultType = shuffleOp.getVectorType();
if (!spirv::CompositeType::isValid(oldResultType))
return failure();
- auto newResultType = getTypeConverter()->convertType(oldResultType);
+ Type newResultType = getTypeConverter()->convertType(oldResultType);
auto oldSourceType = shuffleOp.getV1VectorType();
if (oldSourceType.getNumElements() > 1) {
// -----
+// CHECK-LABEL: @extract_element_size1_vector
+// CHECK-SAME: (%[[S:.+]]: f32
+func.func @extract_element_size1_vector(%arg0 : f32, %i: index) -> f32 {
+ %bcast = vector.broadcast %arg0 : f32 to vector<1xf32>
+ %0 = vector.extractelement %bcast[%i : index] : vector<1xf32>
+ // CHECK: return %[[S]]
+ return %0: f32
+}
+
+// -----
+
+// CHECK-LABEL: @extract_element_0d_vector
+// CHECK-SAME: (%[[S:.+]]: f32)
+func.func @extract_element_0d_vector(%arg0 : f32) -> f32 {
+ %bcast = vector.broadcast %arg0 : f32 to vector<f32>
+ %0 = vector.extractelement %bcast[] : vector<f32>
+ // CHECK: return %[[S]]
+ return %0: f32
+}
+
+// -----
+
// CHECK-LABEL: @extract_strided_slice
// CHECK-SAME: %[[ARG:.+]]: vector<4xf32>
// CHECK: spv.VectorShuffle [1 : i32, 2 : i32] %[[ARG]] : vector<4xf32>, %[[ARG]] : vector<4xf32> -> vector<2xf32>
// -----
+// CHECK-LABEL: @insert_element_size1_vector
+// CHECK-SAME: (%[[S:[a-z0-9]+]]: f32
+func.func @insert_element_size1_vector(%scalar: f32, %vector : vector<1xf32>, %i: index) -> vector<1xf32> {
+ %0 = vector.insertelement %scalar, %vector[%i : index] : vector<1xf32>
+ // CHECK: %[[V:.+]] = builtin.unrealized_conversion_cast %arg0 : f32 to vector<1xf32>
+ // CHECK: return %[[V]]
+ return %0: vector<1xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @insert_element_0d_vector
+// CHECK-SAME: (%[[S:[a-z0-9]+]]: f32
+func.func @insert_element_0d_vector(%scalar: f32, %vector : vector<f32>) -> vector<f32> {
+ %0 = vector.insertelement %scalar, %vector[] : vector<f32>
+ // CHECK: %[[V:.+]] = builtin.unrealized_conversion_cast %arg0 : f32 to vector<f32>
+ // CHECK: return %[[V]]
+ return %0: vector<f32>
+}
+
+// -----
+
// CHECK-LABEL: @insert_strided_slice
// CHECK-SAME: %[[PART:.+]]: vector<2xf32>, %[[ALL:.+]]: vector<4xf32>
// CHECK: spv.VectorShuffle [0 : i32, 4 : i32, 5 : i32, 3 : i32] %[[ALL]] : vector<4xf32>, %[[PART]] : vector<2xf32> -> vector<4xf32>