From: aartbik Date: Tue, 21 Jul 2020 17:57:18 +0000 (-0700) Subject: [mlir] [VectorOps] Add scatter/gather operations to Vector dialect X-Git-Tag: llvmorg-13-init~17266 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=19dbb230a245d3404a485d8684587c3d37c198d3;p=platform%2Fupstream%2Fllvm.git [mlir] [VectorOps] Add scatter/gather operations to Vector dialect Introduces the scatter/gather operations to the Vector dialect (important memory operations for sparse computations), together with a first reference implementation that lowers to the LLVM IR dialect to enable running on CPU (and other targets that support the corresponding LLVM IR intrinsics). The operations can be used directly where applicable, or can be used during progressively lowering to bring other memory operations closer to hardware ISA support for a gather/scatter. The semantics of the operation closely correspond to those of the corresponding llvm intrinsics. Note that the operation allows for a dynamic index vector (which is important for sparse computations). However, this first reference lowering implementation "serializes" the address computation when base + index_vector is converted to a vector of pointers. Exploring how to use SIMD properly during these step is TBD. More general memrefs and idiomatic versions of striding are also TBD. Reviewed By: arpith-jacob Differential Revision: https://reviews.llvm.org/D84039 --- diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td index ce0b3de..f421d2e 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -991,6 +991,35 @@ def LLVM_MaskedStoreOp let assemblyFormat = "$value `,` $data `,` $mask attr-dict `:` " "type($value) `,` type($mask) `into` type($data)"; } + +/// Create a call to Masked Gather intrinsic. +def LLVM_masked_gather + : LLVM_OneResultOp<"intr.masked.gather">, + Arguments<(ins LLVM_Type:$ptrs, LLVM_Type:$mask, + Variadic:$pass_thru, I32Attr:$alignment)> { + string llvmBuilder = [{ + $res = $pass_thru.empty() ? builder.CreateMaskedGather( + $ptrs, llvm::Align($alignment.getZExtValue()), $mask) : + builder.CreateMaskedGather( + $ptrs, llvm::Align($alignment.getZExtValue()), $mask, $pass_thru[0]); + }]; + let assemblyFormat = + "operands attr-dict `:` functional-type(operands, results)"; +} + +/// Create a call to Masked Scatter intrinsic. +def LLVM_masked_scatter + : LLVM_ZeroResultOp<"intr.masked.scatter">, + Arguments<(ins LLVM_Type:$value, LLVM_Type:$ptrs, LLVM_Type:$mask, + I32Attr:$alignment)> { + string llvmBuilder = [{ + builder.CreateMaskedScatter( + $value, $ptrs, llvm::Align($alignment.getZExtValue()), $mask); + }]; + let assemblyFormat = "$value `,` $ptrs `,` $mask attr-dict `:` " + "type($value) `,` type($mask) `into` type($ptrs)"; +} + // // Atomic operations. // diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td index 10a4498..fd3d190 100644 --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -1150,6 +1150,121 @@ def Vector_TransferWriteOp : let hasFolder = 1; } +def Vector_GatherOp : + Vector_Op<"gather">, + Arguments<(ins AnyMemRef:$base, + VectorOfRankAndType<[1], [AnyInteger]>:$indices, + VectorOfRankAndType<[1], [I1]>:$mask, + Variadic>:$pass_thru)>, + Results<(outs VectorOfRank<[1]>:$result)> { + + let summary = "gathers elements from memory into a vector as defined by an index vector"; + + let description = [{ + The gather operation gathers elements from memory into a 1-D vector as + defined by a base and a 1-D index vector, but only if the corresponding + bit is set in a 1-D mask vector. Otherwise, the element is taken from a + 1-D pass-through vector, if provided, or left undefined. Informally the + semantics are: + ``` + if (!defined(pass_thru)) pass_thru = [undef, .., undef] + result[0] := mask[0] ? MEM[base + index[0]] : pass_thru[0] + result[1] := mask[1] ? MEM[base + index[1]] : pass_thru[1] + etc. + ``` + The vector dialect leaves out-of-bounds behavior undefined. + + The gather operation can be used directly where applicable, or can be used + during progressively lowering to bring other memory operations closer to + hardware ISA support for a gather. The semantics of the operation closely + correspond to those of the `llvm.masked.gather` + [intrinsic](https://llvm.org/docs/LangRef.html#llvm-masked-gather-intrinsics). + + Example: + + ```mlir + %g = vector.gather %base, %indices, %mask, %pass_thru + : (memref, vector<16xi32>, vector<16xi1>, vector<16xf32>) -> vector<16xf32> + ``` + + }]; + let extraClassDeclaration = [{ + MemRefType getMemRefType() { + return base().getType().cast(); + } + VectorType getIndicesVectorType() { + return indices().getType().cast(); + } + VectorType getMaskVectorType() { + return mask().getType().cast(); + } + VectorType getPassThruVectorType() { + return (llvm::size(pass_thru()) == 0) + ? VectorType() + : (*pass_thru().begin()).getType().cast(); + } + VectorType getResultVectorType() { + return result().getType().cast(); + } + }]; + let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)"; +} + +def Vector_ScatterOp : + Vector_Op<"scatter">, + Arguments<(ins AnyMemRef:$base, + VectorOfRankAndType<[1], [AnyInteger]>:$indices, + VectorOfRankAndType<[1], [I1]>:$mask, + VectorOfRank<[1]>:$value)> { + + let summary = "scatters elements from a vector into memory as defined by an index vector"; + + let description = [{ + The scatter operation scatters elements from a 1-D vector into memory as + defined by a base and a 1-D index vector, but only if the corresponding + bit in a 1-D mask vector is set. Otherwise, no action is taken for that + element. Informally the semantics are: + ``` + if (mask[0]) MEM[base + index[0]] = value[0] + if (mask[1]) MEM[base + index[1]] = value[1] + etc. + ``` + The vector dialect leaves out-of-bounds and repeated index behavior + undefined. Underlying implementations may enforce strict sequential + semantics for the latter, though. + TODO: enforce the latter always? + + The scatter operation can be used directly where applicable, or can be used + during progressively lowering to bring other memory operations closer to + hardware ISA support for a scatter. The semantics of the operation closely + correspond to those of the `llvm.masked.scatter` + [intrinsic](https://llvm.org/docs/LangRef.html#llvm-masked-scatter-intrinsics). + + Example: + + ```mlir + vector.scatter %base, %indices, %mask, %value + : vector<16xi32>, vector<16xi1>, vector<16xf32> into memref + ``` + }]; + let extraClassDeclaration = [{ + MemRefType getMemRefType() { + return base().getType().cast(); + } + VectorType getIndicesVectorType() { + return indices().getType().cast(); + } + VectorType getMaskVectorType() { + return mask().getType().cast(); + } + VectorType getValueVectorType() { + return value().getType().cast(); + } + }]; + let assemblyFormat = "$base `,` $indices `,` $mask `,` $value attr-dict `:` " + "type($indices) `,` type($mask) `,` type($value) `into` type($base)"; +} + def Vector_ShapeCastOp : Vector_Op<"shape_cast", [NoSideEffect]>, Arguments<(ins AnyTypeOf<[AnyVector, TupleOf<[AnyVector]>]>:$source)>, diff --git a/mlir/integration_test/Dialect/Vector/CPU/test-gather.mlir b/mlir/integration_test/Dialect/Vector/CPU/test-gather.mlir new file mode 100644 index 0000000..5ed8f3e --- /dev/null +++ b/mlir/integration_test/Dialect/Vector/CPU/test-gather.mlir @@ -0,0 +1,97 @@ +// RUN: mlir-opt %s -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \ +// RUN: mlir-cpu-runner -e entry -entry-point-result=void \ +// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ +// RUN: FileCheck %s + +func @gather8(%base: memref, + %indices: vector<8xi32>, %mask: vector<8xi1>) -> vector<8xf32> { + %g = vector.gather %base, %indices, %mask + : (memref, vector<8xi32>, vector<8xi1>) -> vector<8xf32> + return %g : vector<8xf32> +} + +func @gather_pass_thru8(%base: memref, + %indices: vector<8xi32>, %mask: vector<8xi1>, + %pass_thru: vector<8xf32>) -> vector<8xf32> { + %g = vector.gather %base, %indices, %mask, %pass_thru + : (memref, vector<8xi32>, vector<8xi1>, vector<8xf32>) -> vector<8xf32> + return %g : vector<8xf32> +} + +func @entry() { + // Set up memory. + %c0 = constant 0: index + %c1 = constant 1: index + %c10 = constant 10: index + %A = alloc(%c10) : memref + scf.for %i = %c0 to %c10 step %c1 { + %i32 = index_cast %i : index to i32 + %fi = sitofp %i32 : i32 to f32 + store %fi, %A[%i] : memref + } + + // Set up idx vector. + %i0 = constant 0: i32 + %i1 = constant 1: i32 + %i2 = constant 2: i32 + %i3 = constant 3: i32 + %i4 = constant 4: i32 + %i5 = constant 5: i32 + %i6 = constant 6: i32 + %i9 = constant 9: i32 + %0 = vector.broadcast %i0 : i32 to vector<8xi32> + %1 = vector.insert %i6, %0[1] : i32 into vector<8xi32> + %2 = vector.insert %i1, %1[2] : i32 into vector<8xi32> + %3 = vector.insert %i3, %2[3] : i32 into vector<8xi32> + %4 = vector.insert %i5, %3[4] : i32 into vector<8xi32> + %5 = vector.insert %i4, %4[5] : i32 into vector<8xi32> + %6 = vector.insert %i9, %5[6] : i32 into vector<8xi32> + %idx = vector.insert %i2, %6[7] : i32 into vector<8xi32> + + // Set up pass thru vector. + %u = constant -7.0: f32 + %pass = vector.broadcast %u : f32 to vector<8xf32> + + // Set up masks. + %t = constant 1: i1 + %none = vector.constant_mask [0] : vector<8xi1> + %all = vector.constant_mask [8] : vector<8xi1> + %some = vector.constant_mask [4] : vector<8xi1> + %more = vector.insert %t, %some[7] : i1 into vector<8xi1> + + // + // Gather tests. + // + + %g1 = call @gather8(%A, %idx, %all) + : (memref, vector<8xi32>, vector<8xi1>) + -> (vector<8xf32>) + vector.print %g1 : vector<8xf32> + // CHECK: ( 0, 6, 1, 3, 5, 4, 9, 2 ) + + %g2 = call @gather_pass_thru8(%A, %idx, %none, %pass) + : (memref, vector<8xi32>, vector<8xi1>, vector<8xf32>) + -> (vector<8xf32>) + vector.print %g2 : vector<8xf32> + // CHECK: ( -7, -7, -7, -7, -7, -7, -7, -7 ) + + %g3 = call @gather_pass_thru8(%A, %idx, %some, %pass) + : (memref, vector<8xi32>, vector<8xi1>, vector<8xf32>) + -> (vector<8xf32>) + vector.print %g3 : vector<8xf32> + // CHECK: ( 0, 6, 1, 3, -7, -7, -7, -7 ) + + %g4 = call @gather_pass_thru8(%A, %idx, %more, %pass) + : (memref, vector<8xi32>, vector<8xi1>, vector<8xf32>) + -> (vector<8xf32>) + vector.print %g4 : vector<8xf32> + // CHECK: ( 0, 6, 1, 3, -7, -7, -7, 2 ) + + %g5 = call @gather_pass_thru8(%A, %idx, %all, %pass) + : (memref, vector<8xi32>, vector<8xi1>, vector<8xf32>) + -> (vector<8xf32>) + vector.print %g5 : vector<8xf32> + // CHECK: ( 0, 6, 1, 3, 5, 4, 9, 2 ) + + return +} diff --git a/mlir/integration_test/Dialect/Vector/CPU/test-scatter.mlir b/mlir/integration_test/Dialect/Vector/CPU/test-scatter.mlir new file mode 100644 index 0000000..6dd0cf1 --- /dev/null +++ b/mlir/integration_test/Dialect/Vector/CPU/test-scatter.mlir @@ -0,0 +1,135 @@ +// RUN: mlir-opt %s -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \ +// RUN: mlir-cpu-runner -e entry -entry-point-result=void \ +// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ +// RUN: FileCheck %s + +func @scatter8(%base: memref, + %indices: vector<8xi32>, + %mask: vector<8xi1>, %value: vector<8xf32>) { + vector.scatter %base, %indices, %mask, %value + : vector<8xi32>, vector<8xi1>, vector<8xf32> into memref + return +} + +func @printmem(%A: memref) { + %f = constant 0.0: f32 + %0 = vector.broadcast %f : f32 to vector<8xf32> + %1 = constant 0: index + %2 = load %A[%1] : memref + %3 = vector.insert %2, %0[0] : f32 into vector<8xf32> + %4 = constant 1: index + %5 = load %A[%4] : memref + %6 = vector.insert %5, %3[1] : f32 into vector<8xf32> + %7 = constant 2: index + %8 = load %A[%7] : memref + %9 = vector.insert %8, %6[2] : f32 into vector<8xf32> + %10 = constant 3: index + %11 = load %A[%10] : memref + %12 = vector.insert %11, %9[3] : f32 into vector<8xf32> + %13 = constant 4: index + %14 = load %A[%13] : memref + %15 = vector.insert %14, %12[4] : f32 into vector<8xf32> + %16 = constant 5: index + %17 = load %A[%16] : memref + %18 = vector.insert %17, %15[5] : f32 into vector<8xf32> + %19 = constant 6: index + %20 = load %A[%19] : memref + %21 = vector.insert %20, %18[6] : f32 into vector<8xf32> + %22 = constant 7: index + %23 = load %A[%22] : memref + %24 = vector.insert %23, %21[7] : f32 into vector<8xf32> + vector.print %24 : vector<8xf32> + return +} + +func @entry() { + // Set up memory. + %c0 = constant 0: index + %c1 = constant 1: index + %c8 = constant 8: index + %A = alloc(%c8) : memref + scf.for %i = %c0 to %c8 step %c1 { + %i32 = index_cast %i : index to i32 + %fi = sitofp %i32 : i32 to f32 + store %fi, %A[%i] : memref + } + + // Set up idx vector. + %i0 = constant 0: i32 + %i1 = constant 1: i32 + %i2 = constant 2: i32 + %i3 = constant 3: i32 + %i4 = constant 4: i32 + %i5 = constant 5: i32 + %i6 = constant 6: i32 + %i7 = constant 7: i32 + %0 = vector.broadcast %i7 : i32 to vector<8xi32> + %1 = vector.insert %i0, %0[1] : i32 into vector<8xi32> + %2 = vector.insert %i1, %1[2] : i32 into vector<8xi32> + %3 = vector.insert %i6, %2[3] : i32 into vector<8xi32> + %4 = vector.insert %i2, %3[4] : i32 into vector<8xi32> + %5 = vector.insert %i4, %4[5] : i32 into vector<8xi32> + %6 = vector.insert %i5, %5[6] : i32 into vector<8xi32> + %idx = vector.insert %i3, %6[7] : i32 into vector<8xi32> + + // Set up value vector. + %f0 = constant 0.0: f32 + %f1 = constant 1.0: f32 + %f2 = constant 2.0: f32 + %f3 = constant 3.0: f32 + %f4 = constant 4.0: f32 + %f5 = constant 5.0: f32 + %f6 = constant 6.0: f32 + %f7 = constant 7.0: f32 + %7 = vector.broadcast %f0 : f32 to vector<8xf32> + %8 = vector.insert %f1, %7[1] : f32 into vector<8xf32> + %9 = vector.insert %f2, %8[2] : f32 into vector<8xf32> + %10 = vector.insert %f3, %9[3] : f32 into vector<8xf32> + %11 = vector.insert %f4, %10[4] : f32 into vector<8xf32> + %12 = vector.insert %f5, %11[5] : f32 into vector<8xf32> + %13 = vector.insert %f6, %12[6] : f32 into vector<8xf32> + %val = vector.insert %f7, %13[7] : f32 into vector<8xf32> + + // Set up masks. + %t = constant 1: i1 + %none = vector.constant_mask [0] : vector<8xi1> + %some = vector.constant_mask [4] : vector<8xi1> + %more = vector.insert %t, %some[7] : i1 into vector<8xi1> + %all = vector.constant_mask [8] : vector<8xi1> + + // + // Scatter tests. + // + + vector.print %idx : vector<8xi32> + // CHECK: ( 7, 0, 1, 6, 2, 4, 5, 3 ) + + call @printmem(%A) : (memref) -> () + // CHECK: ( 0, 1, 2, 3, 4, 5, 6, 7 ) + + call @scatter8(%A, %idx, %none, %val) + : (memref, vector<8xi32>, vector<8xi1>, vector<8xf32>) -> () + + call @printmem(%A) : (memref) -> () + // CHECK: ( 0, 1, 2, 3, 4, 5, 6, 7 ) + + call @scatter8(%A, %idx, %some, %val) + : (memref, vector<8xi32>, vector<8xi1>, vector<8xf32>) -> () + + call @printmem(%A) : (memref) -> () + // CHECK: ( 1, 2, 2, 3, 4, 5, 3, 0 ) + + call @scatter8(%A, %idx, %more, %val) + : (memref, vector<8xi32>, vector<8xi1>, vector<8xf32>) -> () + + call @printmem(%A) : (memref) -> () + // CHECK: ( 1, 2, 2, 7, 4, 5, 3, 0 ) + + call @scatter8(%A, %idx, %all, %val) + : (memref, vector<8xi32>, vector<8xi1>, vector<8xf32>) -> () + + call @printmem(%A) : (memref) -> () + // CHECK: ( 1, 2, 4, 7, 5, 6, 3, 0 ) + + return +} diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index a59f026..a877bd1 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -34,14 +34,6 @@ using namespace mlir; using namespace mlir::vector; -template -static LLVM::LLVMType getPtrToElementType(T containerType, - LLVMTypeConverter &typeConverter) { - return typeConverter.convertType(containerType.getElementType()) - .template cast() - .getPointerTo(); -} - // Helper to reduce vector type by one rank at front. static VectorType reducedVectorTypeFront(VectorType tp) { assert((tp.getRank() > 1) && "unlowerable vector type"); @@ -124,11 +116,12 @@ static SmallVector getI64SubArray(ArrayAttr arrayAttr, return res; } -template -LogicalResult getVectorTransferAlignment(LLVMTypeConverter &typeConverter, - TransferOp xferOp, unsigned &align) { +// Helper that returns data layout alignment of an operation with memref. +template +LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter, T op, + unsigned &align) { Type elementTy = - typeConverter.convertType(xferOp.getMemRefType().getElementType()); + typeConverter.convertType(op.getMemRefType().getElementType()); if (!elementTy) return failure(); @@ -138,13 +131,54 @@ LogicalResult getVectorTransferAlignment(LLVMTypeConverter &typeConverter, return success(); } +// Helper that returns vector of pointers given a base and an index vector. +LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter, + LLVMTypeConverter &typeConverter, Location loc, + Value memref, Value indices, MemRefType memRefType, + VectorType vType, Type iType, Value &ptrs) { + // Inspect stride and offset structure. + // + // TODO: flat memory only for now, generalize + // + int64_t offset; + SmallVector strides; + auto successStrides = getStridesAndOffset(memRefType, strides, offset); + if (failed(successStrides) || strides.size() != 1 || strides[0] != 1 || + offset != 0 || memRefType.getMemorySpace() != 0) + return failure(); + + // Base pointer. + MemRefDescriptor memRefDescriptor(memref); + Value base = memRefDescriptor.alignedPtr(rewriter, loc); + + // Create a vector of pointers from base and indices. + // + // TODO: this step serializes the address computations unfortunately, + // ideally we would like to add splat(base) + index_vector + // in SIMD form, but this does not match well with current + // constraints of the standard and vector dialect.... + // + int64_t size = vType.getDimSize(0); + auto pType = memRefDescriptor.getElementType(); + auto ptrsType = LLVM::LLVMType::getVectorTy(pType, size); + auto idxType = typeConverter.convertType(iType); + ptrs = rewriter.create(loc, ptrsType); + for (int64_t i = 0; i < size; i++) { + Value off = + extractOne(rewriter, typeConverter, loc, indices, idxType, 1, i); + Value ptr = rewriter.create(loc, pType, base, off); + ptrs = insertOne(rewriter, typeConverter, loc, ptrs, ptr, ptrsType, 1, i); + } + return success(); +} + static LogicalResult replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter, LLVMTypeConverter &typeConverter, Location loc, TransferReadOp xferOp, ArrayRef operands, Value dataPtr) { unsigned align; - if (failed(getVectorTransferAlignment(typeConverter, xferOp, align))) + if (failed(getMemRefAlignment(typeConverter, xferOp, align))) return failure(); rewriter.replaceOpWithNewOp(xferOp, dataPtr, align); return success(); @@ -165,7 +199,7 @@ replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter, return failure(); unsigned align; - if (failed(getVectorTransferAlignment(typeConverter, xferOp, align))) + if (failed(getMemRefAlignment(typeConverter, xferOp, align))) return failure(); rewriter.replaceOpWithNewOp( @@ -180,7 +214,7 @@ replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter, TransferWriteOp xferOp, ArrayRef operands, Value dataPtr) { unsigned align; - if (failed(getVectorTransferAlignment(typeConverter, xferOp, align))) + if (failed(getMemRefAlignment(typeConverter, xferOp, align))) return failure(); auto adaptor = TransferWriteOpAdaptor(operands); rewriter.replaceOpWithNewOp(xferOp, adaptor.vector(), dataPtr, @@ -194,7 +228,7 @@ replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter, TransferWriteOp xferOp, ArrayRef operands, Value dataPtr, Value mask) { unsigned align; - if (failed(getVectorTransferAlignment(typeConverter, xferOp, align))) + if (failed(getMemRefAlignment(typeConverter, xferOp, align))) return failure(); auto adaptor = TransferWriteOpAdaptor(operands); @@ -259,6 +293,83 @@ public: } }; +/// Conversion pattern for a vector.gather. +class VectorGatherOpConversion : public ConvertToLLVMPattern { +public: + explicit VectorGatherOpConversion(MLIRContext *context, + LLVMTypeConverter &typeConverter) + : ConvertToLLVMPattern(vector::GatherOp::getOperationName(), context, + typeConverter) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + auto gather = cast(op); + auto adaptor = vector::GatherOpAdaptor(operands); + + // Resolve alignment. + unsigned align; + if (failed(getMemRefAlignment(typeConverter, gather, align))) + return failure(); + + // Get index ptrs. + VectorType vType = gather.getResultVectorType(); + Type iType = gather.getIndicesVectorType().getElementType(); + Value ptrs; + if (failed(getIndexedPtrs(rewriter, typeConverter, loc, adaptor.base(), + adaptor.indices(), gather.getMemRefType(), vType, + iType, ptrs))) + return failure(); + + // Replace with the gather intrinsic. + ValueRange v = (llvm::size(adaptor.pass_thru()) == 0) ? ValueRange({}) + : adaptor.pass_thru(); + rewriter.replaceOpWithNewOp( + gather, typeConverter.convertType(vType), ptrs, adaptor.mask(), v, + rewriter.getI32IntegerAttr(align)); + return success(); + } +}; + +/// Conversion pattern for a vector.scatter. +class VectorScatterOpConversion : public ConvertToLLVMPattern { +public: + explicit VectorScatterOpConversion(MLIRContext *context, + LLVMTypeConverter &typeConverter) + : ConvertToLLVMPattern(vector::ScatterOp::getOperationName(), context, + typeConverter) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + auto scatter = cast(op); + auto adaptor = vector::ScatterOpAdaptor(operands); + + // Resolve alignment. + unsigned align; + if (failed(getMemRefAlignment(typeConverter, scatter, align))) + return failure(); + + // Get index ptrs. + VectorType vType = scatter.getValueVectorType(); + Type iType = scatter.getIndicesVectorType().getElementType(); + Value ptrs; + if (failed(getIndexedPtrs(rewriter, typeConverter, loc, adaptor.base(), + adaptor.indices(), scatter.getMemRefType(), vType, + iType, ptrs))) + return failure(); + + // Replace with the scatter intrinsic. + rewriter.replaceOpWithNewOp( + scatter, adaptor.value(), ptrs, adaptor.mask(), + rewriter.getI32IntegerAttr(align)); + return success(); + } +}; + +/// Conversion pattern for all vector reductions. class VectorReductionOpConversion : public ConvertToLLVMPattern { public: explicit VectorReductionOpConversion(MLIRContext *context, @@ -1173,7 +1284,9 @@ void mlir::populateVectorToLLVMConversionPatterns( VectorPrintOpConversion, VectorTransferConversion, VectorTransferConversion, - VectorTypeCastOpConversion>(ctx, converter); + VectorTypeCastOpConversion, + VectorGatherOpConversion, + VectorScatterOpConversion>(ctx, converter); // clang-format on } diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp index 03c4079..d16c7c3 100644 --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -1859,6 +1859,49 @@ Optional> TransferWriteOp::getShapeForUnroll() { } //===----------------------------------------------------------------------===// +// GatherOp +//===----------------------------------------------------------------------===// + +static LogicalResult verify(GatherOp op) { + VectorType indicesVType = op.getIndicesVectorType(); + VectorType maskVType = op.getMaskVectorType(); + VectorType resVType = op.getResultVectorType(); + + if (resVType.getElementType() != op.getMemRefType().getElementType()) + return op.emitOpError("base and result element type should match"); + + if (resVType.getDimSize(0) != indicesVType.getDimSize(0)) + return op.emitOpError("expected result dim to match indices dim"); + if (resVType.getDimSize(0) != maskVType.getDimSize(0)) + return op.emitOpError("expected result dim to match mask dim"); + if (llvm::size(op.pass_thru()) != 0) { + VectorType passVType = op.getPassThruVectorType(); + if (resVType != passVType) + return op.emitOpError("expected pass_thru of same type as result type"); + } + return success(); +} + +//===----------------------------------------------------------------------===// +// ScatterOp +//===----------------------------------------------------------------------===// + +static LogicalResult verify(ScatterOp op) { + VectorType indicesVType = op.getIndicesVectorType(); + VectorType maskVType = op.getMaskVectorType(); + VectorType valueVType = op.getValueVectorType(); + + if (valueVType.getElementType() != op.getMemRefType().getElementType()) + return op.emitOpError("base and value element type should match"); + + if (valueVType.getDimSize(0) != indicesVType.getDimSize(0)) + return op.emitOpError("expected value dim to match indices dim"); + if (valueVType.getDimSize(0) != maskVType.getDimSize(0)) + return op.emitOpError("expected value dim to match mask dim"); + return success(); +} + +//===----------------------------------------------------------------------===// // ShapeCastOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir index 874cb5c..69d3aec 100644 --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -969,3 +969,21 @@ func @flat_transpose(%arg0: vector<16xf32>) -> vector<16xf32> { // CHECK-SAME: {columns = 4 : i32, rows = 4 : i32} : // CHECK-SAME: !llvm<"<16 x float>"> into !llvm<"<16 x float>"> // CHECK: llvm.return %[[T]] : !llvm<"<16 x float>"> + +func @gather_op(%arg0: memref, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xf32>) -> vector<3xf32> { + %0 = vector.gather %arg0, %arg1, %arg2, %arg3 : (memref, vector<3xi32>, vector<3xi1>, vector<3xf32>) -> vector<3xf32> + return %0 : vector<3xf32> +} + +// CHECK-LABEL: func @gather_op +// CHECK: %[[G:.*]] = llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm<"<3 x float*>">, !llvm<"<3 x i1>">, !llvm<"<3 x float>">) -> !llvm<"<3 x float>"> +// CHECK: llvm.return %[[G]] : !llvm<"<3 x float>"> + +func @scatter_op(%arg0: memref, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xf32>) { + vector.scatter %arg0, %arg1, %arg2, %arg3 : vector<3xi32>, vector<3xi1>, vector<3xf32> into memref + return +} + +// CHECK-LABEL: func @scatter_op +// CHECK: llvm.intr.masked.scatter %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : !llvm<"<3 x float>">, !llvm<"<3 x i1>"> into !llvm<"<3 x float*>"> +// CHECK: llvm.return diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir index 9164038..ea354f5 100644 --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -1177,3 +1177,66 @@ func @type_cast_layout(%arg0: memref<4x3xf32, affine_map<(d0, d1)[s0, s1, s2] -> // expected-error@+1 {{expects operand to be a memref with no layout}} %0 = vector.type_cast %arg0: memref<4x3xf32, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s0 + d1 * s1 + s2)>> to memref> } + +// ----- + +func @gather_base_type_mismatch(%base: memref, %indices: vector<16xi32>, %mask: vector<16xi1>) { + // expected-error@+1 {{'vector.gather' op base and result element type should match}} + %0 = vector.gather %base, %indices, %mask : (memref, vector<16xi32>, vector<16xi1>) -> vector<16xf32> +} + +// ----- + +func @gather_rank_mismatch(%base: memref, %indices: vector<16xi32>, %mask: vector<16xi1>) { + // expected-error@+1 {{'vector.gather' op result #0 must be of ranks 1, but got 'vector<2x16xf32>'}} + %0 = vector.gather %base, %indices, %mask : (memref, vector<16xi32>, vector<16xi1>) -> vector<2x16xf32> +} + +// ----- + +func @gather_dim_indices_mismatch(%base: memref, %indices: vector<17xi32>, %mask: vector<16xi1>) { + // expected-error@+1 {{'vector.gather' op expected result dim to match indices dim}} + %0 = vector.gather %base, %indices, %mask : (memref, vector<17xi32>, vector<16xi1>) -> vector<16xf32> +} + +// ----- + +func @gather_dim_mask_mismatch(%base: memref, %indices: vector<16xi32>, %mask: vector<17xi1>) { + // expected-error@+1 {{'vector.gather' op expected result dim to match mask dim}} + %0 = vector.gather %base, %indices, %mask : (memref, vector<16xi32>, vector<17xi1>) -> vector<16xf32> +} + +// ----- + +func @gather_pass_thru_type_mismatch(%base: memref, %indices: vector<16xi32>, %mask: vector<16xi1>, %pass_thru: vector<16xf64>) { + // expected-error@+1 {{'vector.gather' op expected pass_thru of same type as result type}} + %0 = vector.gather %base, %indices, %mask, %pass_thru : (memref, vector<16xi32>, vector<16xi1>, vector<16xf64>) -> vector<16xf32> +} + +// ----- + +func @scatter_base_type_mismatch(%base: memref, %indices: vector<16xi32>, %mask: vector<16xi1>, %value: vector<16xf32>) { + // expected-error@+1 {{'vector.scatter' op base and value element type should match}} + vector.scatter %base, %indices, %mask, %value : vector<16xi32>, vector<16xi1>, vector<16xf32> into memref +} + +// ----- + +func @scatter_rank_mismatch(%base: memref, %indices: vector<16xi32>, %mask: vector<16xi1>, %value: vector<2x16xf32>) { + // expected-error@+1 {{'vector.scatter' op operand #3 must be of ranks 1, but got 'vector<2x16xf32>'}} + vector.scatter %base, %indices, %mask, %value : vector<16xi32>, vector<16xi1>, vector<2x16xf32> into memref +} + +// ----- + +func @scatter_dim_indices_mismatch(%base: memref, %indices: vector<17xi32>, %mask: vector<16xi1>, %value: vector<16xf32>) { + // expected-error@+1 {{'vector.scatter' op expected value dim to match indices dim}} + vector.scatter %base, %indices, %mask, %value : vector<17xi32>, vector<16xi1>, vector<16xf32> into memref +} + +// ----- + +func @scatter_dim_mask_mismatch(%base: memref, %indices: vector<16xi32>, %mask: vector<17xi1>, %value: vector<16xf32>) { + // expected-error@+1 {{'vector.scatter' op expected value dim to match mask dim}} + vector.scatter %base, %indices, %mask, %value : vector<16xi32>, vector<17xi1>, vector<16xf32> into memref +} diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir index 4ea7286..0bf4ed8 100644 --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -368,3 +368,14 @@ func @flat_transpose_int(%arg0: vector<16xi32>) -> vector<16xi32> { // CHECK: return %[[X]] : vector<16xi32> return %0 : vector<16xi32> } + +// CHECK-LABEL: @gather_and_scatter +func @gather_and_scatter(%base: memref, %indices: vector<16xi32>, %mask: vector<16xi1>) { + // CHECK: %[[X:.*]] = vector.gather %{{.*}}, %{{.*}}, %{{.*}} : (memref, vector<16xi32>, vector<16xi1>) -> vector<16xf32> + %0 = vector.gather %base, %indices, %mask : (memref, vector<16xi32>, vector<16xi1>) -> vector<16xf32> + // CHECK: %[[Y:.*]] = vector.gather %{{.*}}, %{{.*}}, %{{.*}}, %[[X]] : (memref, vector<16xi32>, vector<16xi1>, vector<16xf32>) -> vector<16xf32> + %1 = vector.gather %base, %indices, %mask, %0 : (memref, vector<16xi32>, vector<16xi1>, vector<16xf32>) -> vector<16xf32> + // CHECK: vector.scatter %{{.*}}, %{{.*}}, %{{.*}}, %[[Y]] : vector<16xi32>, vector<16xi1>, vector<16xf32> into memref + vector.scatter %base, %indices, %mask, %1 : vector<16xi32>, vector<16xi1>, vector<16xf32> into memref + return +} diff --git a/mlir/test/Target/llvmir-intrinsics.mlir b/mlir/test/Target/llvmir-intrinsics.mlir index 1595edf..79b7edb 100644 --- a/mlir/test/Target/llvmir-intrinsics.mlir +++ b/mlir/test/Target/llvmir-intrinsics.mlir @@ -206,6 +206,20 @@ llvm.func @masked_intrinsics(%A: !llvm<"<7 x float>*">, %mask: !llvm<"<7 x i1>"> llvm.return } +// CHECK-LABEL: @masked_gather_scatter_intrinsics +llvm.func @masked_gather_scatter_intrinsics(%M: !llvm<"<7 x float*>">, %mask: !llvm<"<7 x i1>">) { + // CHECK: call <7 x float> @llvm.masked.gather.v7f32.v7p0f32(<7 x float*> %{{.*}}, i32 1, <7 x i1> %{{.*}}, <7 x float> undef) + %a = llvm.intr.masked.gather %M, %mask { alignment = 1: i32} : + (!llvm<"<7 x float*>">, !llvm<"<7 x i1>">) -> !llvm<"<7 x float>"> + // CHECK: call <7 x float> @llvm.masked.gather.v7f32.v7p0f32(<7 x float*> %{{.*}}, i32 1, <7 x i1> %{{.*}}, <7 x float> %{{.*}}) + %b = llvm.intr.masked.gather %M, %mask, %a { alignment = 1: i32} : + (!llvm<"<7 x float*>">, !llvm<"<7 x i1>">, !llvm<"<7 x float>">) -> !llvm<"<7 x float>"> + // CHECK: call void @llvm.masked.scatter.v7f32.v7p0f32(<7 x float> %{{.*}}, <7 x float*> %{{.*}}, i32 1, <7 x i1> %{{.*}}) + llvm.intr.masked.scatter %b, %M, %mask { alignment = 1: i32} : + !llvm<"<7 x float>">, !llvm<"<7 x i1>"> into !llvm<"<7 x float*>"> + llvm.return +} + // CHECK-LABEL: @memcpy_test llvm.func @memcpy_test(%arg0: !llvm.i32, %arg1: !llvm.i1, %arg2: !llvm<"i8*">, %arg3: !llvm<"i8*">) { // CHECK: call void @llvm.memcpy.p0i8.p0i8.i32(i8* %{{.*}}, i8* %{{.*}}, i32 %{{.*}}, i1 %{{.*}})