--- /dev/null
+//===- TosaDecomposeTransposeConv.cpp
+//------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Insert reshape to binary op's input if needed to match rank
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Tosa/IR//TosaOps.h"
+#include "mlir/Dialect/Tosa/Transforms/PassDetail.h"
+#include "mlir/Dialect/Tosa/Transforms/Passes.h"
+#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+using namespace mlir::tosa;
+
+namespace {
+
+template <typename T>
+static void getValuesFromIntArrayAttribute(ArrayAttr attr,
+ SmallVector<T> &arrayValues) {
+ for (Attribute val : attr.getValue()) {
+ arrayValues.push_back(val.cast<IntegerAttr>().getValue().getSExtValue());
+ }
+}
+
+template <typename TosaOp, typename... Args>
+TosaOp CreateOpAndInfer(PatternRewriter &rewriter, Location loc, Type result_ty,
+ Args &&...args) {
+ auto op = rewriter.create<TosaOp>(loc, result_ty, args...);
+
+ InferShapedTypeOpInterface shapeInterface =
+ dyn_cast<InferShapedTypeOpInterface>(op.getOperation());
+ if (!shapeInterface)
+ return op;
+
+ SmallVector<ShapedTypeComponents> returnedShapes;
+ if (shapeInterface
+ .inferReturnTypeComponents(op.getContext(), op.getLoc(),
+ op->getOperands(), op->getAttrDictionary(),
+ op->getRegions(), returnedShapes)
+ .failed())
+ return op;
+
+ // We need to use the element type of the existing result type to generate
+ // the new result shaped type. This is because rescale can include a cast to
+ // different bit-width types and does not have a TypeAttr to define the
+ // target type.
+ auto result = op->getResult(0);
+ auto predictedShape = returnedShapes[0];
+ auto currentKnowledge =
+ mlir::tosa::ValueKnowledge::getKnowledgeFromType(result_ty);
+
+ // Compute the knowledge based on the inferred type.
+ auto inferredKnowledge =
+ mlir::tosa::ValueKnowledge::getPessimisticValueState();
+ inferredKnowledge.dtype = result_ty.cast<ShapedType>().getElementType();
+ inferredKnowledge.hasRank = predictedShape.hasRank();
+ if (predictedShape.hasRank()) {
+ for (auto dim : predictedShape.getDims()) {
+ inferredKnowledge.sizes.push_back(dim);
+ }
+ }
+
+ // Compute the new type based on the joined version.
+ auto newKnowledge =
+ mlir::tosa::ValueKnowledge::join(currentKnowledge, inferredKnowledge);
+ auto new_ty = newKnowledge.getType();
+ result.setType(new_ty);
+ return op;
+}
+
+class TransposeConvDilatedConverter
+ : public OpRewritePattern<tosa::TransposeConv2DOp> {
+public:
+ using OpRewritePattern<tosa::TransposeConv2DOp>::OpRewritePattern;
+ LogicalResult matchAndRewrite(tosa::TransposeConv2DOp op,
+ PatternRewriter &rewriter) const final {
+ Location loc = op->getLoc();
+ Value input = op->getOperand(0);
+ Value weight = op->getOperand(1);
+ Value bias = op->getOperand(2);
+
+ ShapedType inputTy = input.getType().cast<ShapedType>();
+ ShapedType weightTy = weight.getType().cast<ShapedType>();
+ ShapedType biasTy = bias.getType().cast<ShapedType>();
+ ShapedType resultTy = op->getResult(0).getType().cast<ShapedType>();
+
+ llvm::SmallVector<int64_t> pad;
+ llvm::SmallVector<int64_t> stride;
+ llvm::SmallVector<int64_t> dilation;
+
+ getValuesFromIntArrayAttribute(op.out_pad().cast<ArrayAttr>(), pad);
+ getValuesFromIntArrayAttribute(op.stride().cast<ArrayAttr>(), stride);
+ getValuesFromIntArrayAttribute(op.dilation().cast<ArrayAttr>(), dilation);
+
+ // If striding is all 1 we can modify padding and reverse the kernel along
+ // the x/y direction to make it a regular convolution. This is much simpler
+ // then handling striding....
+ if (llvm::any_of(stride, [](int64_t v) { return v != 1; }))
+ return failure();
+
+ if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() ||
+ !biasTy.hasStaticShape() || !resultTy.hasStaticShape())
+ return failure();
+
+ int64_t kernelHeight = (weightTy.getDimSize(1) - 1) * dilation[0] + 1;
+ int64_t kernelWidth = (weightTy.getDimSize(2) - 1) * dilation[1] + 1;
+ int64_t requiredInputHeight = resultTy.getDimSize(1) + kernelHeight - 1;
+ int64_t requiredInputWidth = resultTy.getDimSize(2) + kernelWidth - 1;
+
+ llvm::SmallVector<int64_t> convPad(4, 0);
+ convPad[0] = kernelHeight - 1 - pad[0];
+ convPad[2] = kernelWidth - 1 - pad[1];
+ convPad[1] = requiredInputHeight - convPad[0] - inputTy.getDimSize(1);
+ convPad[3] = requiredInputWidth - convPad[2] - inputTy.getDimSize(2);
+
+ auto reverse1 = rewriter.create<tosa::ReverseOp>(
+ loc, weightTy, weight, rewriter.getI64IntegerAttr(1));
+ auto reverse2 = rewriter.create<tosa::ReverseOp>(
+ loc, weightTy, reverse1, rewriter.getI64IntegerAttr(2));
+
+ Value conv2d;
+ if (op.quantization_info().hasValue()) {
+ conv2d = rewriter.create<tosa::Conv2DOp>(
+ loc, resultTy, input, reverse2, bias,
+ rewriter.getI64ArrayAttr(convPad), rewriter.getI64ArrayAttr(stride),
+ rewriter.getI64ArrayAttr(dilation),
+ op.quantization_info().getValue());
+ } else {
+ conv2d = rewriter.create<tosa::Conv2DOp>(
+ loc, resultTy, input, reverse2, bias,
+ rewriter.getI64ArrayAttr(convPad), rewriter.getI64ArrayAttr(stride),
+ rewriter.getI64ArrayAttr(dilation));
+ }
+
+ rewriter.replaceOp(op, conv2d);
+ return success();
+ }
+};
+
+class TransposeConvStridedConverter
+ : public OpRewritePattern<tosa::TransposeConv2DOp> {
+public:
+ using OpRewritePattern<tosa::TransposeConv2DOp>::OpRewritePattern;
+ LogicalResult matchAndRewrite(tosa::TransposeConv2DOp op,
+ PatternRewriter &rewriter) const final {
+ Location loc = op->getLoc();
+ Value input = op->getOperand(0);
+ Value weight = op->getOperand(1);
+ Value bias = op->getOperand(2);
+
+ ShapedType inputTy = input.getType().cast<ShapedType>();
+ ShapedType weightTy = weight.getType().cast<ShapedType>();
+ ShapedType biasTy = bias.getType().cast<ShapedType>();
+ ShapedType resultTy = op->getResult(0).getType().cast<ShapedType>();
+
+ Type inputETy = inputTy.getElementType();
+ Type weightETy = weightTy.getElementType();
+ Type biasETy = biasTy.getElementType();
+ Type resultETy = resultTy.getElementType();
+
+ llvm::SmallVector<int64_t> pad;
+ llvm::SmallVector<int64_t> stride;
+ llvm::SmallVector<int64_t> dilation;
+
+ getValuesFromIntArrayAttribute(op.out_pad().cast<ArrayAttr>(), pad);
+ getValuesFromIntArrayAttribute(op.stride().cast<ArrayAttr>(), stride);
+ getValuesFromIntArrayAttribute(op.dilation().cast<ArrayAttr>(), dilation);
+
+ // If striding is all 1 we can modify padding and reverse the kernel along
+ // the x/y direction to make it a regular convolution. This is much simpler
+ // then handling striding....
+ if (llvm::any_of(dilation, [](int64_t v) { return v != 1; }))
+ return failure();
+
+ // If strides are all 1 we dont need to use this one.
+ if (llvm::all_of(stride, [](int64_t v) { return v == 1; }))
+ return failure();
+
+ if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() ||
+ !biasTy.hasStaticShape() || !resultTy.hasStaticShape())
+ return failure();
+
+ int64_t batch = inputTy.getDimSize(0);
+
+ int64_t outputChannels = weightTy.getDimSize(0);
+ int64_t weightHeight = weightTy.getDimSize(1);
+ int64_t weightWidth = weightTy.getDimSize(2);
+ int64_t inputChannels = weightTy.getDimSize(3);
+
+ // Pad the weight so that it is modulo of the striding.
+ llvm::SmallVector<int32_t, 8> weightPadding = {0, 0, 0, 0, 0, 0, 0, 0};
+ weightPadding[3] =
+ weightHeight % stride[0] ? stride[0] - weightHeight % stride[0] : 0;
+ weightPadding[5] =
+ weightWidth % stride[1] ? stride[1] - weightWidth % stride[1] : 0;
+ DenseElementsAttr weightPaddingAttr = DenseIntElementsAttr::get(
+ RankedTensorType::get({4, 2}, rewriter.getI32Type()), weightPadding);
+ Value weightPaddingVal = CreateOpAndInfer<tosa::ConstOp>(
+ rewriter, loc, weightPaddingAttr.getType(), weightPaddingAttr);
+
+ if (op.quantization_info().hasValue()) {
+ auto quantInfo = op.quantization_info().getValue();
+ weight = CreateOpAndInfer<tosa::PadOp>(
+ rewriter, loc, UnrankedTensorType::get(weightETy), weight,
+ weightPaddingVal, nullptr,
+ PadOpQuantizationAttr::get(quantInfo.weight_zp(),
+ rewriter.getContext()));
+
+ } else {
+ weight = CreateOpAndInfer<tosa::PadOp>(rewriter, loc,
+ UnrankedTensorType::get(weightETy),
+ weight, weightPaddingVal);
+ }
+
+ weightTy = weight.getType().cast<ShapedType>();
+ weightHeight = weightTy.getDimSize(1);
+ weightWidth = weightTy.getDimSize(2);
+
+ // Split out the width / height by the stride dimensions.
+ llvm::SmallVector<int64_t, 6> weightReshapeDims0 = {
+ outputChannels, weightHeight / stride[0],
+ stride[0], weightWidth / stride[1],
+ stride[1], inputChannels};
+ weight = CreateOpAndInfer<tosa::ReshapeOp>(
+ rewriter, loc, UnrankedTensorType::get(weightETy), weight,
+ rewriter.getI64ArrayAttr(weightReshapeDims0));
+
+ // Transpose the factored-out stride to the output channels.
+ Value transposeWeightVal = rewriter.create<tosa::ConstOp>(
+ loc, RankedTensorType::get({6}, rewriter.getI32Type()),
+ rewriter.getI32TensorAttr({2, 4, 0, 1, 3, 5}));
+
+ weight = CreateOpAndInfer<tosa::TransposeOp>(
+ rewriter, loc, UnrankedTensorType::get(weightETy), weight,
+ transposeWeightVal);
+
+ // Collapse the strides and output channels into a single dimension.
+ llvm::SmallVector<int64_t, 6> weightReshapeDims1 = {
+ outputChannels * stride[0] * stride[1], weightHeight / stride[0],
+ weightWidth / stride[1], inputChannels};
+ weight = CreateOpAndInfer<tosa::ReshapeOp>(
+ rewriter, loc, UnrankedTensorType::get(weightETy), weight,
+ rewriter.getI64ArrayAttr(weightReshapeDims1));
+ ShapedType restridedWeightTy = weight.getType().cast<ShapedType>();
+
+ weight = CreateOpAndInfer<tosa::ReverseOp>(
+ rewriter, loc, UnrankedTensorType::get(weightETy), weight,
+ rewriter.getI64IntegerAttr(1));
+ weight = CreateOpAndInfer<tosa::ReverseOp>(
+ rewriter, loc, UnrankedTensorType::get(weightETy), weight,
+ rewriter.getI64IntegerAttr(2));
+
+ // We need to pad the input far enough that we can pull all values.
+ llvm::SmallVector<int32_t, 8> inputPadding = {0, 0, 0, 0, 0, 0, 0, 0};
+ inputPadding[2] += restridedWeightTy.getDimSize(1) - 1;
+ inputPadding[3] += restridedWeightTy.getDimSize(1) - 1;
+ inputPadding[4] += restridedWeightTy.getDimSize(2) - 1;
+ inputPadding[5] += restridedWeightTy.getDimSize(2) - 1;
+
+ DenseElementsAttr inputPaddingAttr = DenseIntElementsAttr::get(
+ RankedTensorType::get({4, 2}, rewriter.getI32Type()), inputPadding);
+
+ Value inputPaddingVal = CreateOpAndInfer<tosa::ConstOp>(
+ rewriter, loc, inputPaddingAttr.getType(), inputPaddingAttr);
+
+ if (op.quantization_info().hasValue()) {
+ auto quantInfo = op.quantization_info().getValue();
+ input = CreateOpAndInfer<tosa::PadOp>(
+ rewriter, loc, UnrankedTensorType::get(inputETy), input,
+ inputPaddingVal, nullptr,
+ PadOpQuantizationAttr::get(quantInfo.input_zp(),
+ rewriter.getContext()));
+ } else {
+ input = CreateOpAndInfer<tosa::PadOp>(rewriter, loc,
+ UnrankedTensorType::get(inputETy),
+ input, inputPaddingVal);
+ }
+
+ // We use a zero bias as we need to broadcast the bias.
+ auto zeroBias = rewriter.create<tosa::ConstOp>(
+ loc,
+ RankedTensorType::get({outputChannels * stride[0] * stride[1]},
+ biasETy),
+ DenseElementsAttr::get(
+ RankedTensorType::get({outputChannels * stride[0] * stride[1]},
+ biasETy),
+ rewriter.getZeroAttr(biasETy)));
+
+ // Perform the convolution using the zero bias.
+ Value conv2d;
+ if (op.quantization_info().hasValue()) {
+ conv2d = CreateOpAndInfer<tosa::Conv2DOp>(
+ rewriter, loc, UnrankedTensorType::get(resultETy), input,
+ weight, zeroBias,
+ /*pad=*/rewriter.getI64ArrayAttr({0, 0, 0, 0}),
+ /*stride=*/rewriter.getI64ArrayAttr({1, 1}),
+ /*dilation=*/rewriter.getI64ArrayAttr({1, 1}),
+ op.quantization_info().getValue())
+ .getResult();
+ } else {
+ conv2d = CreateOpAndInfer<tosa::Conv2DOp>(
+ rewriter, loc, UnrankedTensorType::get(resultETy), input,
+ weight, zeroBias,
+ /*pad=*/rewriter.getI64ArrayAttr({0, 0, 0, 0}),
+ /*stride=*/rewriter.getI64ArrayAttr({1, 1}),
+ /*dilation=*/rewriter.getI64ArrayAttr({1, 1}))
+ .getResult();
+ }
+
+ // Factor the resulting width / height.
+ ShapedType convTy = conv2d.getType().cast<ShapedType>();
+ Type convETy = convTy.getElementType();
+
+ int64_t convHeight = convTy.getDimSize(1);
+ int64_t convWidth = convTy.getDimSize(2);
+
+ // Factor striding out of the convolution result.
+ llvm::SmallVector<int64_t, 6> convReshapeDims0 = {
+ batch, convHeight, convWidth, stride[0], stride[1], outputChannels};
+ conv2d = CreateOpAndInfer<tosa::ReshapeOp>(
+ rewriter, loc, UnrankedTensorType::get(resultETy), conv2d,
+ rewriter.getI64ArrayAttr(convReshapeDims0));
+
+ // Transpose the factored-out stride to the output channels.
+ Value transposeConvVal = rewriter.create<tosa::ConstOp>(
+ loc, RankedTensorType::get({6}, rewriter.getI32Type()),
+ rewriter.getI32TensorAttr({0, 1, 3, 2, 4, 5}));
+
+ conv2d = CreateOpAndInfer<tosa::TransposeOp>(
+ rewriter, loc, UnrankedTensorType::get(convETy), conv2d,
+ transposeConvVal);
+
+ // Fuse striding behavior back into width / height.
+ llvm::SmallVector<int64_t, 6> convReshapeDims1 = {
+ batch, convHeight * stride[0], convWidth * stride[1], outputChannels};
+ conv2d = CreateOpAndInfer<tosa::ReshapeOp>(
+ rewriter, loc, UnrankedTensorType::get(resultETy), conv2d,
+ rewriter.getI64ArrayAttr(convReshapeDims1));
+
+ // Slice out the final result.
+ llvm::SmallVector<int64_t, 4> sliceBegin = {0, 0, 0, 0};
+ llvm::SmallVector<int64_t, 4> sliceSize(resultTy.getShape().begin(),
+ resultTy.getShape().begin());
+ sliceBegin[1] = pad[0];
+ sliceBegin[2] = pad[1];
+
+ auto slice = CreateOpAndInfer<tosa::SliceOp>(
+ rewriter, loc, UnrankedTensorType::get(resultETy), conv2d,
+ rewriter.getI64ArrayAttr(sliceBegin),
+ rewriter.getI64ArrayAttr(resultTy.getShape()))
+ .getResult();
+
+ auto addBias =
+ CreateOpAndInfer<tosa::AddOp>(rewriter, loc, op.getType(), slice, bias);
+
+ rewriter.replaceOp(op, addBias.getResult());
+
+ return success();
+ }
+};
+
+/// Pass that enables broadcast by making all input arrays have the same
+/// number of dimensions. Insert RESHAPE operations to lower rank operand
+struct TosaDecomposeTransposeConv
+ : public TosaDecomposeTransposeConvBase<TosaDecomposeTransposeConv> {
+public:
+ void runOnFunction() override {
+ auto func = getFunction();
+ RewritePatternSet patterns(func.getContext());
+ patterns
+ .insert<TransposeConvDilatedConverter, TransposeConvStridedConverter>(
+ func.getContext());
+ (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
+ }
+};
+} // end anonymous namespace
+
+std::unique_ptr<Pass> mlir::tosa::createTosaDecomposeTransposeConvPass() {
+ return std::make_unique<TosaDecomposeTransposeConv>();
+}
--- /dev/null
+// RUN: mlir-opt --split-input-file --tosa-decompose-transpose-conv %s | FileCheck %s
+
+// CHECK-LABEL: @transpose_conv2d
+func @transpose_conv2d(%arg0: tensor<2x16x14x3xf32>, %arg1: tensor<5x3x6x3xf32>, %arg2: tensor<5xf32>) -> tensor<2x?x?x5xf32> {
+ // CHECK: %[[REV1:.+]] = "tosa.reverse"(%arg1) {axis = 1 : i64}
+ // CHECK: %[[REV2:.+]] = "tosa.reverse"(%[[REV1]]) {axis = 2 : i64}
+ // CHECK: "tosa.conv2d"(%arg0, %[[REV2]], %arg2) {dilation = [1, 1], pad = [2, 2, 5, 5], stride = [1, 1]}
+ %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {dilation = [1, 1], out_pad = [0, 0], out_shape = [-1, -1, -1, -1], stride = [1, 1]} : (tensor<2x16x14x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>) -> tensor<2x18x19x5xf32>
+ %1 = tensor.cast %0 : tensor<2x18x19x5xf32> to tensor<2x?x?x5xf32>
+ return %1 : tensor<2x?x?x5xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @transpose_conv2d_quantized
+func @transpose_conv2d_quantized(%arg0: tensor<2x16x14x3xi8>, %arg1: tensor<5x3x6x3xi8>, %arg2: tensor<5xi32>) -> (tensor<2x18x19x5xi32>) {
+ // CHECK: %[[REV1:.+]] = "tosa.reverse"(%arg1) {axis = 1 : i64}
+ // CHECK: %[[REV2:.+]] = "tosa.reverse"(%[[REV1]]) {axis = 2 : i64}
+ // CHECK: "tosa.conv2d"(%arg0, %[[REV2]], %arg2) {dilation = [1, 1], pad = [2, 2, 5, 5], quantization_info = {input_zp = -22 : i32, weight_zp = 42 : i32}, stride = [1, 1]}
+ %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {dilation = [1, 1], out_pad = [0, 0], quantization_info = {input_zp = -22 : i32, weight_zp = 42 : i32}, out_shape = [-1, -1, -1, -1], stride = [1, 1]} : (tensor<2x16x14x3xi8>, tensor<5x3x6x3xi8>, tensor<5xi32>) -> tensor<2x18x19x5xi32>
+ return %0 : tensor<2x18x19x5xi32>
+}
+
+// ----
+
+// CHECK-LABEL: @transpose_conv2d_dilated
+func @transpose_conv2d_dilated(%arg0: tensor<2x16x14x3xf32>, %arg1: tensor<5x3x6x3xf32>, %arg2: tensor<5xf32>) -> tensor<2x?x?x5xf32> {
+ // CHECK: %[[REV1:.+]] = "tosa.reverse"(%arg1) {axis = 1 : i64}
+ // CHECK: %[[REV2:.+]] = "tosa.reverse"(%[[REV1]]) {axis = 2 : i64}
+ // CHECK: "tosa.conv2d"(%arg0, %[[REV2]], %arg2) {dilation = [2, 3], pad = [4, 4, 15, 15], stride = [1, 1]}
+ %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {dilation = [2, 3], out_pad = [0, 0], out_shape = [-1, -1, -1, -1], stride = [1, 1]} : (tensor<2x16x14x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>) -> tensor<2x20x29x5xf32>
+ %1 = tensor.cast %0 : tensor<2x20x29x5xf32> to tensor<2x?x?x5xf32>
+ return %1 : tensor<2x?x?x5xf32>
+}
+
+// ----
+
+// CHECK-LABEL: @transpose_conv2d_strided
+func @transpose_conv2d_strided(%arg0: tensor<2x17x15x3xf32>, %arg1: tensor<5x3x5x3xf32>, %arg2: tensor<5xf32>) -> tensor<2x?x?x5xf32> {
+ // Manipulate the weight matrix to handle striding.
+ // CHECK-DAG: %[[PADV:.+]] = "tosa.const"() {value = dense<{{\[\[}}0, 0], [0, 1], [0, 1], [0, 0]]> : tensor<4x2xi32>}
+ // CHECK-DAG: %[[TRANSV:.+]] = "tosa.const"() {value = dense<[2, 4, 0, 1, 3, 5]> : tensor<6xi32>}
+ // CHECK-DAG: %[[PADW:.+]] = "tosa.pad"(%arg1, %[[PADV]])
+ // CHECK-DAG: %[[RESW1:.+]] = "tosa.reshape"(%[[PADW]]) {new_shape = [5, 2, 2, 2, 3, 3]}
+ // CHECK-DAG: %[[TRANS:.+]] = "tosa.transpose"(%[[RESW1]], %[[TRANSV]])
+ // CHECK-DAG: %[[RESW2:.+]] = "tosa.reshape"(%[[TRANS]]) {new_shape = [30, 2, 2, 3]}
+ // CHECK-DAG: %[[REV1:.+]] = "tosa.reverse"(%[[RESW2]]) {axis = 1 : i64}
+ // CHECK-DAG: %[[NEWWEIGHT:.+]] = "tosa.reverse"(%[[REV1]]) {axis = 2 : i64}
+
+ // Pad out the input matrix to handle the transpose conv.
+ // CHECK-DAG: %[[PAD:.+]] = "tosa.const"() {value = dense<{{\[\[}}0, 0], [1, 1], [1, 1], [0, 0]]> : tensor<4x2xi32>}
+ // CHECK-DAG: %[[TRANS2:.+]] = "tosa.const"() {value = dense<[0, 1, 3, 2, 4, 5]> : tensor<6xi32>}
+ // CHECK-DAG: %[[NEWINPUT:.+]] = "tosa.pad"(%arg0, %[[PAD]])
+
+ // Manipulate the final shape.
+ // CHECK-DAG: %[[BIAS:.+]] = "tosa.const"() {value = dense<0.000000e+00> : tensor<30xf32>}
+ // CHECK-DAG: %[[CONV:.+]] = "tosa.conv2d"(%[[NEWINPUT]], %[[NEWWEIGHT]], %[[BIAS]]) {dilation = [1, 1], pad = [0, 0, 0, 0], stride = [1, 1]}
+ // CHECK-DAG: %[[RESHAPE_OUT_1:.+]] = "tosa.reshape"(%[[CONV]]) {new_shape = [2, 18, 16, 2, 3, 5]}
+ // CHECK-DAG: %[[TRANS_OUT:.+]] = "tosa.transpose"(%[[RESHAPE_OUT_1]], %[[TRANS2]])
+ // CHECK-DAG: %[[RESHAPE_OUT_2:.+]] = "tosa.reshape"(%[[TRANS_OUT]]) {new_shape = [2, 36, 48, 5]}
+ // CHECK-DAG: %[[SLICE:.+]] = "tosa.slice"(%[[RESHAPE_OUT_2]]) {size = [2, 35, 47, 5], start = [0, 0, 0, 0]}
+ // CHECK: %[[ADD:.+]] = "tosa.add"(%[[SLICE]], %arg2)
+ %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {dilation = [1, 1], out_pad = [0, 0], out_shape = [-1, -1, -1, -1], stride = [2, 3]} : (tensor<2x17x15x3xf32>, tensor<5x3x5x3xf32>, tensor<5xf32>) -> tensor<2x35x47x5xf32>
+ %1 = tensor.cast %0 : tensor<2x35x47x5xf32> to tensor<2x?x?x5xf32>
+ return %1 : tensor<2x?x?x5xf32>
+}
+
+// ----
+
+// CHECK-LABEL: @transpose_conv2d_strided_quantized
+func @transpose_conv2d_strided_quantized(%arg0: tensor<2x17x15x3xi8>, %arg1: tensor<5x3x5x3xi8>, %arg2: tensor<5xi32>) -> (tensor<2x35x47x5xi32>) {
+ // Manipulate the weight matrix to handle striding.
+ // CHECK-DAG: %[[PADV:.+]] = "tosa.const"() {value = dense<{{\[\[}}0, 0], [0, 1], [0, 1], [0, 0]]> : tensor<4x2xi32>}
+ // CHECK-DAG: %[[TRANSV:.+]] = "tosa.const"() {value = dense<[2, 4, 0, 1, 3, 5]> : tensor<6xi32>}
+ // CHECK-DAG: %[[PADW:.+]] = "tosa.pad"(%arg1, %[[PADV]]) {quantization_info = {input_zp = 42 : i32}}
+ // CHECK-DAG: %[[RESW1:.+]] = "tosa.reshape"(%[[PADW]]) {new_shape = [5, 2, 2, 2, 3, 3]}
+ // CHECK-DAG: %[[TRANS:.+]] = "tosa.transpose"(%[[RESW1]], %[[TRANSV]])
+ // CHECK-DAG: %[[RESW2:.+]] = "tosa.reshape"(%[[TRANS]]) {new_shape = [30, 2, 2, 3]}
+ // CHECK-DAG: %[[REV1:.+]] = "tosa.reverse"(%[[RESW2]]) {axis = 1 : i64}
+ // CHECK-DAG: %[[NEWWEIGHT:.+]] = "tosa.reverse"(%[[REV1]]) {axis = 2 : i64}
+
+ // Pad out the input matrix to handle the transpose conv.
+ // CHECK-DAG: %[[PAD:.+]] = "tosa.const"() {value = dense<{{\[\[}}0, 0], [1, 1], [1, 1], [0, 0]]> : tensor<4x2xi32>}
+ // CHECK-DAG: %[[TRANS2:.+]] = "tosa.const"() {value = dense<[0, 1, 3, 2, 4, 5]> : tensor<6xi32>}
+ // CHECK-DAG: %[[NEWINPUT:.+]] = "tosa.pad"(%arg0, %[[PAD]]) {quantization_info = {input_zp = -22 : i32}}
+
+ // Manipulate the final shape.
+ // CHECK-DAG: %[[BIAS:.+]] = "tosa.const"() {value = dense<0> : tensor<30xi32>}
+ // CHECK-DAG: %[[CONV:.+]] = "tosa.conv2d"(%[[NEWINPUT]], %[[NEWWEIGHT]], %[[BIAS]]) {dilation = [1, 1], pad = [0, 0, 0, 0], quantization_info = {input_zp = -22 : i32, weight_zp = 42 : i32}, stride = [1, 1]}
+ // CHECK-DAG: %[[RESHAPE_OUT_1:.+]] = "tosa.reshape"(%[[CONV]]) {new_shape = [2, 18, 16, 2, 3, 5]}
+ // CHECK-DAG: %[[TRANS_OUT:.+]] = "tosa.transpose"(%[[RESHAPE_OUT_1]], %[[TRANS2]])
+ // CHECK-DAG: %[[RESHAPE_OUT_2:.+]] = "tosa.reshape"(%[[TRANS_OUT]]) {new_shape = [2, 36, 48, 5]}
+ // CHECK-DAG: %[[SLICE:.+]] = "tosa.slice"(%[[RESHAPE_OUT_2]]) {size = [2, 35, 47, 5], start = [0, 0, 0, 0]}
+ // CHECK: %[[ADD:.+]] = "tosa.add"(%[[SLICE]], %arg2)
+ %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {dilation = [1, 1], out_pad = [0, 0], quantization_info = {input_zp = -22 : i32, weight_zp = 42 : i32}, out_shape = [-1, -1, -1, -1], stride = [2, 3]} : (tensor<2x17x15x3xi8>, tensor<5x3x5x3xi8>, tensor<5xi32>) -> tensor<2x35x47x5xi32>
+ return %0 : tensor<2x35x47x5xi32>
+}