return conv1DOp;
}
+FailureOr<Conv1DOp>
+DownscaleConv2DOp::returningMatchAndRewrite(Conv2DOp convOp,
+ PatternRewriter &rewriter) const {
+ if (convOp.hasBufferSemantics())
+ return failure(); // To be implemented.
+
+ Value input = convOp.getInputs().front();
+ Value kernel = convOp.getInputs().back();
+ Value output = convOp.getOutputs().front();
+
+ auto inputType = input.getType().dyn_cast<RankedTensorType>();
+ auto kernelType = kernel.getType().dyn_cast<RankedTensorType>();
+ auto outputType = output.getType().dyn_cast<RankedTensorType>();
+
+ auto kernelShape = kernelType.getShape();
+ auto outputShape = outputType.getShape();
+
+ // Only handle the case where at least one of the window dimensions is
+ // of size 1. Other cases can rely on tiling to reduce to such cases.
+ int64_t khSize = kernelShape[0], kwSize = kernelShape[1];
+ int64_t ohSize = outputShape[0], owSize = outputShape[1];
+ bool removeH = (khSize == 1 && ohSize == 1);
+ bool removeW = (kwSize == 1 && owSize == 1);
+ if (!removeH && !removeW)
+ return failure();
+
+ // Get new shapes and types for all operands by removing the size-1
+ // dimension.
+ using RTTBuilder = RankedTensorType::Builder;
+ RankedTensorType newInputType =
+ RTTBuilder(inputType).dropDim((removeH ? 0 : 1));
+ RankedTensorType newKernelType =
+ RTTBuilder(kernelType).dropDim((removeH ? 0 : 1));
+ RankedTensorType newOutputType =
+ RTTBuilder(outputType).dropDim(removeH ? 0 : 1);
+
+ // Rank-reduce operands.
+ Location loc = convOp.getLoc();
+ Value newInput = tensor::createCanonicalRankReducingExtractSliceOp(
+ rewriter, loc, input, newInputType);
+ Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp(
+ rewriter, loc, kernel, newKernelType);
+ Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp(
+ rewriter, loc, output, newOutputType);
+
+ auto conv1DOp = rewriter.create<Conv1DOp>(loc, newOutputType,
+ ValueRange{newInput, newKernel},
+ ValueRange{newOutput});
+
+ // Insert back.
+ Value inserted = tensor::createCanonicalRankReducingInsertSliceOp(
+ rewriter, loc, conv1DOp.getResult(0), output);
+ rewriter.replaceOp(convOp, inserted);
+
+ return conv1DOp;
+}
+
void linalg::populateDecomposeConvolutionPatterns(RewritePatternSet &patterns,
PatternBenefit benefit) {
patterns.add<DownscaleSizeOneWindowed2DConvolution<linalg::Conv2DNhwcHwcfOp,
Conv1DNwcWcfOp>,
DownscaleSizeOneWindowed2DConvolution<linalg::Conv2DNchwFchwOp,
Conv1DNcwFcwOp>,
- DownscaleDepthwiseConv2DNhwcHwcOp>(patterns.getContext(),
- benefit);
+ DownscaleDepthwiseConv2DNhwcHwcOp, DownscaleConv2DOp>(
+ patterns.getContext(), benefit);
patterns.add<
DownscaleSizeOneWindowed2DConvolution<PoolingNhwcSumOp, PoolingNwcSumOp>,
DownscaleSizeOneWindowed2DConvolution<PoolingNchwSumOp, PoolingNcwSumOp>,
return %0: tensor<1x1x56x96xf32>
}
+// CHECK-LABEL: @conv_2d
+// CHECK-SAME: (%[[ARG0:[0-9a-z]+]]: tensor<1x?xf32>,
+// CHECK-SAME: %[[ARG1:[0-9a-z]+]]: tensor<1x?xf32>,
+// CHECK-SAME: %[[ARG2:[0-9a-z]+]]: tensor<1x?xf32>)
+func.func @conv_2d(%input: tensor<1x?xf32>, %filter: tensor<1x?xf32>, %init: tensor<1x?xf32>) -> tensor<1x?xf32> {
+ // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
+ // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
+ // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
+ // CHECK: %[[SLICERES:.+]] = linalg.conv_1d
+ // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
+ %0 = linalg.conv_2d
+ ins (%input, %filter: tensor<1x?xf32>, tensor<1x?xf32>)
+ outs (%init: tensor<1x?xf32>) -> tensor<1x?xf32>
+ // CHECK: return %[[RES]]
+ return %0 : tensor<1x?xf32>
+}
+
// CHECK-LABEL: @pooling_nhwc_sum
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x1x?x?xf32>,
// CHECK-SAME: %[[ARG1:.+]]: tensor<1x?xf32>