From b2e72cd38de859194b18d598fdfe704315be3d36 Mon Sep 17 00:00:00 2001 From: thomasraoux Date: Mon, 19 Apr 2021 21:24:06 -0700 Subject: [PATCH] [mlir][spirv] Support conversion of extract op from vector<1xT> type Differential Revision: https://reviews.llvm.org/D100814 --- mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp | 5 +++++ mlir/test/Conversion/VectorToSPIRV/simple.mlir | 18 ++++++++++++++++++ 2 files changed, 23 insertions(+) diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp index 4cfcb41..edabae7 100644 --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp @@ -89,6 +89,11 @@ struct VectorExtractOpConvert final return failure(); vector::ExtractOp::Adaptor adaptor(operands); + if (adaptor.vector().getType().isa()) { + rewriter.replaceOp(extractOp, adaptor.vector()); + return success(); + } + int32_t id = getFirstIntValue(extractOp.position()); rewriter.replaceOpWithNewOp( extractOp, adaptor.vector(), id); diff --git a/mlir/test/Conversion/VectorToSPIRV/simple.mlir b/mlir/test/Conversion/VectorToSPIRV/simple.mlir index 836d385..9f9657c 100644 --- a/mlir/test/Conversion/VectorToSPIRV/simple.mlir +++ b/mlir/test/Conversion/VectorToSPIRV/simple.mlir @@ -40,6 +40,24 @@ func @extract(%arg0 : vector<2xf32>) { // ----- +module attributes { spv.target_env = #spv.target_env<#spv.vce, {}> } { + +// CHECK-LABEL: func @extract_scalar +// CHECK-SAME: %[[ARG0:.+]]: vector<2xf16> +// CHECK-SAME: %[[ARG1:.+]]: vector<4xf32> +// CHECK: %[[S:.+]] = spv.Bitcast %[[ARG0]] : vector<2xf16> to f32 +// CHECK: spv.CompositeInsert %[[S]], %[[ARG1]][0 : i32] : f32 into vector<4xf32> +func @extract_scalar(%arg0 : vector<2xf16>, %arg1 : vector<4xf32>) { + %0 = vector.bitcast %arg0 : vector<2xf16> to vector<1xf32> + %1 = vector.extract %0[0] : vector<1xf32> + %2 = vector.insert %1, %arg1[0] : f32 into vector<4xf32> + spv.Return +} + +} // end module + +// ----- + // CHECK-LABEL: extract_insert // CHECK-SAME: %[[V:.*]]: vector<4xf32> // CHECK: %[[S:.*]] = spv.CompositeExtract %[[V]][1 : i32] : vector<4xf32> -- 2.7.4