From 2823b685804b3419c29f3fd8480f4d1ad4fb5c17 Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Fri, 18 Oct 2019 13:48:26 -0700 Subject: [PATCH] Implement lowering of VectorTypeCastOp to LLVM A VectorTypeCastOp can only be used to lower between statically sized contiguous memrefs of scalar and matching vector type. The sizes and strides are thus fully static and easy to determine. A relevant test is added. This is a step towards solving tensorflow/mlir#189. PiperOrigin-RevId: 275538981 --- .../StandardToLLVM/ConvertStandardToLLVM.h | 5 + .../StandardToLLVM/ConvertStandardToLLVM.cpp | 35 ++++--- mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp | 106 ++++++++++++++++++++- .../Conversion/VectorToLLVM/vector-to-llvm.mlir | 21 +++- 4 files changed, 150 insertions(+), 17 deletions(-) diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h index a5d0d3c..bd21c35 100644 --- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h +++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h @@ -82,6 +82,11 @@ public: Value *promoteOneMemRefDescriptor(Location loc, Value *operand, OpBuilder &builder); + static constexpr unsigned kPtrPosInMemRefDescriptor = 0; + static constexpr unsigned kOffsetPosInMemRefDescriptor = 1; + static constexpr unsigned kSizePosInMemRefDescriptor = 2; + static constexpr unsigned kStridePosInMemRefDescriptor = 3; + protected: /// LLVM IR module used to parse/create types. llvm::Module *module; diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp index 15f61ab9..490b669 100644 --- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp @@ -156,10 +156,10 @@ LLVM::LLVMType LLVMTypeConverter::convertFunctionSignature( // int64_t sizes[Rank]; // omitted when rank == 0 // int64_t strides[Rank]; // omitted when rank == 0 // }; -static unsigned kPtrPosInMemRefDescriptor = 0; -static unsigned kOffsetPosInMemRefDescriptor = 1; -static unsigned kSizePosInMemRefDescriptor = 2; -static unsigned kStridePosInMemRefDescriptor = 3; +constexpr unsigned LLVMTypeConverter::kPtrPosInMemRefDescriptor; +constexpr unsigned LLVMTypeConverter::kOffsetPosInMemRefDescriptor; +constexpr unsigned LLVMTypeConverter::kSizePosInMemRefDescriptor; +constexpr unsigned LLVMTypeConverter::kStridePosInMemRefDescriptor; Type LLVMTypeConverter::convertMemRefType(MemRefType type) { int64_t offset; SmallVector strides; @@ -282,7 +282,8 @@ public: Type elementTypePtr) { return builder.create( loc, elementTypePtr, memref, - builder.getIndexArrayAttr(kPtrPosInMemRefDescriptor)); + builder.getIndexArrayAttr( + LLVMTypeConverter::kPtrPosInMemRefDescriptor)); } protected: @@ -763,11 +764,13 @@ struct AllocOpLowering : public LLVMLegalizationPattern { memRefDescriptor = rewriter.create( op->getLoc(), structType, memRefDescriptor, allocated, - rewriter.getIndexArrayAttr(kPtrPosInMemRefDescriptor)); + rewriter.getIndexArrayAttr( + LLVMTypeConverter::kPtrPosInMemRefDescriptor)); memRefDescriptor = rewriter.create( op->getLoc(), structType, memRefDescriptor, createIndexConstant(rewriter, op->getLoc(), offset), - rewriter.getIndexArrayAttr(kOffsetPosInMemRefDescriptor)); + rewriter.getIndexArrayAttr( + LLVMTypeConverter::kOffsetPosInMemRefDescriptor)); if (type.getRank() == 0) // No size/stride descriptor in memref, return the descriptor value. @@ -798,10 +801,12 @@ struct AllocOpLowering : public LLVMLegalizationPattern { int64_t index = indexedSize.index(); memRefDescriptor = rewriter.create( op->getLoc(), structType, memRefDescriptor, indexedSize.value(), - rewriter.getI64ArrayAttr({kSizePosInMemRefDescriptor, index})); + rewriter.getI64ArrayAttr( + {LLVMTypeConverter::kSizePosInMemRefDescriptor, index})); memRefDescriptor = rewriter.create( op->getLoc(), structType, memRefDescriptor, strideValues[index], - rewriter.getI64ArrayAttr({kStridePosInMemRefDescriptor, index})); + rewriter.getI64ArrayAttr( + {LLVMTypeConverter::kStridePosInMemRefDescriptor, index})); } // Return the final value of the descriptor. @@ -896,7 +901,8 @@ struct DeallocOpLowering : public LLVMLegalizationPattern { } auto type = transformed.memref()->getType().cast(); - Type elementPtrType = type.getStructElementType(kPtrPosInMemRefDescriptor); + Type elementPtrType = + type.getStructElementType(LLVMTypeConverter::kPtrPosInMemRefDescriptor); Value *bufferPtr = extractMemRefElementPtr( rewriter, op->getLoc(), transformed.memref(), elementPtrType); Value *casted = rewriter.create( @@ -952,7 +958,8 @@ struct DimOpLowering : public LLVMLegalizationPattern { if (ShapedType::isDynamic(shape[index])) rewriter.replaceOpWithNewOp( op, getIndexType(), transformed.memrefOrTensor(), - rewriter.getI64ArrayAttr({kSizePosInMemRefDescriptor, index})); + rewriter.getI64ArrayAttr( + {LLVMTypeConverter::kSizePosInMemRefDescriptor, index})); else // Use constant for static size. rewriter.replaceOp( @@ -1015,7 +1022,8 @@ struct LoadStoreOpLowering : public LLVMLegalizationPattern { offset == MemRefType::getDynamicStrideOrOffset() ? rewriter.create( loc, indexTy, memRefDescriptor, - rewriter.getIndexArrayAttr(kOffsetPosInMemRefDescriptor)) + rewriter.getIndexArrayAttr( + LLVMTypeConverter::kOffsetPosInMemRefDescriptor)) : this->createIndexConstant(rewriter, loc, offset); for (int i = 0, e = indices.size(); i < e; ++i) { Value *stride; @@ -1028,7 +1036,8 @@ struct LoadStoreOpLowering : public LLVMLegalizationPattern { // Use dynamic stride. stride = rewriter.create( loc, indexTy, memRefDescriptor, - rewriter.getIndexArrayAttr({kStridePosInMemRefDescriptor, i})); + rewriter.getIndexArrayAttr( + {LLVMTypeConverter::kStridePosInMemRefDescriptor, i})); } Value *additionalOffset = rewriter.create(loc, indices[i], stride); diff --git a/mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp index 3c3a18d..765c25a 100644 --- a/mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp @@ -155,10 +155,112 @@ public: } }; +class VectorTypeCastOpConversion : public LLVMOpLowering { +public: + explicit VectorTypeCastOpConversion(MLIRContext *context, + LLVMTypeConverter &typeConverter) + : LLVMOpLowering(vector::VectorTypeCastOp::getOperationName(), context, + typeConverter) {} + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + vector::VectorTypeCastOp castOp = cast(op); + MemRefType sourceMemRefType = + castOp.getOperand()->getType().cast(); + MemRefType targetMemRefType = + castOp.getResult()->getType().cast(); + + // Only static shape casts supported atm. + if (!sourceMemRefType.hasStaticShape() || + !targetMemRefType.hasStaticShape()) + return matchFailure(); + + Value *sourceMemRef = operands[0]; + auto llvmSourceDescriptorTy = + sourceMemRef->getType().dyn_cast(); + if (!llvmSourceDescriptorTy || !llvmSourceDescriptorTy.isStructTy()) + return matchFailure(); + + auto llvmTargetDescriptorTy = lowering.convertType(targetMemRefType) + .dyn_cast_or_null(); + if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy()) + return matchFailure(); + + Type llvmSourceElementTy = llvmSourceDescriptorTy.getStructElementType( + LLVMTypeConverter::kPtrPosInMemRefDescriptor); + Type llvmTargetElementTy = llvmTargetDescriptorTy.getStructElementType( + LLVMTypeConverter::kPtrPosInMemRefDescriptor); + + int64_t offset; + SmallVector strides; + auto successStrides = + getStridesAndOffset(targetMemRefType, strides, offset); + bool isContiguous = (strides.back() == 1); + if (isContiguous) { + auto sizes = targetMemRefType.getShape(); + for (int index = 0, e = strides.size() - 2; index < e; ++index) { + if (strides[index] != strides[index + 1] * sizes[index + 1]) { + isContiguous = false; + break; + } + } + } + // Only contiguous tensors supported atm. + if (failed(successStrides) || !isContiguous) + return matchFailure(); + + auto int64Ty = LLVM::LLVMType::getInt64Ty(lowering.getDialect()); + + // Create descriptor. + Value *desc = rewriter.create(loc, llvmTargetDescriptorTy); + // Set ptr. + Value *ptr = rewriter.create( + loc, llvmSourceElementTy, sourceMemRef, + rewriter.getIndexArrayAttr( + LLVMTypeConverter::kPtrPosInMemRefDescriptor)); + ptr = rewriter.create(loc, llvmTargetElementTy, ptr); + desc = rewriter.create( + op->getLoc(), llvmTargetDescriptorTy, desc, ptr, + rewriter.getIndexArrayAttr( + LLVMTypeConverter::kPtrPosInMemRefDescriptor)); + // Fill offset 0. + auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0); + auto zero = rewriter.create(loc, int64Ty, attr); + desc = rewriter.create( + op->getLoc(), llvmTargetDescriptorTy, desc, zero, + rewriter.getIndexArrayAttr( + LLVMTypeConverter::kOffsetPosInMemRefDescriptor)); + // Fill size and stride descriptors in memref. + for (auto indexedSize : llvm::enumerate(targetMemRefType.getShape())) { + int64_t index = indexedSize.index(); + auto sizeAttr = + rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value()); + auto size = rewriter.create(loc, int64Ty, sizeAttr); + desc = rewriter.create( + op->getLoc(), llvmTargetDescriptorTy, desc, size, + rewriter.getI64ArrayAttr( + {LLVMTypeConverter::kSizePosInMemRefDescriptor, index})); + auto strideAttr = + rewriter.getIntegerAttr(rewriter.getIndexType(), strides[index]); + auto stride = rewriter.create(loc, int64Ty, strideAttr); + desc = rewriter.create( + op->getLoc(), llvmTargetDescriptorTy, desc, stride, + rewriter.getI64ArrayAttr( + {LLVMTypeConverter::kStridePosInMemRefDescriptor, index})); + } + + rewriter.replaceOp(op, desc); + return matchSuccess(); + } +}; + /// Populate the given list with patterns that convert from Vector to LLVM. void mlir::populateVectorToLLVMConversionPatterns( LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { - patterns.insert( + patterns.insert( converter.getDialect()->getContext(), converter); } @@ -190,5 +292,5 @@ OpPassBase *mlir::createLowerVectorToLLVMPass() { } static PassRegistration - pass("vector-lower-to-llvm-dialect", + pass("convert-vector-to-llvm", "Lower the operations from the vector dialect into the LLVM dialect"); diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir index fa2345e..aefeca4 100644 --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -vector-lower-to-llvm-dialect | FileCheck %s +// RUN: mlir-opt %s -convert-vector-to-llvm | FileCheck %s func @outerproduct(%arg0: vector<2xf32>, %arg1: vector<3xf32>) -> vector<2x3xf32> { %2 = vector.outerproduct %arg0, %arg1 : vector<2xf32>, vector<3xf32> @@ -46,4 +46,21 @@ func @extract_element_from_vec_3d(%arg0: vector<4x3x16xf32>) -> f32 { // CHECK: llvm.extractvalue %{{.*}}[0 : i32, 0 : i32] : !llvm<"[4 x [3 x <16 x float>]]"> // CHECK: llvm.mlir.constant(0 : i32) : !llvm.i32 // CHECK: llvm.extractelement %{{.*}}, %{{.*}} : !llvm<"<16 x float>"> -// CHECK: llvm.return %{{.*}} : !llvm.float \ No newline at end of file +// CHECK: llvm.return %{{.*}} : !llvm.float + +func @vector_type_cast(%arg0: memref<8x8x8xf32>) -> memref<1xvector<8x8x8xf32>> { + %0 = vector.type_cast %arg0: memref<8x8x8xf32>, memref<1xvector<8x8x8xf32>> + return %0 : memref<1xvector<8x8x8xf32>> +} +// CHECK-LABEL: vector_type_cast +// CHECK: llvm.mlir.undef : !llvm<"{ [8 x [8 x <8 x float>]]*, i64, [1 x i64], [1 x i64] }"> +// CHECK: %[[ptr:.*]] = llvm.extractvalue {{.*}}[0 : index] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }"> +// CHECK: %[[bit:.*]] = llvm.bitcast %[[ptr]] : !llvm<"float*"> to !llvm<"[8 x [8 x <8 x float>]]*"> +// CHECK: llvm.insertvalue %[[bit]], {{.*}}[0 : index] : !llvm<"{ [8 x [8 x <8 x float>]]*, i64, [1 x i64], [1 x i64] }"> +// CHECK: llvm.mlir.constant(0 : index +// CHECK: llvm.insertvalue {{.*}}[1 : index] : !llvm<"{ [8 x [8 x <8 x float>]]*, i64, [1 x i64], [1 x i64] }"> +// CHECK: llvm.mlir.constant(1 : index +// CHECK: llvm.insertvalue {{.*}}[2, 0] : !llvm<"{ [8 x [8 x <8 x float>]]*, i64, [1 x i64], [1 x i64] }"> +// CHECK: llvm.mlir.constant(1 : index +// CHECK: llvm.insertvalue {{.*}}[3, 0] : !llvm<"{ [8 x [8 x <8 x float>]]*, i64, [1 x i64], [1 x i64] }"> + -- 2.7.4