/// Rewrites 2-D convolution ops with size-1 window dimensions into 1-D
/// convolution ops.
+template <typename Conv2DOp, typename Conv1DOp>
struct DownscaleSizeOneWindowed2DConvolution final
- : public OpRewritePattern<Conv2DNhwcHwcfOp> {
+ : public OpRewritePattern<Conv2DOp> {
DownscaleSizeOneWindowed2DConvolution(
MLIRContext *context,
LinalgTransformationFilter f = LinalgTransformationFilter(),
PatternBenefit benefit = 1)
- : OpRewritePattern<Conv2DNhwcHwcfOp>(context, benefit),
- filter(std::move(f)) {}
+ : OpRewritePattern<Conv2DOp>(context, benefit), filter(std::move(f)) {}
- FailureOr<Conv1DNwcWcfOp>
- returningMatchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp,
- PatternRewriter &rewriter) const;
+ FailureOr<Conv1DOp> returningMatchAndRewrite(Conv2DOp convOp,
+ PatternRewriter &rewriter) const;
- LogicalResult matchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp,
+ LogicalResult matchAndRewrite(Conv2DOp convOp,
PatternRewriter &rewriter) const override {
return returningMatchAndRewrite(convOp, rewriter);
}
transform::DecomposeOp::applyToOne(linalg::LinalgOp target,
SmallVectorImpl<Operation *> &results,
transform::TransformState &state) {
- FailureOr<LinalgOp> windowed =
- tryApply<DownscaleSizeOneWindowed2DConvolution>(target);
- if (succeeded(windowed)) {
- results.push_back(*windowed);
+ FailureOr<LinalgOp> windowedNhwc =
+ tryApply<DownscaleSizeOneWindowed2DConvolution<linalg::Conv2DNhwcHwcfOp,
+ Conv1DNwcWcfOp>>(target);
+ if (succeeded(windowedNhwc)) {
+ results.push_back(*windowedNhwc);
+ return DiagnosedSilenceableFailure(success());
+ }
+ FailureOr<LinalgOp> windowedNchw =
+ tryApply<DownscaleSizeOneWindowed2DConvolution<linalg::Conv2DNchwFchwOp,
+ Conv1DNcwFcwOp>>(target);
+ if (succeeded(windowedNchw)) {
+ results.push_back(*windowedNchw);
return DiagnosedSilenceableFailure(success());
}
FailureOr<LinalgOp> depthwise =
// and then turning back to named ops. But for now it's fine to have a few
// patterns matching special ops to get started.
-FailureOr<Conv1DNwcWcfOp>
-DownscaleSizeOneWindowed2DConvolution::returningMatchAndRewrite(
- linalg::Conv2DNhwcHwcfOp convOp, PatternRewriter &rewriter) const {
+template <typename Conv2DOp, typename Conv1DOp>
+FailureOr<Conv1DOp> DownscaleSizeOneWindowed2DConvolution<Conv2DOp, Conv1DOp>::
+ returningMatchAndRewrite(Conv2DOp convOp, PatternRewriter &rewriter) const {
if (failed(filter.checkAndNotify(rewriter, convOp)))
return failure();
if (convOp.hasBufferSemantics())
auto kernelShape = kernelType.getShape();
auto outputShape = outputType.getShape();
+ // Get domain indices based on conv2D layout.
+ int khIndex, kwIndex, ohIndex, owIndex;
+
+ TypeSwitch<Operation *>(convOp)
+ .Case([&](linalg::Conv2DNhwcHwcfOp op) {
+ khIndex = 0;
+ kwIndex = 1;
+ ohIndex = 1;
+ owIndex = 2;
+ })
+ .Case([&](linalg::Conv2DNchwFchwOp op) {
+ khIndex = 2;
+ kwIndex = 3;
+ ohIndex = 2;
+ owIndex = 3;
+ })
+ .Default([&](Operation *op) {
+ llvm_unreachable("unexpected conv2d operation.");
+ });
+
// Only handle the case where at least one of the window dimensions is
// of size 1. Other cases can rely on tiling to reduce to such cases.
- int64_t khSize = kernelShape[0], kwSize = kernelShape[1];
- int64_t ohSize = outputShape[1], owSize = outputShape[2];
+ int64_t khSize = kernelShape[khIndex], kwSize = kernelShape[kwIndex];
+ int64_t ohSize = outputShape[ohIndex], owSize = outputShape[owIndex];
bool removeH = (khSize == 1 && ohSize == 1);
bool removeW = (kwSize == 1 && owSize == 1);
if (!removeH && !removeW)
// dimension.
using RTTBuilder = RankedTensorType::Builder;
RankedTensorType newInputType =
- RTTBuilder(inputType).dropDim((removeH ? 1 : 2));
+ RTTBuilder(inputType).dropDim((removeH ? ohIndex : owIndex));
RankedTensorType newKernelType =
- RTTBuilder(kernelType).dropDim((removeH ? 0 : 1));
+ RTTBuilder(kernelType).dropDim((removeH ? khIndex : kwIndex));
RankedTensorType newOutputType =
- RTTBuilder(outputType).dropDim(removeH ? 1 : 2);
+ RTTBuilder(outputType).dropDim((removeH ? ohIndex : owIndex));
// Rank-reduce operands.
Location loc = convOp.getLoc();
// Rank-reduce strides and dilations too.
// TODO: dropDim 1-liner helper.
- auto strides = llvm::to_vector<4>(convOp.getStrides().getValues<int64_t>());
+ auto strides =
+ llvm::to_vector<4>(convOp.getStrides().template getValues<int64_t>());
strides.erase(strides.begin() + (removeH ? 0 : 1));
auto stridesAttr = rewriter.getI64VectorAttr(strides);
auto dilations =
- llvm::to_vector<4>(convOp.getDilations().getValues<int64_t>());
+ llvm::to_vector<4>(convOp.getDilations().template getValues<int64_t>());
dilations.erase(dilations.begin() + (removeH ? 0 : 1));
auto dilationsAttr = rewriter.getI64VectorAttr(dilations);
- auto conv1DOp = rewriter.create<linalg::Conv1DNwcWcfOp>(
+ auto conv1DOp = rewriter.create<Conv1DOp>(
loc, newOutputType, ValueRange{newInput, newKernel},
ValueRange{newOutput}, stridesAttr, dilationsAttr);
void linalg::populateDecomposeConvolutionPatterns(
RewritePatternSet &patterns, const LinalgTransformationFilter &filter,
PatternBenefit benefit) {
- patterns.add<DownscaleSizeOneWindowed2DConvolution,
+ patterns.add<DownscaleSizeOneWindowed2DConvolution<linalg::Conv2DNhwcHwcfOp,
+ Conv1DNwcWcfOp>,
+ DownscaleSizeOneWindowed2DConvolution<linalg::Conv2DNchwFchwOp,
+ Conv1DNcwFcwOp>,
DownscaleDepthwiseConv2DNhwcHwcOp>(patterns.getContext(), filter,
benefit);
}
return %0 : tensor<?x1x?x?xf32>
}
+// CHECK-LABEL: @conv_2d_nchw_fchw
+// CHECK-SAME: (%[[ARG0:[0-9a-z]+]]: tensor<?x?x1x?xf32>,
+// CHECK-SAME: %[[ARG1:[0-9a-z]+]]: tensor<?x?x1x?xf32>,
+// CHECK-SAME: %[[ARG2:[0-9a-z]+]]: tensor<?x?x1x?xf32>)
+func.func @conv_2d_nchw_fchw(%input: tensor<?x?x1x?xf32>, %filter: tensor<?x?x1x?xf32>, %init: tensor<?x?x1x?xf32>) -> tensor<?x?x1x?xf32> {
+ // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
+ // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
+ // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
+ // CHECK: %[[SLICERES:.+]] = linalg.conv_1d_ncw_fcw
+ // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
+ %0 = linalg.conv_2d_nchw_fchw {dilations = dense<1> : tensor<2xi64>,
+ strides = dense<1> : tensor<2xi64>}
+ ins (%input, %filter: tensor<?x?x1x?xf32>, tensor<?x?x1x?xf32>)
+ outs (%init: tensor<?x?x1x?xf32>) -> tensor<?x?x1x?xf32>
+ // CHECK: return %[[RES]]
+ return %0 : tensor<?x?x1x?xf32>
+}
+
// CHECK-LABEL: @depthwise_conv_2d_nhwc_hwc
// CHECK-SAME: %[[ARG0:.+]]: tensor<1x1x113x96xf32>
// CHECK-SAME: %[[ARG1:.+]]: tensor<1x3x96xf32>