From 766ce87e9bed89bc3b5c2c904f1eb2d10be0d3be Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Sun, 5 Jan 2020 19:37:56 -0500 Subject: [PATCH] [mlir][Linalg] Lower linalg.reshape to LLVM for the static case Summary: This diff adds lowering of the linalg.reshape op to LLVM. A new descriptor is created with fields initialized as follows: 1. allocatedPTr, alignedPtr and offset are copied from the source descriptor 2. sizes are copied from the static destination shape 3. strides are copied from the static strides collected with `getStridesAndOffset` Only the static case in which the target view conforms to strided memref semantics is supported. Other cases are left for future work and will be added on a per-need basis. Reviewers: ftynse, mravishankar Subscribers: mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, arpith-jacob, mgester, lucyrfox, llvm-commits Tags: #llvm Differential Revision: https://reviews.llvm.org/D72316 --- mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp | 52 +++++++++++++++++++- mlir/test/Dialect/Linalg/llvm.mlir | 60 +++++++++++++++++++++++ 2 files changed, 110 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp index 2dd36c9..86890b1 100644 --- a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp +++ b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp @@ -122,8 +122,14 @@ public: void setOffset(Value v) { d.setOffset(rewriter(), loc(), v); } Value size(unsigned i) { return d.size(rewriter(), loc(), i); } void setSize(unsigned i, Value v) { d.setSize(rewriter(), loc(), i, v); } + void setConstantSize(unsigned i, int64_t v) { + d.setConstantSize(rewriter(), loc(), i, v); + } Value stride(unsigned i) { return d.stride(rewriter(), loc(), i); } void setStride(unsigned i, Value v) { d.setStride(rewriter(), loc(), i, v); } + void setConstantStride(unsigned i, int64_t v) { + d.setConstantStride(rewriter(), loc(), i, v); + } operator Value() { return d; } @@ -161,6 +167,48 @@ public: } }; +// ReshapeOp creates a new view descriptor of the proper rank. +// For now, the only conversion supported is for target MemRef with static sizes +// and strides. +class ReshapeOpConversion : public LLVMOpLowering { +public: + explicit ReshapeOpConversion(MLIRContext *context, + LLVMTypeConverter &lowering_) + : LLVMOpLowering(ReshapeOp::getOperationName(), context, lowering_) {} + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto reshapeOp = cast(op); + MemRefType dstType = reshapeOp.getResult().getType().cast(); + + if (!dstType.hasStaticShape()) + return matchFailure(); + + int64_t offset; + SmallVector strides; + auto res = getStridesAndOffset(dstType, strides, offset); + if (failed(res) || llvm::any_of(strides, [](int64_t val) { + return ShapedType::isDynamicStrideOrOffset(val); + })) + return matchFailure(); + + edsc::ScopedContext context(rewriter, op->getLoc()); + ReshapeOpOperandAdaptor adaptor(operands); + BaseViewConversionHelper baseDesc(adaptor.view()); + BaseViewConversionHelper desc(lowering.convertType(dstType)); + desc.setAllocatedPtr(baseDesc.allocatedPtr()); + desc.setAlignedPtr(baseDesc.alignedPtr()); + desc.setOffset(baseDesc.offset()); + for (auto en : llvm::enumerate(dstType.getShape())) + desc.setConstantSize(en.index(), en.value()); + for (auto en : llvm::enumerate(strides)) + desc.setConstantStride(en.index(), en.value()); + rewriter.replaceOp(op, {desc}); + return matchSuccess(); + } +}; + /// Conversion pattern that transforms a linalg.slice op into: /// 1. A function entry `alloca` operation to allocate a ViewDescriptor. /// 2. A load of the ViewDescriptor from the pointer allocated in 1. @@ -508,8 +556,8 @@ populateLinalgToStandardConversionPatterns(OwningRewritePatternList &patterns, void mlir::populateLinalgToLLVMConversionPatterns( LinalgTypeConverter &converter, OwningRewritePatternList &patterns, MLIRContext *ctx) { - patterns.insert(ctx, converter); + patterns.insert(ctx, converter); } namespace { diff --git a/mlir/test/Dialect/Linalg/llvm.mlir b/mlir/test/Dialect/Linalg/llvm.mlir index 7054a3d..d70ee8c 100644 --- a/mlir/test/Dialect/Linalg/llvm.mlir +++ b/mlir/test/Dialect/Linalg/llvm.mlir @@ -196,3 +196,63 @@ func @matmul_vec_indexed(%A: !matrix_type_A, // CHECK-LABEL: func @matmul_vec_indexed( // CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 // CHECK: llvm.call @external_indexed_outerproduct_matmul(%[[ZERO]], %[[ZERO]], %[[ZERO]], %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.i64, !llvm.i64, !llvm.i64, !llvm<"{ <4 x float>*, <4 x float>*, i64, [2 x i64], [2 x i64] }*">, !llvm<"{ <4 x float>*, <4 x float>*, i64, [2 x i64], [2 x i64] }*">, !llvm<"{ [4 x <4 x float>]*, [4 x <4 x float>]*, i64, [2 x i64], [2 x i64] }*">) -> () + +func @reshape_static(%arg0: memref<3x4x5xf32>) { + // Reshapes that expand and collapse back a contiguous tensor with some 1's. + %0 = linalg.reshape %arg0 [(i, j, k, l, m) -> (i, j), + (i, j, k, l, m) -> (k), + (i, j, k, l, m) -> (l, m)] : + memref<3x4x5xf32> into memref<1x3x4x1x5xf32> + %r0 = linalg.reshape %0 [(i, j, k, l, m) -> (i, j), + (i, j, k, l, m) -> (k), + (i, j, k, l, m) -> (l, m)] : + memref<1x3x4x1x5xf32> into memref<3x4x5xf32> + return +} +// CHECK-LABEL: func @reshape_static( +// CHECK: llvm.mlir.undef : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }"> +// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }"> +// CHECK: llvm.insertvalue {{.*}}[0] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }"> +// CHECK: llvm.extractvalue {{.*}}[1] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }"> +// CHECK: llvm.insertvalue {{.*}}[1] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }"> +// CHECK: llvm.extractvalue {{.*}}[2] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }"> +// CHECK: llvm.insertvalue {{.*}}[2] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }"> +// CHECK: llvm.mlir.constant(1 : index) : !llvm.i64 +// CHECK: llvm.insertvalue {{.*}}[3, 0] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }"> +// CHECK: llvm.mlir.constant(3 : index) : !llvm.i64 +// CHECK: llvm.insertvalue {{.*}}[3, 1] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }"> +// CHECK: llvm.mlir.constant(4 : index) : !llvm.i64 +// CHECK: llvm.insertvalue {{.*}}[3, 2] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }"> +// CHECK: llvm.mlir.constant(1 : index) : !llvm.i64 +// CHECK: llvm.insertvalue {{.*}}[3, 3] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }"> +// CHECK: llvm.mlir.constant(5 : index) : !llvm.i64 +// CHECK: llvm.insertvalue {{.*}}[3, 4] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }"> +// CHECK: llvm.mlir.constant(60 : index) : !llvm.i64 +// CHECK: llvm.insertvalue {{.*}}[4, 0] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }"> +// CHECK: llvm.mlir.constant(20 : index) : !llvm.i64 +// CHECK: llvm.insertvalue {{.*}}[4, 1] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }"> +// CHECK: llvm.mlir.constant(5 : index) : !llvm.i64 +// CHECK: llvm.insertvalue {{.*}}[4, 2] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }"> +// CHECK: llvm.mlir.constant(5 : index) : !llvm.i64 +// CHECK: llvm.insertvalue {{.*}}[4, 3] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }"> +// CHECK: llvm.mlir.constant(1 : index) : !llvm.i64 +// CHECK: llvm.insertvalue {{.*}}[4, 4] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }"> +// CHECK: llvm.mlir.undef : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }"> +// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }"> +// CHECK: llvm.insertvalue {{.*}}[0] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }"> +// CHECK: llvm.extractvalue {{.*}}[1] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }"> +// CHECK: llvm.insertvalue {{.*}}[1] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }"> +// CHECK: llvm.extractvalue {{.*}}[2] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }"> +// CHECK: llvm.insertvalue {{.*}}[2] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }"> +// CHECK: llvm.mlir.constant(3 : index) : !llvm.i64 +// CHECK: llvm.insertvalue {{.*}}[3, 0] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }"> +// CHECK: llvm.mlir.constant(4 : index) : !llvm.i64 +// CHECK: llvm.insertvalue {{.*}}[3, 1] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }"> +// CHECK: llvm.mlir.constant(5 : index) : !llvm.i64 +// CHECK: llvm.insertvalue {{.*}}[3, 2] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }"> +// CHECK: llvm.mlir.constant(20 : index) : !llvm.i64 +// CHECK: llvm.insertvalue {{.*}}[4, 0] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }"> +// CHECK: llvm.mlir.constant(5 : index) : !llvm.i64 +// CHECK: llvm.insertvalue {{.*}}[4, 1] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }"> +// CHECK: llvm.mlir.constant(1 : index) : !llvm.i64 +// CHECK: llvm.insertvalue {{.*}}[4, 2] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }"> -- 2.7.4