From 69ddee1d2aadaa0b9ac4549f366d1bf5701a65f0 Mon Sep 17 00:00:00 2001 From: Hanhan Wang Date: Tue, 31 Mar 2020 21:21:33 -0700 Subject: [PATCH] [mlir][Linalg] Introduce linalg.pooling_min/max/sum op. Summary: Performs an N-D pooling operation similarly to the description in the TF documentation: https://www.tensorflow.org/api_docs/python/tf/nn/pool Different from the description, this operation doesn't perform on batch and channel. It only takes tensors of rank `N`. ``` output[x[0], ..., x[N-1]] = REDUCE_{z[0], ..., z[N-1]} input[ x[0] * strides[0] - pad_before[0] + dilation_rate[0]*z[0], ... x[N-1]*strides[N-1] - pad_before[N-1] + dilation_rate[N-1]*z[N-1] ], ``` The required optional arguments are: - strides: an i64 array specifying the stride (i.e. step) for window loops. - dilations: an i64 array specifying the filter upsampling/input downsampling rate - padding: an i64 array of pairs (low, high) specifying the number of elements to pad along a dimension. If strides or dilations attributes are missing then the default value is one for each of the input dimensions. Similarly, padding values are zero for both low and high in each of the dimensions, if not specified. Differential Revision: https://reviews.llvm.org/D76414 --- mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h | 16 +- .../mlir/Dialect/Linalg/IR/LinalgStructuredOps.td | 168 ++++++++++++++++++--- .../mlir/Dialect/Utils/StructuredOpsUtils.h | 9 ++ mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp | 21 ++- mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 81 ++++++++-- .../Dialect/Linalg/Transforms/LinalgToLoops.cpp | 71 +++++++++ mlir/test/Dialect/Linalg/invalid.mlir | 11 ++ mlir/test/Dialect/Linalg/loops.mlir | 70 +++++++++ mlir/test/Dialect/Linalg/roundtrip.mlir | 42 ++++++ 9 files changed, 444 insertions(+), 45 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h index 7756a08..77d9d9f 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h @@ -29,6 +29,9 @@ namespace mlir { namespace linalg { class ConvOp; +class PoolingMaxOp; +class PoolingMinOp; +class PoolingSumOp; /// Returns the name mangled library call name to disambiguate between different /// overloads at the C level. The name mangling scheme is basic and uses MLIR @@ -60,12 +63,13 @@ std::string generateLibraryCallName(Operation *op); SmallVector makeAffineDimExprs(unsigned num, unsigned &startIdx, MLIRContext *context); -/// Builds the indexing expressions for a ConvOp `op`. Returns the vector of -/// AffineMaps representing: -/// `stride[i] * xs[i] + dilation[i] * zs[i] - pad_low[i]` -SmallVector weightedConvInputIndex(ConvOp op, - ArrayRef xs, - ArrayRef zs); +/// Builds the indexing expressions for a ConvOp/PoolingOp `op`. Returns the +/// vector of AffineMaps representing: +/// `stride[i] * outputDims[i] + dilation[i] * windowDims[i] - pad_low[i]` +template +extern SmallVector +weightedPoolingInputIndex(PoolingOp op, ArrayRef outputDims, + ArrayRef windowDims); /// Returns `maybeMap.get()` if `maybeMap` is set, otherwise returns the /// symbol-less identity map of `rank`. diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td index ab53fc3..31b89bc 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -251,7 +251,69 @@ def MatmulOp : LinalgStructured_Op<"matmul", [NInputs<2>, NOutputs<1>]> { let hasFolder = 1; } -def ConvOp : LinalgStructured_Op<"conv", [NInputs<2>, NOutputs<1>]> { +/// A base class for pooling operation such as conv. The arguments must contain +/// optional arguments `strides`, `dilations` and `padding` with following type: +/// OptionalAttr:$strides +/// OptionalAttr:$dilations +/// OptionalAttr:$padding +/// `stirdes` denotes the step of each window along the dimension. +class PoolingBase_Op props> + : LinalgStructured_Op { + let description = [{ + Performs an N-D pooling operation similarly to the description in the TF + documentation: + https://www.tensorflow.org/api_docs/python/tf/nn/pool + + Different from the description, this operation doesn't perform on batch and + channel. It only takes tensors of rank `N`. + + ``` + output[x[0], ..., x[N-1]] = + REDUCE_{z[0], ..., z[N-1]} + input[ + x[0] * strides[0] - pad_before[0] + dilation_rate[0]*z[0], + ... + x[N-1]*strides[N-1] - pad_before[N-1] + dilation_rate[N-1]*z[N-1] + ], + ``` + + The required optional arguments are: + - strides: an i64 array specifying the stride (i.e. step) for window + loops. + - dilations: an i64 array specifying the filter upsampling/input + downsampling rate + - padding: an i64 array of pairs (low, high) specifying the number of + elements to pad along a dimension. + + If strides or dilations attributes are missing then the default value is + one for each of the input dimensions. Similarly, padding values are zero + for both low and high in each of the dimensions, if not specified. + }]; + + code commonUtils = libraryCallName # [{ + int64_t getStride(unsigned i) { + assert(i < getNumWindowLoops()); + if (!strides().hasValue()) return 1; + return strides()->getValue()[i] + .cast().getValue().getSExtValue(); + } + + int64_t getDilation(unsigned i) { + assert(i < getNumWindowLoops()); + if (!dilations().hasValue()) return 1; + return dilations()->getValue()[i] + .cast().getValue().getSExtValue(); + } + + int64_t getLowPad(unsigned i) { + assert(i < getNumWindowLoops()); + if (!padding().hasValue()) return 0; + return padding().getValue().getValue({i, 0}); + } + }]; +} + +def ConvOp : PoolingBase_Op<"conv", [NInputs<2>, NOutputs<1>]> { let description = [{ Generic n-D convolution as described in the TF documentation: @@ -282,7 +344,7 @@ def ConvOp : LinalgStructured_Op<"conv", [NInputs<2>, NOutputs<1>]> { OptionalAttr:$dilations, OptionalAttr:$padding); - let extraClassDeclaration = libraryCallName # [{ + let extraClassDeclaration = commonUtils # [{ // TODO(ntv) extend to support more than 1 dimensions and potentially // grouping too. unsigned getNumBatchDimensions() { return 1; } @@ -309,26 +371,6 @@ def ConvOp : LinalgStructured_Op<"conv", [NInputs<2>, NOutputs<1>]> { return iters; } - int64_t getStride(unsigned i) { - assert(i < getNumWindowLoops()); - if (!strides().hasValue()) return 1; - return strides()->getValue()[i] - .cast().getValue().getSExtValue(); - } - - int64_t getDilation(unsigned i) { - assert(i < getNumWindowLoops()); - if (!dilations().hasValue()) return 1; - return dilations()->getValue()[i] - .cast().getValue().getSExtValue(); - } - - int64_t getLowPad(unsigned i) { - assert(i < getNumWindowLoops()); - if (!padding().hasValue()) return 0; - return padding().getValue().getValue({i, 0}); - } - // F(z0, ..., zN-1, q, k) * // I(b, x0 + z0 - pad_low_0, ..., xN-1 + zN-1 - pad_low_N-1, q) // -> O(b, x0, ..., xN-1, k) @@ -358,7 +400,7 @@ def ConvOp : LinalgStructured_Op<"conv", [NInputs<2>, NOutputs<1>]> { // Window reduction dims: sum_{z[0], ..., z[N-1], q} auto zs = makeAffineDimExprs(nWin, idx, context); // Construct the weighedSum expression. - auto ws = weightedConvInputIndex(*this, xs, zs); + auto ws = weightedPoolingInputIndex(*this, xs, zs); return SmallVector{ // filter[z[0], ..., z[N-1], q, k] AffineMap::get(idx, 0, concat(concat(zs, qs), ks)), @@ -378,6 +420,86 @@ def ConvOp : LinalgStructured_Op<"conv", [NInputs<2>, NOutputs<1>]> { let hasFolder = 1; } +class SingleInputPoolingBase_Op + : PoolingBase_Op, NOutputs<1>]> { + let description = [{ + A base class for single input pooling function. + + TODO: Figure out a better way to handle window dimensions, i.e., eliminate + the fake memref. + The window dimensions are specified by argument `windowDims`. The i-th + dimension in the shape of `windowDims` denotes the size of the window along + dimension i. For example, if the window size is 2x3, then a memref<2x3> + should be passed to the operation as `windowDims`. + }]; + + let arguments = (ins AnyStridedMemRef:$input, + AnyStridedMemRef:$windowDims, + AnyStridedMemRef:$output, + OptionalAttr:$strides, + OptionalAttr:$dilations, + OptionalAttr:$padding); + + let extraClassDeclaration = commonUtils# [{ + llvm::Optional> referenceIterators() { + // Outer parallel loops are always the number of output dimensions. + unsigned nPar = getOutputShapedType(0).getRank(); + // The window loops has the same number loops with output dimensions. + unsigned nWin = nPar; + SmallVector iters(nPar, getParallelIteratorTypeName()); + iters.reserve(nPar + nWin); + iters.append(nWin, getWindowIteratorTypeName()); + return iters; + } + + llvm::Optional> referenceIndexingMaps() { + MLIRContext *context = getContext(); + auto nPar = getNumParallelLoops(); + auto nWin = getNumWindowLoops(); + assert(nWin > 0 && "expected at least one window dimension"); + unsigned idx = 0; + auto outputDims = makeAffineDimExprs(nPar, idx, context); + auto windowDims = makeAffineDimExprs(nWin, idx, context); + // Construct the weighedSum expression. + auto inputDims = + weightedPoolingInputIndex(*this, outputDims, windowDims); + return SmallVector{ + // input + AffineMap::get(idx, 0, inputDims), + // windowDims + AffineMap::get(idx, 0, windowDims), + // output + AffineMap::get(idx, 0, outputDims) + }; + } + }]; + + let verifier = [{ return ::verify(*this); }]; + + let hasFolder = 1; +} + +def PoolingMaxOp: SingleInputPoolingBase_Op<"pooling_max"> { + let description = [{ + Takes max op as pooling operation, i.e., it samples the maximum value in the + window. + }]; +} + +def PoolingMinOp: SingleInputPoolingBase_Op<"pooling_min"> { + let description = [{ + Takes min op as pooling operation, i.e., it samples the minimum value in the + window. + }]; +} + +def PoolingSumOp: SingleInputPoolingBase_Op<"pooling_sum"> { + let description = [{ + Takes add op as pooling operation, i.e., it accumulates the values in the + window. + }]; +} + //===----------------------------------------------------------------------===// // Generic Linalg ops. //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h index d54791a..bb37bb2 100644 --- a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h +++ b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h @@ -72,6 +72,15 @@ constexpr StringRef getFunAttrName() { return "fun"; } /// function that implements the structured op. constexpr StringRef getLibraryCallAttrName() { return "library_call"; } +/// Attribute name for the StrArrayAttr which encodes the value of strides. +constexpr StringRef getStridesAttrName() { return "strides"; } + +/// Attribute name for the StrArrayAttr which encodes the value of dilations. +constexpr StringRef getDilationsAttrName() { return "dilations"; } + +/// Attribute name for the StrArrayAttr which encodes the value of paddings. +constexpr StringRef getPaddingAttrName() { return "padding"; } + /// Use to encode that a particular iterator type has parallel semantics. constexpr StringRef getParallelIteratorTypeName() { return "parallel"; } diff --git a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp index 577b134..b493aa6 100644 --- a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp +++ b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp @@ -524,12 +524,21 @@ populateLinalgToStandardConversionPatterns(OwningRewritePatternList &patterns, MLIRContext *ctx) { // TODO(ntv) ConvOp conversion needs to export a descriptor with relevant // attribute values such as kernel striding and dilation. - patterns.insert, - LinalgOpConversion, LinalgOpConversion, - LinalgOpConversion, LinalgOpConversion, - LinalgOpConversion, - LinalgOpConversion, LinalgOpConversion>( - ctx); + // clang-format off + patterns.insert< + CopyTransposeConversion, + LinalgOpConversion, + LinalgOpConversion, + LinalgOpConversion, + LinalgOpConversion, + LinalgOpConversion, + LinalgOpConversion, + LinalgOpConversion, + LinalgOpConversion, + LinalgOpConversion, + LinalgOpConversion, + LinalgOpConversion>(ctx); + // clang-format on } } // namespace diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index aa340e5..077b34c 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -140,7 +140,6 @@ static void printGenericOp(OpAsmPrinter &p, GenericOpType op) { p.printRegion(op.region()); p.printOptionalAttrDict(op.getAttrs(), attrNames); p << ": " << op.getOperandTypes(); - auto outputTensorTypes = op.getResultTypes(); if (!outputTensorTypes.empty()) p << " -> " << outputTensorTypes; @@ -827,8 +826,10 @@ static LogicalResult verify(CopyOp op) { return success(); } -static LogicalResult -verifyStrideOrDilation(ConvOp op, ArrayRef attrs, bool isStride) { +template +static LogicalResult verifyStrideOrDilation(LinalgPoolingOp op, + ArrayRef attrs, + bool isStride) { auto strideOrDilation = isStride ? "stride" : "dilation"; if (attrs.size() != op.getNumWindowLoops()) return op.emitOpError("expects num ") @@ -860,6 +861,41 @@ static LogicalResult verify(ConvOp op) { return success(); } +template +LogicalResult verifySingleInputPoolingOp(PoolingOp op) { + auto inputType = op.input().getType().template cast(); + auto outputType = op.output().getType().template cast(); + if (outputType.getElementType() != inputType.getElementType()) + return op.emitOpError("expects memref elemental types to match"); + + auto windowDimsType = op.windowDims().getType().template cast(); + if (outputType.getRank() != inputType.getRank() || + outputType.getRank() != windowDimsType.getRank()) + return op.emitOpError("expects memref ranks to match"); + + if (auto strides = op.strides()) { + if (failed( + verifyStrideOrDilation(op, strides->getValue(), /*isStride=*/true))) + return failure(); + } + if (auto dilations = op.dilations()) { + if (failed(verifyStrideOrDilation(op, dilations->getValue(), + /*isStride=*/false))) + return failure(); + } + return success(); +} + +static LogicalResult verify(PoolingMaxOp op) { + return verifySingleInputPoolingOp(op); +} +static LogicalResult verify(PoolingMinOp op) { + return verifySingleInputPoolingOp(op); +} +static LogicalResult verify(PoolingSumOp op) { + return verifySingleInputPoolingOp(op); +} + namespace mlir { namespace linalg { @@ -894,21 +930,34 @@ mlir::linalg::makeAffineDimExprs(unsigned num, unsigned &startIdx, return res; } +template SmallVector -mlir::linalg::weightedConvInputIndex(ConvOp op, ArrayRef xs, - ArrayRef zs) { - assert(xs.size() == zs.size()); +mlir::linalg::weightedPoolingInputIndex(PoolingOp op, + ArrayRef outputDims, + ArrayRef windowDims) { + assert(outputDims.size() == windowDims.size()); SmallVector res; - res.reserve(xs.size()); - for (unsigned i = 0, e = xs.size(); i < e; ++i) { + res.reserve(outputDims.size()); + for (unsigned i = 0, e = outputDims.size(); i < e; ++i) { // TODO(ntv): add a level of indirection to linalg.generic. - auto expr = - op.getStride(i) * xs[i] + op.getDilation(i) * zs[i] - op.getLowPad(i); + auto expr = op.getStride(i) * outputDims[i] + + op.getDilation(i) * windowDims[i] - op.getLowPad(i); res.push_back(expr); } return res; } +#define INSTANTIATE_WEIGHTED_POOLING_INPUT_INDEX(OP_TYPE) \ + template SmallVector \ + mlir::linalg::weightedPoolingInputIndex( \ + OP_TYPE op, ArrayRef outputDims, \ + ArrayRef windowDims); + +INSTANTIATE_WEIGHTED_POOLING_INPUT_INDEX(ConvOp) +INSTANTIATE_WEIGHTED_POOLING_INPUT_INDEX(PoolingMaxOp) +INSTANTIATE_WEIGHTED_POOLING_INPUT_INDEX(PoolingMinOp) +INSTANTIATE_WEIGHTED_POOLING_INPUT_INDEX(PoolingSumOp) + SmallVector mlir::linalg::concat(ArrayRef a, ArrayRef b) { auto rangeA = llvm::make_range(a.begin(), a.end()); @@ -959,6 +1008,18 @@ LogicalResult ConvOp::fold(ArrayRef, SmallVectorImpl &) { return foldMemRefCast(*this); } +LogicalResult PoolingMaxOp::fold(ArrayRef, + SmallVectorImpl &) { + return foldMemRefCast(*this); +} +LogicalResult PoolingMinOp::fold(ArrayRef, + SmallVectorImpl &) { + return foldMemRefCast(*this); +} +LogicalResult PoolingSumOp::fold(ArrayRef, + SmallVectorImpl &) { + return foldMemRefCast(*this); +} LogicalResult CopyOp::fold(ArrayRef, SmallVectorImpl &) { return foldMemRefCast(*this); diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp index eb2a881..ae589e7 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp @@ -106,6 +106,23 @@ static void inlineRegionAndEmitStdStore(OpType op, } } +// Returns a pair that contains input indices and output indices of a +// SingleInputPoolingOp `op`. +template +static std::pair, SmallVector> +getInputAndOutputIndices(ArrayRef allIvs, SingleInputPoolingOp op) { + auto &b = ScopedContext::getBuilder(); + auto loc = ScopedContext::getLocation(); + auto mapsRange = op.indexing_maps().template getAsRange(); + auto maps = + functional::map([](AffineMapAttr a) { return a.getValue(); }, mapsRange); + SmallVector iIdx( + makeCanonicalAffineApplies(b, loc, maps[0], allIvs)); + SmallVector oIdx( + makeCanonicalAffineApplies(b, loc, maps[2], allIvs)); + return {iIdx, oIdx}; +} + namespace { template class LinalgScopedEmitter {}; @@ -273,6 +290,57 @@ public: } }; +template +class LinalgScopedEmitter { +public: + static void emitScalarImplementation(ArrayRef allIvs, + PoolingMaxOp op) { + auto indices = getInputAndOutputIndices(allIvs, op); + ValueHandleArray iIdx(indices.first); + ValueHandleArray oIdx(indices.second); + + // Emit scalar form. + ValueHandle lhs = std_load(op.output(), oIdx); + ValueHandle rhs = std_load(op.input(), iIdx); + using edsc::op::operator>; + ValueHandle maxValue = std_select(lhs > rhs, lhs, rhs); + std_store(maxValue, op.output(), oIdx); + } +}; + +template +class LinalgScopedEmitter { +public: + static void emitScalarImplementation(ArrayRef allIvs, + PoolingMinOp op) { + auto indices = getInputAndOutputIndices(allIvs, op); + ValueHandleArray iIdx(indices.first); + ValueHandleArray oIdx(indices.second); + + // Emit scalar form. + ValueHandle lhs = std_load(op.output(), oIdx); + ValueHandle rhs = std_load(op.input(), iIdx); + using edsc::op::operator<; + ValueHandle minValue = std_select(lhs < rhs, lhs, rhs); + std_store(minValue, op.output(), oIdx); + } +}; + +template +class LinalgScopedEmitter { +public: + static void emitScalarImplementation(ArrayRef allIvs, + PoolingSumOp op) { + auto indices = getInputAndOutputIndices(allIvs, op); + SmallVector iIdx = indices.first; + SmallVector oIdx = indices.second; + IndexedValueType input(op.input()), output(op.output()); + + // Emit scalar form. + output(oIdx) += input(iIdx); + } +}; + // Emits the MLIR for the scalar part of the generic op by: // 1. Emitting std_load and std_store ops for each input and output // view in order. This is achieved by applying the appropriate input or @@ -688,6 +756,9 @@ INSTANTIATE_LINALG_OP_TO_LOOPS(DotOp) INSTANTIATE_LINALG_OP_TO_LOOPS(MatvecOp) INSTANTIATE_LINALG_OP_TO_LOOPS(MatmulOp) INSTANTIATE_LINALG_OP_TO_LOOPS(ConvOp) +INSTANTIATE_LINALG_OP_TO_LOOPS(PoolingMaxOp) +INSTANTIATE_LINALG_OP_TO_LOOPS(PoolingMinOp) +INSTANTIATE_LINALG_OP_TO_LOOPS(PoolingSumOp) INSTANTIATE_LINALG_OP_TO_LOOPS(GenericOp) INSTANTIATE_LINALG_OP_TO_LOOPS(IndexedGenericOp) diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir index 59e4a76..7a82915 100644 --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -513,3 +513,14 @@ func @reshape(%arg0: memref) { %0 = linalg.reshape %arg0 [affine_map<(i, j, k) -> (i, j)>, affine_map<(i, j, k) -> (k)>] : memref into memref (d0 * s0 + d1)>> } + +// ----- + +func @pooling_rank_mismatch(%arg0: memref, + %arg1: memref<2x3xf32>, + %arg2: memref) { + // expected-error @+1 {{expects memref ranks to match}} + linalg.pooling_max(%arg0, %arg1, %arg2) {strides = [2, 1, 2]}: + memref, memref<2x3xf32>, memref + return +} diff --git a/mlir/test/Dialect/Linalg/loops.mlir b/mlir/test/Dialect/Linalg/loops.mlir index c8d114b..1bd0cf6 100644 --- a/mlir/test/Dialect/Linalg/loops.mlir +++ b/mlir/test/Dialect/Linalg/loops.mlir @@ -9,6 +9,7 @@ // CHECK-DAG: #[[strided4D:.*]] = affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3 + d3)> // CHECK-DAG: #[[clampMinMap:.*]] = affine_map<(d0) -> (d0, 0)> +// CHECK-DAG: #[[Stride1Dilation1:.*]] = affine_map<(d0, d1) -> (d0 + d1)> // CHECK-DAG: #[[Stride2Dilation1:.*]] = affine_map<(d0, d1) -> (d0 * 2 + d1)> // CHECK-DAG: #[[Stride2Dilation4:.*]] = affine_map<(d0, d1) -> (d0 * 2 + d1 * 4)> // CHECK-DAG: #[[Stride3Dilation5:.*]] = affine_map<(d0, d1) -> (d0 * 3 + d1 * 5)> @@ -251,6 +252,75 @@ func @conv_padding(%arg0: memref, // CHECK: %{{.*}} = addf %{{.*}}, %{{.*}} : f32 // CHECK: store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}] : memref +func @pooling_max(%arg0: memref, + %arg1: memref, + %arg2: memref) { + linalg.pooling_max(%arg0, %arg1, %arg2) { strides = [2, 1] }: + memref, memref, memref + return +} +// CHECK-LABEL: func @pooling_max +// CHECK: %[[WX:.*]] = dim %arg1, 0 : memref +// CHECK: %[[WY:.*]] = dim %arg1, 1 : memref +// CHECK: %[[OX:.*]] = dim %arg2, 0 : memref +// CHECK: %[[OY:.*]] = dim %arg2, 1 : memref +// CHECK: loop.for %{{.*}} = %{{.*}} to %[[OX]] step %{{.*}} { +// CHECK: loop.for %{{.*}} = %{{.*}} to %[[OY]] step %{{.*}} { +// CHECK: loop.for %{{.*}} = %{{.*}} to %[[WX]] step %{{.*}} { +// CHECK: loop.for %{{.*}} = %{{.*}} to %[[WY]] step %{{.*}} { +// CHECK: %[[IX:.*]] = affine.apply #[[Stride2Dilation1]](%{{.*}}, %{{.*}}) +// CHECK: %[[IY:.*]] = affine.apply #[[Stride1Dilation1]](%{{.*}}, %{{.*}}) +// CHECK: %{{.*}} = load %{{.*}}[%{{.*}}, %{{.*}}] : memref +// CHECK: %{{.*}} = load %{{.*}}[%[[IX]], %[[IY]]] : memref +// CHECK: %[[RES:.*]] = select %{{.*}}, %{{.*}}, %{{.*}} : f32 +// CHECK: store %[[RES]], %{{.*}}[%{{.*}}, %{{.*}}] : memref + +func @pooling_min(%arg0: memref, + %arg1: memref, + %arg2: memref) { + linalg.pooling_min(%arg0, %arg1, %arg2) { strides = [2, 1] }: + memref, memref, memref + return +} +// CHECK-LABEL: func @pooling_min +// CHECK: %[[WX:.*]] = dim %arg1, 0 : memref +// CHECK: %[[WY:.*]] = dim %arg1, 1 : memref +// CHECK: %[[OX:.*]] = dim %arg2, 0 : memref +// CHECK: %[[OY:.*]] = dim %arg2, 1 : memref +// CHECK: loop.for %{{.*}} = %{{.*}} to %[[OX]] step %{{.*}} { +// CHECK: loop.for %{{.*}} = %{{.*}} to %[[OY]] step %{{.*}} { +// CHECK: loop.for %{{.*}} = %{{.*}} to %[[WX]] step %{{.*}} { +// CHECK: loop.for %{{.*}} = %{{.*}} to %[[WY]] step %{{.*}} { +// CHECK: %[[IX:.*]] = affine.apply #[[Stride2Dilation1]](%{{.*}}, %{{.*}}) +// CHECK: %[[IY:.*]] = affine.apply #[[Stride1Dilation1]](%{{.*}}, %{{.*}}) +// CHECK: %{{.*}} = load %{{.*}}[%{{.*}}, %{{.*}}] : memref +// CHECK: %{{.*}} = load %{{.*}}[%[[IX]], %[[IY]]] : memref +// CHECK: %[[RES:.*]] = select %{{.*}}, %{{.*}}, %{{.*}} : f32 +// CHECK: store %[[RES]], %{{.*}}[%{{.*}}, %{{.*}}] : memref + +func @pooling_sum(%arg0: memref, + %arg1: memref, + %arg2: memref) { + linalg.pooling_sum(%arg0, %arg1, %arg2) { strides = [2, 1] }: + memref, memref, memref + return +} +// CHECK-LABEL: func @pooling_sum +// CHECK: %[[WX:.*]] = dim %arg1, 0 : memref +// CHECK: %[[WY:.*]] = dim %arg1, 1 : memref +// CHECK: %[[OX:.*]] = dim %arg2, 0 : memref +// CHECK: %[[OY:.*]] = dim %arg2, 1 : memref +// CHECK: loop.for %{{.*}} = %{{.*}} to %[[OX]] step %{{.*}} { +// CHECK: loop.for %{{.*}} = %{{.*}} to %[[OY]] step %{{.*}} { +// CHECK: loop.for %{{.*}} = %{{.*}} to %[[WX]] step %{{.*}} { +// CHECK: loop.for %{{.*}} = %{{.*}} to %[[WY]] step %{{.*}} { +// CHECK: %[[IX:.*]] = affine.apply #[[Stride2Dilation1]](%{{.*}}, %{{.*}}) +// CHECK: %[[IY:.*]] = affine.apply #[[Stride1Dilation1]](%{{.*}}, %{{.*}}) +// CHECK: %[[RHS:.*]] = load %{{.*}}[%[[IX]], %[[IY]]] : memref +// CHECK: %[[LHS:.*]] = load %{{.*}}[%{{.*}}, %{{.*}}] : memref +// CHECK: %[[RES:.*]] = addf %[[LHS]], %[[RHS]] : f32 +// CHECK: store %[[RES]], %{{.*}}[%{{.*}}, %{{.*}}] : memref + func @foo(%0: f32, %1: f32, %2: f32) -> (f32, f32) { %f0 = constant 0.0 : f32 return %f0, %f0 : f32, f32 diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir index 468fad4..05d35f8 100644 --- a/mlir/test/Dialect/Linalg/roundtrip.mlir +++ b/mlir/test/Dialect/Linalg/roundtrip.mlir @@ -244,6 +244,48 @@ func @conv_padding(%arg0: memref, // ----- +func @pooling_max(%arg0: memref, + %arg1: memref, + %arg2: memref) { + linalg.pooling_max(%arg0, %arg1, %arg2) {strides = [2, 1, 2]}: + memref, memref, memref + return +} +// CHECK-LABEL: func @pooling_max +// CHECK: linalg.pooling_max(%{{.*}}, %{{.*}}, %{{.*}}) +// CHECK-SAME: {strides = [2, 1, 2]} +// CHECK-SAME: memref, memref, memref + +// ----- + +func @pooling_min(%arg0: memref, + %arg1: memref, + %arg2: memref) { + linalg.pooling_min(%arg0, %arg1, %arg2) {strides = [2, 1, 2]}: + memref, memref, memref + return +} +// CHECK-LABEL: func @pooling_min +// CHECK: linalg.pooling_min(%{{.*}}, %{{.*}}, %{{.*}}) +// CHECK-SAME: {strides = [2, 1, 2]} +// CHECK-SAME: memref, memref, memref + +// ----- + +func @pooling_sum(%arg0: memref, + %arg1: memref, + %arg2: memref) { + linalg.pooling_sum(%arg0, %arg1, %arg2) {strides = [2, 1, 2]}: + memref, memref, memref + return +} +// CHECK-LABEL: func @pooling_sum +// CHECK: linalg.pooling_sum(%{{.*}}, %{{.*}}, %{{.*}}) +// CHECK-SAME: {strides = [2, 1, 2]} +// CHECK-SAME: memref, memref, memref + +// ----- + // CHECK-DAG: #[[strided2D:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> // CHECK-DAG: #[[strided3D:.*]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2 + d2)> -- 2.7.4