[mlir][SPIRVToLLVM] Add shufflevector conversion
authorWeiwei Li <weiwei.li1@huawei.com>
Mon, 1 Nov 2021 14:47:47 +0000 (22:47 +0800)
committerWeiwei Li <weiwei.li1@huawei.com>
Mon, 1 Nov 2021 15:05:37 +0000 (23:05 +0800)
Add the shufflevector conversion. It only handles the static, i.e., IntegerAttr, index.

Co-authored: Xinyi Liu <xyliuhelen@gmail.com>

Reviewed by: antiagainst

Differential revision: https://reviews.llvm.org/D112161

mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
mlir/test/Conversion/SPIRVToLLVM/misc-ops-to-llvm.mlir

index 51df6c7..b416c30 100644 (file)
@@ -1364,6 +1364,65 @@ public:
   }
 };
 
+//===----------------------------------------------------------------------===//
+// VectorShuffleOp conversion
+//===----------------------------------------------------------------------===//
+
+class VectorShufflePattern
+    : public SPIRVToLLVMConversion<spirv::VectorShuffleOp> {
+public:
+  using SPIRVToLLVMConversion<spirv::VectorShuffleOp>::SPIRVToLLVMConversion;
+  LogicalResult
+  matchAndRewrite(spirv::VectorShuffleOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    auto components = adaptor.components();
+    auto vector1 = adaptor.vector1();
+    auto vector2 = adaptor.vector2();
+    int vector1Size = vector1.getType().cast<VectorType>().getNumElements();
+    int vector2Size = vector2.getType().cast<VectorType>().getNumElements();
+    if (vector1Size == vector2Size) {
+      rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(op, vector1, vector2,
+                                                         components);
+      return success();
+    }
+
+    auto dstType = typeConverter.convertType(op.getType());
+    auto scalarType = dstType.cast<VectorType>().getElementType();
+    auto componentsArray = components.getValue();
+    auto context = rewriter.getContext();
+    auto llvmI32Type = IntegerType::get(context, 32);
+    Value targetOp = rewriter.create<LLVM::UndefOp>(loc, dstType);
+    for (unsigned i = 0; i < componentsArray.size(); i++) {
+      if (componentsArray[i].isa<IntegerAttr>())
+        op.emitError("unable to support non-constant component");
+
+      int indexVal = componentsArray[i].cast<IntegerAttr>().getInt();
+      if (indexVal == -1)
+        continue;
+
+      int offsetVal = 0;
+      Value baseVector = vector1;
+      if (indexVal >= vector1Size) {
+        offsetVal = vector1Size;
+        baseVector = vector2;
+      }
+
+      Value dstIndex = rewriter.create<LLVM::ConstantOp>(
+          loc, llvmI32Type, rewriter.getIntegerAttr(rewriter.getI32Type(), i));
+      Value index = rewriter.create<LLVM::ConstantOp>(
+          loc, llvmI32Type,
+          rewriter.getIntegerAttr(rewriter.getI32Type(), indexVal - offsetVal));
+
+      auto extractOp = rewriter.create<LLVM::ExtractElementOp>(
+          loc, scalarType, baseVector, index);
+      targetOp = rewriter.create<LLVM::InsertElementOp>(loc, dstType, targetOp,
+                                                        extractOp, dstIndex);
+    }
+    rewriter.replaceOp(op, targetOp);
+    return success();
+  }
+};
 } // namespace
 
 //===----------------------------------------------------------------------===//
@@ -1489,6 +1548,7 @@ void mlir::populateSPIRVToLLVMConversionPatterns(
       CompositeExtractPattern, CompositeInsertPattern,
       DirectConversionPattern<spirv::SelectOp, LLVM::SelectOp>,
       DirectConversionPattern<spirv::UndefOp, LLVM::UndefOp>,
+      VectorShufflePattern,
 
       // Shift ops
       ShiftPattern<spirv::ShiftRightArithmeticOp, LLVM::AShrOp>,
index c8528d0..38d55b9 100644 (file)
@@ -59,6 +59,32 @@ spv.func @select_vector(%arg0: vector<2xi1>, %arg1: vector<2xi32>) "None" {
 }
 
 //===----------------------------------------------------------------------===//
+// spv.VectorShuffle
+//===----------------------------------------------------------------------===//
+
+spv.func @vector_shuffle_same_size(%vector1: vector<2xf32>, %vector2: vector<2xf32>) -> vector<3xf32> "None" {
+  //      CHECK: %[[res:.*]] = llvm.shufflevector {{.*}} [0 : i32, 2 : i32, -1 : i32] : vector<2xf32>, vector<2xf32>
+  // CHECK-NEXT: return %[[res]] : vector<3xf32>
+  %0 = spv.VectorShuffle [0: i32, 2: i32, 0xffffffff: i32] %vector1: vector<2xf32>, %vector2: vector<2xf32> -> vector<3xf32>
+  spv.ReturnValue %0: vector<3xf32>
+}
+
+spv.func @vector_shuffle_different_size(%vector1: vector<3xf32>, %vector2: vector<2xf32>) -> vector<3xf32> "None" {
+  //      CHECK: %[[UNDEF:.*]] = llvm.mlir.undef : vector<3xf32>
+  // CHECK-NEXT: %[[C0_0:.*]] = llvm.mlir.constant(0 : i32) : i32
+  // CHECK-NEXT: %[[C0_1:.*]] = llvm.mlir.constant(0 : i32) : i32
+  // CHECK-NEXT: %[[EXT0:.*]] = llvm.extractelement %arg0[%[[C0_1]] : i32] : vector<3xf32>
+  // CHECK-NEXT: %[[INSERT0:.*]] = llvm.insertelement %[[EXT0]], %[[UNDEF]][%[[C0_0]] : i32] : vector<3xf32>
+  // CHECK-NEXT: %[[C1_0:.*]] = llvm.mlir.constant(1 : i32) : i32
+  // CHECK-NEXT: %[[C1_1:.*]] = llvm.mlir.constant(1 : i32) : i32
+  // CHECK-NEXT: %[[EXT1:.*]] = llvm.extractelement {{.*}}[%[[C1_1]] : i32] : vector<2xf32>
+  // CHECK-NEXT: %[[RES:.*]] = llvm.insertelement %[[EXT1]], %[[INSERT0]][%[[C1_0]] : i32] : vector<3xf32>
+  // CHECK-NEXT: llvm.return %[[RES]] : vector<3xf32>
+  %0 = spv.VectorShuffle [0: i32, 4: i32, 0xffffffff: i32] %vector1: vector<3xf32>, %vector2: vector<2xf32> -> vector<3xf32>
+  spv.ReturnValue %0: vector<3xf32>
+}
+
+//===----------------------------------------------------------------------===//
 // spv.EntryPoint and spv.ExecutionMode
 //===----------------------------------------------------------------------===//