let assemblyFormat = "$value `,` $data `,` $mask attr-dict `:` "
"type($value) `,` type($mask) `into` type($data)";
}
+
+/// Create a call to Masked Gather intrinsic.
+def LLVM_masked_gather
+ : LLVM_OneResultOp<"intr.masked.gather">,
+ Arguments<(ins LLVM_Type:$ptrs, LLVM_Type:$mask,
+ Variadic<LLVM_Type>:$pass_thru, I32Attr:$alignment)> {
+ string llvmBuilder = [{
+ $res = $pass_thru.empty() ? builder.CreateMaskedGather(
+ $ptrs, llvm::Align($alignment.getZExtValue()), $mask) :
+ builder.CreateMaskedGather(
+ $ptrs, llvm::Align($alignment.getZExtValue()), $mask, $pass_thru[0]);
+ }];
+ let assemblyFormat =
+ "operands attr-dict `:` functional-type(operands, results)";
+}
+
+/// Create a call to Masked Scatter intrinsic.
+def LLVM_masked_scatter
+ : LLVM_ZeroResultOp<"intr.masked.scatter">,
+ Arguments<(ins LLVM_Type:$value, LLVM_Type:$ptrs, LLVM_Type:$mask,
+ I32Attr:$alignment)> {
+ string llvmBuilder = [{
+ builder.CreateMaskedScatter(
+ $value, $ptrs, llvm::Align($alignment.getZExtValue()), $mask);
+ }];
+ let assemblyFormat = "$value `,` $ptrs `,` $mask attr-dict `:` "
+ "type($value) `,` type($mask) `into` type($ptrs)";
+}
+
//
// Atomic operations.
//
let hasFolder = 1;
}
+def Vector_GatherOp :
+ Vector_Op<"gather">,
+ Arguments<(ins AnyMemRef:$base,
+ VectorOfRankAndType<[1], [AnyInteger]>:$indices,
+ VectorOfRankAndType<[1], [I1]>:$mask,
+ Variadic<VectorOfRank<[1]>>:$pass_thru)>,
+ Results<(outs VectorOfRank<[1]>:$result)> {
+
+ let summary = "gathers elements from memory into a vector as defined by an index vector";
+
+ let description = [{
+ The gather operation gathers elements from memory into a 1-D vector as
+ defined by a base and a 1-D index vector, but only if the corresponding
+ bit is set in a 1-D mask vector. Otherwise, the element is taken from a
+ 1-D pass-through vector, if provided, or left undefined. Informally the
+ semantics are:
+ ```
+ if (!defined(pass_thru)) pass_thru = [undef, .., undef]
+ result[0] := mask[0] ? MEM[base + index[0]] : pass_thru[0]
+ result[1] := mask[1] ? MEM[base + index[1]] : pass_thru[1]
+ etc.
+ ```
+ The vector dialect leaves out-of-bounds behavior undefined.
+
+ The gather operation can be used directly where applicable, or can be used
+ during progressively lowering to bring other memory operations closer to
+ hardware ISA support for a gather. The semantics of the operation closely
+ correspond to those of the `llvm.masked.gather`
+ [intrinsic](https://llvm.org/docs/LangRef.html#llvm-masked-gather-intrinsics).
+
+ Example:
+
+ ```mlir
+ %g = vector.gather %base, %indices, %mask, %pass_thru
+ : (memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>) -> vector<16xf32>
+ ```
+
+ }];
+ let extraClassDeclaration = [{
+ MemRefType getMemRefType() {
+ return base().getType().cast<MemRefType>();
+ }
+ VectorType getIndicesVectorType() {
+ return indices().getType().cast<VectorType>();
+ }
+ VectorType getMaskVectorType() {
+ return mask().getType().cast<VectorType>();
+ }
+ VectorType getPassThruVectorType() {
+ return (llvm::size(pass_thru()) == 0)
+ ? VectorType()
+ : (*pass_thru().begin()).getType().cast<VectorType>();
+ }
+ VectorType getResultVectorType() {
+ return result().getType().cast<VectorType>();
+ }
+ }];
+ let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
+}
+
+def Vector_ScatterOp :
+ Vector_Op<"scatter">,
+ Arguments<(ins AnyMemRef:$base,
+ VectorOfRankAndType<[1], [AnyInteger]>:$indices,
+ VectorOfRankAndType<[1], [I1]>:$mask,
+ VectorOfRank<[1]>:$value)> {
+
+ let summary = "scatters elements from a vector into memory as defined by an index vector";
+
+ let description = [{
+ The scatter operation scatters elements from a 1-D vector into memory as
+ defined by a base and a 1-D index vector, but only if the corresponding
+ bit in a 1-D mask vector is set. Otherwise, no action is taken for that
+ element. Informally the semantics are:
+ ```
+ if (mask[0]) MEM[base + index[0]] = value[0]
+ if (mask[1]) MEM[base + index[1]] = value[1]
+ etc.
+ ```
+ The vector dialect leaves out-of-bounds and repeated index behavior
+ undefined. Underlying implementations may enforce strict sequential
+ semantics for the latter, though.
+ TODO: enforce the latter always?
+
+ The scatter operation can be used directly where applicable, or can be used
+ during progressively lowering to bring other memory operations closer to
+ hardware ISA support for a scatter. The semantics of the operation closely
+ correspond to those of the `llvm.masked.scatter`
+ [intrinsic](https://llvm.org/docs/LangRef.html#llvm-masked-scatter-intrinsics).
+
+ Example:
+
+ ```mlir
+ vector.scatter %base, %indices, %mask, %value
+ : vector<16xi32>, vector<16xi1>, vector<16xf32> into memref<?f32>
+ ```
+ }];
+ let extraClassDeclaration = [{
+ MemRefType getMemRefType() {
+ return base().getType().cast<MemRefType>();
+ }
+ VectorType getIndicesVectorType() {
+ return indices().getType().cast<VectorType>();
+ }
+ VectorType getMaskVectorType() {
+ return mask().getType().cast<VectorType>();
+ }
+ VectorType getValueVectorType() {
+ return value().getType().cast<VectorType>();
+ }
+ }];
+ let assemblyFormat = "$base `,` $indices `,` $mask `,` $value attr-dict `:` "
+ "type($indices) `,` type($mask) `,` type($value) `into` type($base)";
+}
+
def Vector_ShapeCastOp :
Vector_Op<"shape_cast", [NoSideEffect]>,
Arguments<(ins AnyTypeOf<[AnyVector, TupleOf<[AnyVector]>]>:$source)>,
--- /dev/null
+// RUN: mlir-opt %s -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
+// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
+// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
+// RUN: FileCheck %s
+
+func @gather8(%base: memref<?xf32>,
+ %indices: vector<8xi32>, %mask: vector<8xi1>) -> vector<8xf32> {
+ %g = vector.gather %base, %indices, %mask
+ : (memref<?xf32>, vector<8xi32>, vector<8xi1>) -> vector<8xf32>
+ return %g : vector<8xf32>
+}
+
+func @gather_pass_thru8(%base: memref<?xf32>,
+ %indices: vector<8xi32>, %mask: vector<8xi1>,
+ %pass_thru: vector<8xf32>) -> vector<8xf32> {
+ %g = vector.gather %base, %indices, %mask, %pass_thru
+ : (memref<?xf32>, vector<8xi32>, vector<8xi1>, vector<8xf32>) -> vector<8xf32>
+ return %g : vector<8xf32>
+}
+
+func @entry() {
+ // Set up memory.
+ %c0 = constant 0: index
+ %c1 = constant 1: index
+ %c10 = constant 10: index
+ %A = alloc(%c10) : memref<?xf32>
+ scf.for %i = %c0 to %c10 step %c1 {
+ %i32 = index_cast %i : index to i32
+ %fi = sitofp %i32 : i32 to f32
+ store %fi, %A[%i] : memref<?xf32>
+ }
+
+ // Set up idx vector.
+ %i0 = constant 0: i32
+ %i1 = constant 1: i32
+ %i2 = constant 2: i32
+ %i3 = constant 3: i32
+ %i4 = constant 4: i32
+ %i5 = constant 5: i32
+ %i6 = constant 6: i32
+ %i9 = constant 9: i32
+ %0 = vector.broadcast %i0 : i32 to vector<8xi32>
+ %1 = vector.insert %i6, %0[1] : i32 into vector<8xi32>
+ %2 = vector.insert %i1, %1[2] : i32 into vector<8xi32>
+ %3 = vector.insert %i3, %2[3] : i32 into vector<8xi32>
+ %4 = vector.insert %i5, %3[4] : i32 into vector<8xi32>
+ %5 = vector.insert %i4, %4[5] : i32 into vector<8xi32>
+ %6 = vector.insert %i9, %5[6] : i32 into vector<8xi32>
+ %idx = vector.insert %i2, %6[7] : i32 into vector<8xi32>
+
+ // Set up pass thru vector.
+ %u = constant -7.0: f32
+ %pass = vector.broadcast %u : f32 to vector<8xf32>
+
+ // Set up masks.
+ %t = constant 1: i1
+ %none = vector.constant_mask [0] : vector<8xi1>
+ %all = vector.constant_mask [8] : vector<8xi1>
+ %some = vector.constant_mask [4] : vector<8xi1>
+ %more = vector.insert %t, %some[7] : i1 into vector<8xi1>
+
+ //
+ // Gather tests.
+ //
+
+ %g1 = call @gather8(%A, %idx, %all)
+ : (memref<?xf32>, vector<8xi32>, vector<8xi1>)
+ -> (vector<8xf32>)
+ vector.print %g1 : vector<8xf32>
+ // CHECK: ( 0, 6, 1, 3, 5, 4, 9, 2 )
+
+ %g2 = call @gather_pass_thru8(%A, %idx, %none, %pass)
+ : (memref<?xf32>, vector<8xi32>, vector<8xi1>, vector<8xf32>)
+ -> (vector<8xf32>)
+ vector.print %g2 : vector<8xf32>
+ // CHECK: ( -7, -7, -7, -7, -7, -7, -7, -7 )
+
+ %g3 = call @gather_pass_thru8(%A, %idx, %some, %pass)
+ : (memref<?xf32>, vector<8xi32>, vector<8xi1>, vector<8xf32>)
+ -> (vector<8xf32>)
+ vector.print %g3 : vector<8xf32>
+ // CHECK: ( 0, 6, 1, 3, -7, -7, -7, -7 )
+
+ %g4 = call @gather_pass_thru8(%A, %idx, %more, %pass)
+ : (memref<?xf32>, vector<8xi32>, vector<8xi1>, vector<8xf32>)
+ -> (vector<8xf32>)
+ vector.print %g4 : vector<8xf32>
+ // CHECK: ( 0, 6, 1, 3, -7, -7, -7, 2 )
+
+ %g5 = call @gather_pass_thru8(%A, %idx, %all, %pass)
+ : (memref<?xf32>, vector<8xi32>, vector<8xi1>, vector<8xf32>)
+ -> (vector<8xf32>)
+ vector.print %g5 : vector<8xf32>
+ // CHECK: ( 0, 6, 1, 3, 5, 4, 9, 2 )
+
+ return
+}
--- /dev/null
+// RUN: mlir-opt %s -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
+// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
+// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
+// RUN: FileCheck %s
+
+func @scatter8(%base: memref<?xf32>,
+ %indices: vector<8xi32>,
+ %mask: vector<8xi1>, %value: vector<8xf32>) {
+ vector.scatter %base, %indices, %mask, %value
+ : vector<8xi32>, vector<8xi1>, vector<8xf32> into memref<?xf32>
+ return
+}
+
+func @printmem(%A: memref<?xf32>) {
+ %f = constant 0.0: f32
+ %0 = vector.broadcast %f : f32 to vector<8xf32>
+ %1 = constant 0: index
+ %2 = load %A[%1] : memref<?xf32>
+ %3 = vector.insert %2, %0[0] : f32 into vector<8xf32>
+ %4 = constant 1: index
+ %5 = load %A[%4] : memref<?xf32>
+ %6 = vector.insert %5, %3[1] : f32 into vector<8xf32>
+ %7 = constant 2: index
+ %8 = load %A[%7] : memref<?xf32>
+ %9 = vector.insert %8, %6[2] : f32 into vector<8xf32>
+ %10 = constant 3: index
+ %11 = load %A[%10] : memref<?xf32>
+ %12 = vector.insert %11, %9[3] : f32 into vector<8xf32>
+ %13 = constant 4: index
+ %14 = load %A[%13] : memref<?xf32>
+ %15 = vector.insert %14, %12[4] : f32 into vector<8xf32>
+ %16 = constant 5: index
+ %17 = load %A[%16] : memref<?xf32>
+ %18 = vector.insert %17, %15[5] : f32 into vector<8xf32>
+ %19 = constant 6: index
+ %20 = load %A[%19] : memref<?xf32>
+ %21 = vector.insert %20, %18[6] : f32 into vector<8xf32>
+ %22 = constant 7: index
+ %23 = load %A[%22] : memref<?xf32>
+ %24 = vector.insert %23, %21[7] : f32 into vector<8xf32>
+ vector.print %24 : vector<8xf32>
+ return
+}
+
+func @entry() {
+ // Set up memory.
+ %c0 = constant 0: index
+ %c1 = constant 1: index
+ %c8 = constant 8: index
+ %A = alloc(%c8) : memref<?xf32>
+ scf.for %i = %c0 to %c8 step %c1 {
+ %i32 = index_cast %i : index to i32
+ %fi = sitofp %i32 : i32 to f32
+ store %fi, %A[%i] : memref<?xf32>
+ }
+
+ // Set up idx vector.
+ %i0 = constant 0: i32
+ %i1 = constant 1: i32
+ %i2 = constant 2: i32
+ %i3 = constant 3: i32
+ %i4 = constant 4: i32
+ %i5 = constant 5: i32
+ %i6 = constant 6: i32
+ %i7 = constant 7: i32
+ %0 = vector.broadcast %i7 : i32 to vector<8xi32>
+ %1 = vector.insert %i0, %0[1] : i32 into vector<8xi32>
+ %2 = vector.insert %i1, %1[2] : i32 into vector<8xi32>
+ %3 = vector.insert %i6, %2[3] : i32 into vector<8xi32>
+ %4 = vector.insert %i2, %3[4] : i32 into vector<8xi32>
+ %5 = vector.insert %i4, %4[5] : i32 into vector<8xi32>
+ %6 = vector.insert %i5, %5[6] : i32 into vector<8xi32>
+ %idx = vector.insert %i3, %6[7] : i32 into vector<8xi32>
+
+ // Set up value vector.
+ %f0 = constant 0.0: f32
+ %f1 = constant 1.0: f32
+ %f2 = constant 2.0: f32
+ %f3 = constant 3.0: f32
+ %f4 = constant 4.0: f32
+ %f5 = constant 5.0: f32
+ %f6 = constant 6.0: f32
+ %f7 = constant 7.0: f32
+ %7 = vector.broadcast %f0 : f32 to vector<8xf32>
+ %8 = vector.insert %f1, %7[1] : f32 into vector<8xf32>
+ %9 = vector.insert %f2, %8[2] : f32 into vector<8xf32>
+ %10 = vector.insert %f3, %9[3] : f32 into vector<8xf32>
+ %11 = vector.insert %f4, %10[4] : f32 into vector<8xf32>
+ %12 = vector.insert %f5, %11[5] : f32 into vector<8xf32>
+ %13 = vector.insert %f6, %12[6] : f32 into vector<8xf32>
+ %val = vector.insert %f7, %13[7] : f32 into vector<8xf32>
+
+ // Set up masks.
+ %t = constant 1: i1
+ %none = vector.constant_mask [0] : vector<8xi1>
+ %some = vector.constant_mask [4] : vector<8xi1>
+ %more = vector.insert %t, %some[7] : i1 into vector<8xi1>
+ %all = vector.constant_mask [8] : vector<8xi1>
+
+ //
+ // Scatter tests.
+ //
+
+ vector.print %idx : vector<8xi32>
+ // CHECK: ( 7, 0, 1, 6, 2, 4, 5, 3 )
+
+ call @printmem(%A) : (memref<?xf32>) -> ()
+ // CHECK: ( 0, 1, 2, 3, 4, 5, 6, 7 )
+
+ call @scatter8(%A, %idx, %none, %val)
+ : (memref<?xf32>, vector<8xi32>, vector<8xi1>, vector<8xf32>) -> ()
+
+ call @printmem(%A) : (memref<?xf32>) -> ()
+ // CHECK: ( 0, 1, 2, 3, 4, 5, 6, 7 )
+
+ call @scatter8(%A, %idx, %some, %val)
+ : (memref<?xf32>, vector<8xi32>, vector<8xi1>, vector<8xf32>) -> ()
+
+ call @printmem(%A) : (memref<?xf32>) -> ()
+ // CHECK: ( 1, 2, 2, 3, 4, 5, 3, 0 )
+
+ call @scatter8(%A, %idx, %more, %val)
+ : (memref<?xf32>, vector<8xi32>, vector<8xi1>, vector<8xf32>) -> ()
+
+ call @printmem(%A) : (memref<?xf32>) -> ()
+ // CHECK: ( 1, 2, 2, 7, 4, 5, 3, 0 )
+
+ call @scatter8(%A, %idx, %all, %val)
+ : (memref<?xf32>, vector<8xi32>, vector<8xi1>, vector<8xf32>) -> ()
+
+ call @printmem(%A) : (memref<?xf32>) -> ()
+ // CHECK: ( 1, 2, 4, 7, 5, 6, 3, 0 )
+
+ return
+}
using namespace mlir;
using namespace mlir::vector;
-template <typename T>
-static LLVM::LLVMType getPtrToElementType(T containerType,
- LLVMTypeConverter &typeConverter) {
- return typeConverter.convertType(containerType.getElementType())
- .template cast<LLVM::LLVMType>()
- .getPointerTo();
-}
-
// Helper to reduce vector type by one rank at front.
static VectorType reducedVectorTypeFront(VectorType tp) {
assert((tp.getRank() > 1) && "unlowerable vector type");
return res;
}
-template <typename TransferOp>
-LogicalResult getVectorTransferAlignment(LLVMTypeConverter &typeConverter,
- TransferOp xferOp, unsigned &align) {
+// Helper that returns data layout alignment of an operation with memref.
+template <typename T>
+LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter, T op,
+ unsigned &align) {
Type elementTy =
- typeConverter.convertType(xferOp.getMemRefType().getElementType());
+ typeConverter.convertType(op.getMemRefType().getElementType());
if (!elementTy)
return failure();
return success();
}
+// Helper that returns vector of pointers given a base and an index vector.
+LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter,
+ LLVMTypeConverter &typeConverter, Location loc,
+ Value memref, Value indices, MemRefType memRefType,
+ VectorType vType, Type iType, Value &ptrs) {
+ // Inspect stride and offset structure.
+ //
+ // TODO: flat memory only for now, generalize
+ //
+ int64_t offset;
+ SmallVector<int64_t, 4> strides;
+ auto successStrides = getStridesAndOffset(memRefType, strides, offset);
+ if (failed(successStrides) || strides.size() != 1 || strides[0] != 1 ||
+ offset != 0 || memRefType.getMemorySpace() != 0)
+ return failure();
+
+ // Base pointer.
+ MemRefDescriptor memRefDescriptor(memref);
+ Value base = memRefDescriptor.alignedPtr(rewriter, loc);
+
+ // Create a vector of pointers from base and indices.
+ //
+ // TODO: this step serializes the address computations unfortunately,
+ // ideally we would like to add splat(base) + index_vector
+ // in SIMD form, but this does not match well with current
+ // constraints of the standard and vector dialect....
+ //
+ int64_t size = vType.getDimSize(0);
+ auto pType = memRefDescriptor.getElementType();
+ auto ptrsType = LLVM::LLVMType::getVectorTy(pType, size);
+ auto idxType = typeConverter.convertType(iType);
+ ptrs = rewriter.create<LLVM::UndefOp>(loc, ptrsType);
+ for (int64_t i = 0; i < size; i++) {
+ Value off =
+ extractOne(rewriter, typeConverter, loc, indices, idxType, 1, i);
+ Value ptr = rewriter.create<LLVM::GEPOp>(loc, pType, base, off);
+ ptrs = insertOne(rewriter, typeConverter, loc, ptrs, ptr, ptrsType, 1, i);
+ }
+ return success();
+}
+
static LogicalResult
replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter,
LLVMTypeConverter &typeConverter, Location loc,
TransferReadOp xferOp,
ArrayRef<Value> operands, Value dataPtr) {
unsigned align;
- if (failed(getVectorTransferAlignment(typeConverter, xferOp, align)))
+ if (failed(getMemRefAlignment(typeConverter, xferOp, align)))
return failure();
rewriter.replaceOpWithNewOp<LLVM::LoadOp>(xferOp, dataPtr, align);
return success();
return failure();
unsigned align;
- if (failed(getVectorTransferAlignment(typeConverter, xferOp, align)))
+ if (failed(getMemRefAlignment(typeConverter, xferOp, align)))
return failure();
rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
TransferWriteOp xferOp,
ArrayRef<Value> operands, Value dataPtr) {
unsigned align;
- if (failed(getVectorTransferAlignment(typeConverter, xferOp, align)))
+ if (failed(getMemRefAlignment(typeConverter, xferOp, align)))
return failure();
auto adaptor = TransferWriteOpAdaptor(operands);
rewriter.replaceOpWithNewOp<LLVM::StoreOp>(xferOp, adaptor.vector(), dataPtr,
TransferWriteOp xferOp, ArrayRef<Value> operands,
Value dataPtr, Value mask) {
unsigned align;
- if (failed(getVectorTransferAlignment(typeConverter, xferOp, align)))
+ if (failed(getMemRefAlignment(typeConverter, xferOp, align)))
return failure();
auto adaptor = TransferWriteOpAdaptor(operands);
}
};
+/// Conversion pattern for a vector.gather.
+class VectorGatherOpConversion : public ConvertToLLVMPattern {
+public:
+ explicit VectorGatherOpConversion(MLIRContext *context,
+ LLVMTypeConverter &typeConverter)
+ : ConvertToLLVMPattern(vector::GatherOp::getOperationName(), context,
+ typeConverter) {}
+
+ LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = op->getLoc();
+ auto gather = cast<vector::GatherOp>(op);
+ auto adaptor = vector::GatherOpAdaptor(operands);
+
+ // Resolve alignment.
+ unsigned align;
+ if (failed(getMemRefAlignment(typeConverter, gather, align)))
+ return failure();
+
+ // Get index ptrs.
+ VectorType vType = gather.getResultVectorType();
+ Type iType = gather.getIndicesVectorType().getElementType();
+ Value ptrs;
+ if (failed(getIndexedPtrs(rewriter, typeConverter, loc, adaptor.base(),
+ adaptor.indices(), gather.getMemRefType(), vType,
+ iType, ptrs)))
+ return failure();
+
+ // Replace with the gather intrinsic.
+ ValueRange v = (llvm::size(adaptor.pass_thru()) == 0) ? ValueRange({})
+ : adaptor.pass_thru();
+ rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
+ gather, typeConverter.convertType(vType), ptrs, adaptor.mask(), v,
+ rewriter.getI32IntegerAttr(align));
+ return success();
+ }
+};
+
+/// Conversion pattern for a vector.scatter.
+class VectorScatterOpConversion : public ConvertToLLVMPattern {
+public:
+ explicit VectorScatterOpConversion(MLIRContext *context,
+ LLVMTypeConverter &typeConverter)
+ : ConvertToLLVMPattern(vector::ScatterOp::getOperationName(), context,
+ typeConverter) {}
+
+ LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = op->getLoc();
+ auto scatter = cast<vector::ScatterOp>(op);
+ auto adaptor = vector::ScatterOpAdaptor(operands);
+
+ // Resolve alignment.
+ unsigned align;
+ if (failed(getMemRefAlignment(typeConverter, scatter, align)))
+ return failure();
+
+ // Get index ptrs.
+ VectorType vType = scatter.getValueVectorType();
+ Type iType = scatter.getIndicesVectorType().getElementType();
+ Value ptrs;
+ if (failed(getIndexedPtrs(rewriter, typeConverter, loc, adaptor.base(),
+ adaptor.indices(), scatter.getMemRefType(), vType,
+ iType, ptrs)))
+ return failure();
+
+ // Replace with the scatter intrinsic.
+ rewriter.replaceOpWithNewOp<LLVM::masked_scatter>(
+ scatter, adaptor.value(), ptrs, adaptor.mask(),
+ rewriter.getI32IntegerAttr(align));
+ return success();
+ }
+};
+
+/// Conversion pattern for all vector reductions.
class VectorReductionOpConversion : public ConvertToLLVMPattern {
public:
explicit VectorReductionOpConversion(MLIRContext *context,
VectorPrintOpConversion,
VectorTransferConversion<TransferReadOp>,
VectorTransferConversion<TransferWriteOp>,
- VectorTypeCastOpConversion>(ctx, converter);
+ VectorTypeCastOpConversion,
+ VectorGatherOpConversion,
+ VectorScatterOpConversion>(ctx, converter);
// clang-format on
}
}
//===----------------------------------------------------------------------===//
+// GatherOp
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verify(GatherOp op) {
+ VectorType indicesVType = op.getIndicesVectorType();
+ VectorType maskVType = op.getMaskVectorType();
+ VectorType resVType = op.getResultVectorType();
+
+ if (resVType.getElementType() != op.getMemRefType().getElementType())
+ return op.emitOpError("base and result element type should match");
+
+ if (resVType.getDimSize(0) != indicesVType.getDimSize(0))
+ return op.emitOpError("expected result dim to match indices dim");
+ if (resVType.getDimSize(0) != maskVType.getDimSize(0))
+ return op.emitOpError("expected result dim to match mask dim");
+ if (llvm::size(op.pass_thru()) != 0) {
+ VectorType passVType = op.getPassThruVectorType();
+ if (resVType != passVType)
+ return op.emitOpError("expected pass_thru of same type as result type");
+ }
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// ScatterOp
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verify(ScatterOp op) {
+ VectorType indicesVType = op.getIndicesVectorType();
+ VectorType maskVType = op.getMaskVectorType();
+ VectorType valueVType = op.getValueVectorType();
+
+ if (valueVType.getElementType() != op.getMemRefType().getElementType())
+ return op.emitOpError("base and value element type should match");
+
+ if (valueVType.getDimSize(0) != indicesVType.getDimSize(0))
+ return op.emitOpError("expected value dim to match indices dim");
+ if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
+ return op.emitOpError("expected value dim to match mask dim");
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
// ShapeCastOp
//===----------------------------------------------------------------------===//
// CHECK-SAME: {columns = 4 : i32, rows = 4 : i32} :
// CHECK-SAME: !llvm<"<16 x float>"> into !llvm<"<16 x float>">
// CHECK: llvm.return %[[T]] : !llvm<"<16 x float>">
+
+func @gather_op(%arg0: memref<?xf32>, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xf32>) -> vector<3xf32> {
+ %0 = vector.gather %arg0, %arg1, %arg2, %arg3 : (memref<?xf32>, vector<3xi32>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
+ return %0 : vector<3xf32>
+}
+
+// CHECK-LABEL: func @gather_op
+// CHECK: %[[G:.*]] = llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm<"<3 x float*>">, !llvm<"<3 x i1>">, !llvm<"<3 x float>">) -> !llvm<"<3 x float>">
+// CHECK: llvm.return %[[G]] : !llvm<"<3 x float>">
+
+func @scatter_op(%arg0: memref<?xf32>, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xf32>) {
+ vector.scatter %arg0, %arg1, %arg2, %arg3 : vector<3xi32>, vector<3xi1>, vector<3xf32> into memref<?xf32>
+ return
+}
+
+// CHECK-LABEL: func @scatter_op
+// CHECK: llvm.intr.masked.scatter %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : !llvm<"<3 x float>">, !llvm<"<3 x i1>"> into !llvm<"<3 x float*>">
+// CHECK: llvm.return
// expected-error@+1 {{expects operand to be a memref with no layout}}
%0 = vector.type_cast %arg0: memref<4x3xf32, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s0 + d1 * s1 + s2)>> to memref<vector<4x3xf32>>
}
+
+// -----
+
+func @gather_base_type_mismatch(%base: memref<?xf64>, %indices: vector<16xi32>, %mask: vector<16xi1>) {
+ // expected-error@+1 {{'vector.gather' op base and result element type should match}}
+ %0 = vector.gather %base, %indices, %mask : (memref<?xf64>, vector<16xi32>, vector<16xi1>) -> vector<16xf32>
+}
+
+// -----
+
+func @gather_rank_mismatch(%base: memref<?xf32>, %indices: vector<16xi32>, %mask: vector<16xi1>) {
+ // expected-error@+1 {{'vector.gather' op result #0 must be of ranks 1, but got 'vector<2x16xf32>'}}
+ %0 = vector.gather %base, %indices, %mask : (memref<?xf32>, vector<16xi32>, vector<16xi1>) -> vector<2x16xf32>
+}
+
+// -----
+
+func @gather_dim_indices_mismatch(%base: memref<?xf32>, %indices: vector<17xi32>, %mask: vector<16xi1>) {
+ // expected-error@+1 {{'vector.gather' op expected result dim to match indices dim}}
+ %0 = vector.gather %base, %indices, %mask : (memref<?xf32>, vector<17xi32>, vector<16xi1>) -> vector<16xf32>
+}
+
+// -----
+
+func @gather_dim_mask_mismatch(%base: memref<?xf32>, %indices: vector<16xi32>, %mask: vector<17xi1>) {
+ // expected-error@+1 {{'vector.gather' op expected result dim to match mask dim}}
+ %0 = vector.gather %base, %indices, %mask : (memref<?xf32>, vector<16xi32>, vector<17xi1>) -> vector<16xf32>
+}
+
+// -----
+
+func @gather_pass_thru_type_mismatch(%base: memref<?xf32>, %indices: vector<16xi32>, %mask: vector<16xi1>, %pass_thru: vector<16xf64>) {
+ // expected-error@+1 {{'vector.gather' op expected pass_thru of same type as result type}}
+ %0 = vector.gather %base, %indices, %mask, %pass_thru : (memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf64>) -> vector<16xf32>
+}
+
+// -----
+
+func @scatter_base_type_mismatch(%base: memref<?xf64>, %indices: vector<16xi32>, %mask: vector<16xi1>, %value: vector<16xf32>) {
+ // expected-error@+1 {{'vector.scatter' op base and value element type should match}}
+ vector.scatter %base, %indices, %mask, %value : vector<16xi32>, vector<16xi1>, vector<16xf32> into memref<?xf64>
+}
+
+// -----
+
+func @scatter_rank_mismatch(%base: memref<?xf32>, %indices: vector<16xi32>, %mask: vector<16xi1>, %value: vector<2x16xf32>) {
+ // expected-error@+1 {{'vector.scatter' op operand #3 must be of ranks 1, but got 'vector<2x16xf32>'}}
+ vector.scatter %base, %indices, %mask, %value : vector<16xi32>, vector<16xi1>, vector<2x16xf32> into memref<?xf32>
+}
+
+// -----
+
+func @scatter_dim_indices_mismatch(%base: memref<?xf32>, %indices: vector<17xi32>, %mask: vector<16xi1>, %value: vector<16xf32>) {
+ // expected-error@+1 {{'vector.scatter' op expected value dim to match indices dim}}
+ vector.scatter %base, %indices, %mask, %value : vector<17xi32>, vector<16xi1>, vector<16xf32> into memref<?xf32>
+}
+
+// -----
+
+func @scatter_dim_mask_mismatch(%base: memref<?xf32>, %indices: vector<16xi32>, %mask: vector<17xi1>, %value: vector<16xf32>) {
+ // expected-error@+1 {{'vector.scatter' op expected value dim to match mask dim}}
+ vector.scatter %base, %indices, %mask, %value : vector<16xi32>, vector<17xi1>, vector<16xf32> into memref<?xf32>
+}
// CHECK: return %[[X]] : vector<16xi32>
return %0 : vector<16xi32>
}
+
+// CHECK-LABEL: @gather_and_scatter
+func @gather_and_scatter(%base: memref<?xf32>, %indices: vector<16xi32>, %mask: vector<16xi1>) {
+ // CHECK: %[[X:.*]] = vector.gather %{{.*}}, %{{.*}}, %{{.*}} : (memref<?xf32>, vector<16xi32>, vector<16xi1>) -> vector<16xf32>
+ %0 = vector.gather %base, %indices, %mask : (memref<?xf32>, vector<16xi32>, vector<16xi1>) -> vector<16xf32>
+ // CHECK: %[[Y:.*]] = vector.gather %{{.*}}, %{{.*}}, %{{.*}}, %[[X]] : (memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>) -> vector<16xf32>
+ %1 = vector.gather %base, %indices, %mask, %0 : (memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>) -> vector<16xf32>
+ // CHECK: vector.scatter %{{.*}}, %{{.*}}, %{{.*}}, %[[Y]] : vector<16xi32>, vector<16xi1>, vector<16xf32> into memref<?xf32>
+ vector.scatter %base, %indices, %mask, %1 : vector<16xi32>, vector<16xi1>, vector<16xf32> into memref<?xf32>
+ return
+}
llvm.return
}
+// CHECK-LABEL: @masked_gather_scatter_intrinsics
+llvm.func @masked_gather_scatter_intrinsics(%M: !llvm<"<7 x float*>">, %mask: !llvm<"<7 x i1>">) {
+ // CHECK: call <7 x float> @llvm.masked.gather.v7f32.v7p0f32(<7 x float*> %{{.*}}, i32 1, <7 x i1> %{{.*}}, <7 x float> undef)
+ %a = llvm.intr.masked.gather %M, %mask { alignment = 1: i32} :
+ (!llvm<"<7 x float*>">, !llvm<"<7 x i1>">) -> !llvm<"<7 x float>">
+ // CHECK: call <7 x float> @llvm.masked.gather.v7f32.v7p0f32(<7 x float*> %{{.*}}, i32 1, <7 x i1> %{{.*}}, <7 x float> %{{.*}})
+ %b = llvm.intr.masked.gather %M, %mask, %a { alignment = 1: i32} :
+ (!llvm<"<7 x float*>">, !llvm<"<7 x i1>">, !llvm<"<7 x float>">) -> !llvm<"<7 x float>">
+ // CHECK: call void @llvm.masked.scatter.v7f32.v7p0f32(<7 x float> %{{.*}}, <7 x float*> %{{.*}}, i32 1, <7 x i1> %{{.*}})
+ llvm.intr.masked.scatter %b, %M, %mask { alignment = 1: i32} :
+ !llvm<"<7 x float>">, !llvm<"<7 x i1>"> into !llvm<"<7 x float*>">
+ llvm.return
+}
+
// CHECK-LABEL: @memcpy_test
llvm.func @memcpy_test(%arg0: !llvm.i32, %arg1: !llvm.i1, %arg2: !llvm<"i8*">, %arg3: !llvm<"i8*">) {
// CHECK: call void @llvm.memcpy.p0i8.p0i8.i32(i8* %{{.*}}, i8* %{{.*}}, i32 %{{.*}}, i1 %{{.*}})