[mlir][tosa] Migrate tosa to more efficient linalg.conv
authorRob Suderman <rob.suderman@gmail.com>
Wed, 11 Aug 2021 18:05:08 +0000 (11:05 -0700)
committerRob Suderman <rob.suderman@gmail.com>
Wed, 11 Aug 2021 18:05:12 +0000 (11:05 -0700)
Existing linalg.conv2d is not well optimized for performance. Changed to a
version that is more aligned for optimziation. Include the corresponding
transposes to use this optimized version.

This also splits the conv and depthwise conv into separate implementations
to avoid overly complex lowerings.

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D107504

mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
mlir/test/Dialect/Linalg/named-ops.mlir

index 53b54e1bff9fbabb34d0bc1c0cf9126e56374b6b..3e1fcabc8cb9b3cb35eca5c0755e086c2ac0935e 100644 (file)
@@ -628,10 +628,10 @@ structured_op: !LinalgStructuredOpConfig
                   scalar_arg: B
 --- !LinalgOpConfig
 metadata: !LinalgOpMetadata
-  name: conv_2d_input_nhwc_filter_ohwi_poly
-  cpp_class_name: Conv2DInputNhwcFilterOhwiPolyOp
+  name: conv_2d_nchw
+  cpp_class_name: Conv2DNchwOp
   doc: |-
-    Performs 2-D convolution.
+    Performs 2-D convolution.
 
     Numeric casting is performed on the operands to the inner multiply, promoting
     them to the same data type as the accumulator/output.
@@ -648,13 +648,13 @@ structured_op: !LinalgStructuredOpConfig
     usage: InputOperand
     type_var: T2
     shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12]
-      -> (s4, s5, s6, s3)>
+      -> (s4, s1, s5, s6)>
   - !LinalgOperandDefConfig
     name: O
     usage: OutputOperand
     type_var: U
     shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12]
-      -> (s0, s7, s8, s4)>
+      -> (s0, s4, s7, s8, s1)>
   - !LinalgOperandDefConfig
     name: strides
     usage: IndexAttribute
@@ -670,18 +670,18 @@ structured_op: !LinalgStructuredOpConfig
   indexing_maps: !LinalgIndexingMapsConfig
     static_indexing_maps:
     - affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
-      s9, s10, s11, s12] -> (d0, d1 * s9 + d3 * s11, d2 * s10 + d4 * s12, d6)>
+      s9, s10, s11, s12] -> (d0, d4, d2 * s9 + d5 * s11, d3 * s10 + d6 * s12)>
     - affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
-      s9, s10, s11, s12] -> (d5, d3, d4, d6)>
+      s9, s10, s11, s12] -> (d1, d4, d5, d6)>
     - affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
-      s9, s10, s11, s12] -> (d0, d1, d2, d5)>
+      s9, s10, s11, s12] -> (d0, d1, d2, d3)>
   iterator_types:
   - parallel
   - parallel
   - parallel
+  - parallel
   - reduction
   - reduction
-  - parallel
   - reduction
   assignments:
   - !ScalarAssign
@@ -710,14 +710,13 @@ structured_op: !LinalgStructuredOpConfig
                   scalar_arg: K
 --- !LinalgOpConfig
 metadata: !LinalgOpMetadata
-  name: conv_2d_input_nhwc_filter_ohwi_poly_q
-  cpp_class_name: Conv2DInputNhwcFilterOhwiPolyQOp
+  name: conv_2d_nhwc_hwcf
+  cpp_class_name: Conv2DNhwcHwcfOp
   doc: |-
-    Performs a 2-D quantized convolution.
+    Performs 2-D convolution.
 
     Numeric casting is performed on the operands to the inner multiply, promoting
-    them to the same data type as the accumulator/output. Includes zero point
-    adjustment for quantization.
+    them to the same data type as the accumulator/output.
 structured_op: !LinalgStructuredOpConfig
   args:
   - !LinalgOperandDefConfig
@@ -731,21 +730,13 @@ structured_op: !LinalgStructuredOpConfig
     usage: InputOperand
     type_var: T2
     shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12]
-      -> (s4, s5, s6, s3)>
-  - !LinalgOperandDefConfig
-    name: IZp
-    usage: InputOperand
-    type_var: I32
-  - !LinalgOperandDefConfig
-    name: KZp
-    usage: InputOperand
-    type_var: I32
+      -> (s4, s5, s3, s6)>
   - !LinalgOperandDefConfig
     name: O
     usage: OutputOperand
     type_var: U
     shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12]
-      -> (s0, s7, s8, s4)>
+      -> (s0, s7, s8, s6)>
   - !LinalgOperandDefConfig
     name: strides
     usage: IndexAttribute
@@ -761,22 +752,18 @@ structured_op: !LinalgStructuredOpConfig
   indexing_maps: !LinalgIndexingMapsConfig
     static_indexing_maps:
     - affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
-      s9, s10, s11, s12] -> (d0, d1 * s9 + d3 * s11, d2 * s10 + d4 * s12, d6)>
-    - affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
-      s9, s10, s11, s12] -> (d5, d3, d4, d6)>
+      s9, s10, s11, s12] -> (d0, d1 * s9 + d4 * s11, d2 * s10 + d5 * s12, d6)>
     - affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
-      s9, s10, s11, s12] -> ()>
-    - affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
-      s9, s10, s11, s12] -> ()>
+      s9, s10, s11, s12] -> (d4, d5, d6, d3)>
     - affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
-      s9, s10, s11, s12] -> (d0, d1, d2, d5)>
+      s9, s10, s11, s12] -> (d0, d1, d2, d3)>
   iterator_types:
   - parallel
   - parallel
   - parallel
+  - parallel
   - reduction
   - reduction
-  - parallel
   - reduction
   assignments:
   - !ScalarAssign
@@ -792,37 +779,17 @@ structured_op: !LinalgStructuredOpConfig
             fn_name: mul
             operands:
             - !ScalarExpression
-              scalar_apply:
-                fn_name: sub
+              symbolic_cast:
+                type_var: U
                 operands:
                 - !ScalarExpression
-                  symbolic_cast:
-                    type_var: U
-                    operands:
-                    - !ScalarExpression
-                      scalar_arg: I
-                - !ScalarExpression
-                  symbolic_cast:
-                    type_var: U
-                    operands:
-                    - !ScalarExpression
-                      scalar_arg: IZp
+                  scalar_arg: I
             - !ScalarExpression
-              scalar_apply:
-                fn_name: sub
+              symbolic_cast:
+                type_var: U
                 operands:
                 - !ScalarExpression
-                  symbolic_cast:
-                    type_var: U
-                    operands:
-                    - !ScalarExpression
-                      scalar_arg: K
-                - !ScalarExpression
-                  symbolic_cast:
-                    type_var: U
-                    operands:
-                    - !ScalarExpression
-                      scalar_arg: KZp
+                  scalar_arg: K
 --- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: depthwise_conv_2d_input_nhwc_filter_hwc_poly
@@ -906,13 +873,14 @@ structured_op: !LinalgStructuredOpConfig
                   scalar_arg: K
 --- !LinalgOpConfig
 metadata: !LinalgOpMetadata
-  name: depthwise_conv_2D_nchw
-  cpp_class_name: DepthwiseConv2DNchwOp
+  name: conv_2d_nhwc_hwcf_q
+  cpp_class_name: Conv2DNhwcHwcfQOp
   doc: |-
-    Performs depth-wise 2-D convolution.
+    Performs 2-D convolution with zero point offsets.
 
     Numeric casting is performed on the operands to the inner multiply, promoting
-    them to the same data type as the accumulator/output.
+    them to the same data type as the accumulator/output. This includes the zero
+    point offsets common to quantized operations.
 structured_op: !LinalgStructuredOpConfig
   args:
   - !LinalgOperandDefConfig
@@ -927,12 +895,20 @@ structured_op: !LinalgStructuredOpConfig
     type_var: T2
     shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12]
       -> (s4, s5, s3, s6)>
+  - !LinalgOperandDefConfig
+    name: IZp
+    usage: InputOperand
+    type_var: I32
+  - !LinalgOperandDefConfig
+    name: KZp
+    usage: InputOperand
+    type_var: I32
   - !LinalgOperandDefConfig
     name: O
     usage: OutputOperand
     type_var: U
     shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12]
-      -> (s0, s7, s8, s3, s6)>
+      -> (s0, s7, s8, s6)>
   - !LinalgOperandDefConfig
     name: strides
     usage: IndexAttribute
@@ -948,19 +924,23 @@ structured_op: !LinalgStructuredOpConfig
   indexing_maps: !LinalgIndexingMapsConfig
     static_indexing_maps:
     - affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
-      s9, s10, s11, s12] -> (d0, d1 * s9 + d3 * s11, d2 * s10 + d4 * s12, d5)>
+      s9, s10, s11, s12] -> (d0, d1 * s9 + d4 * s11, d2 * s10 + d5 * s12, d6)>
     - affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
-      s9, s10, s11, s12] -> (d3, d4, d5, d6)>
+      s9, s10, s11, s12] -> (d4, d5, d6, d3)>
     - affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
-      s9, s10, s11, s12] -> (d0, d1, d2, d5, d6)>
+      s9, s10, s11, s12] -> ()>
+    - affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
+      s9, s10, s11, s12] -> ()>
+    - affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
+      s9, s10, s11, s12] -> (d0, d1, d2, d3)>
   iterator_types:
   - parallel
   - parallel
   - parallel
+  - parallel
+  - reduction
   - reduction
   - reduction
-  - parallel
-  - parallel
   assignments:
   - !ScalarAssign
     arg: O
@@ -975,21 +955,41 @@ structured_op: !LinalgStructuredOpConfig
             fn_name: mul
             operands:
             - !ScalarExpression
-              symbolic_cast:
-                type_var: U
+              scalar_apply:
+                fn_name: sub
                 operands:
                 - !ScalarExpression
-                  scalar_arg: I
+                  symbolic_cast:
+                    type_var: U
+                    operands:
+                    - !ScalarExpression
+                      scalar_arg: I
+                - !ScalarExpression
+                  symbolic_cast:
+                    type_var: U
+                    operands:
+                    - !ScalarExpression
+                      scalar_arg: IZp
             - !ScalarExpression
-              symbolic_cast:
-                type_var: U
+              scalar_apply:
+                fn_name: sub
                 operands:
                 - !ScalarExpression
-                  scalar_arg: K
+                  symbolic_cast:
+                    type_var: U
+                    operands:
+                    - !ScalarExpression
+                      scalar_arg: K
+                - !ScalarExpression
+                  symbolic_cast:
+                    type_var: U
+                    operands:
+                    - !ScalarExpression
+                      scalar_arg: KZp
 --- !LinalgOpConfig
 metadata: !LinalgOpMetadata
-  name: depthwise_conv2D_nchw_q
-  cpp_class_name: DepthwiseConv2DNchwQOp
+  name: depthwise_conv2D_nchw
+  cpp_class_name: DepthwiseConv2DNchwOp
   doc: |-
     Performs depth-wise 2-D convolution.
 
@@ -1009,14 +1009,6 @@ structured_op: !LinalgStructuredOpConfig
     type_var: T2
     shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12]
       -> (s4, s5, s3, s6)>
-  - !LinalgOperandDefConfig
-    name: IZp
-    usage: InputOperand
-    type_var: I32
-  - !LinalgOperandDefConfig
-    name: KZp
-    usage: InputOperand
-    type_var: I32
   - !LinalgOperandDefConfig
     name: O
     usage: OutputOperand
@@ -1041,10 +1033,6 @@ structured_op: !LinalgStructuredOpConfig
       s9, s10, s11, s12] -> (d0, d1 * s9 + d3 * s11, d2 * s10 + d4 * s12, d5)>
     - affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
       s9, s10, s11, s12] -> (d3, d4, d5, d6)>
-    - affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
-      s9, s10, s11, s12] -> ()>
-    - affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
-      s9, s10, s11, s12] -> ()>
     - affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
       s9, s10, s11, s12] -> (d0, d1, d2, d5, d6)>
   iterator_types:
@@ -1069,43 +1057,23 @@ structured_op: !LinalgStructuredOpConfig
             fn_name: mul
             operands:
             - !ScalarExpression
-              scalar_apply:
-                fn_name: sub
+              symbolic_cast:
+                type_var: U
                 operands:
                 - !ScalarExpression
-                  symbolic_cast:
-                    type_var: U
-                    operands:
-                    - !ScalarExpression
-                      scalar_arg: I
-                - !ScalarExpression
-                  symbolic_cast:
-                    type_var: U
-                    operands:
-                    - !ScalarExpression
-                      scalar_arg: IZp
+                  scalar_arg: I
             - !ScalarExpression
-              scalar_apply:
-                fn_name: sub
+              symbolic_cast:
+                type_var: U
                 operands:
                 - !ScalarExpression
-                  symbolic_cast:
-                    type_var: U
-                    operands:
-                    - !ScalarExpression
-                      scalar_arg: K
-                - !ScalarExpression
-                  symbolic_cast:
-                    type_var: U
-                    operands:
-                    - !ScalarExpression
-                      scalar_arg: KZp
+                  scalar_arg: K
 --- !LinalgOpConfig
 metadata: !LinalgOpMetadata
-  name: conv_2d_nchw
-  cpp_class_name: Conv2DNchwOp
+  name: depthwise_conv2D_nchw_q
+  cpp_class_name: DepthwiseConv2DNchwQOp
   doc: |-
-    Performs 2-D convolution.
+    Performs depth-wise 2-D convolution.
 
     Numeric casting is performed on the operands to the inner multiply, promoting
     them to the same data type as the accumulator/output.
@@ -1122,13 +1090,21 @@ structured_op: !LinalgStructuredOpConfig
     usage: InputOperand
     type_var: T2
     shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12]
-      -> (s4, s1, s5, s6)>
+      -> (s4, s5, s3, s6)>
+  - !LinalgOperandDefConfig
+    name: IZp
+    usage: InputOperand
+    type_var: I32
+  - !LinalgOperandDefConfig
+    name: KZp
+    usage: InputOperand
+    type_var: I32
   - !LinalgOperandDefConfig
     name: O
     usage: OutputOperand
     type_var: U
     shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12]
-      -> (s0, s4, s7, s8, s1)>
+      -> (s0, s7, s8, s3, s6)>
   - !LinalgOperandDefConfig
     name: strides
     usage: IndexAttribute
@@ -1144,19 +1120,23 @@ structured_op: !LinalgStructuredOpConfig
   indexing_maps: !LinalgIndexingMapsConfig
     static_indexing_maps:
     - affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
-      s9, s10, s11, s12] -> (d0, d4, d2 * s9 + d5 * s11, d3 * s10 + d6 * s12)>
+      s9, s10, s11, s12] -> (d0, d1 * s9 + d3 * s11, d2 * s10 + d4 * s12, d5)>
     - affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
-      s9, s10, s11, s12] -> (d1, d4, d5, d6)>
+      s9, s10, s11, s12] -> (d3, d4, d5, d6)>
     - affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
-      s9, s10, s11, s12] -> (d0, d1, d2, d3)>
+      s9, s10, s11, s12] -> ()>
+    - affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
+      s9, s10, s11, s12] -> ()>
+    - affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
+      s9, s10, s11, s12] -> (d0, d1, d2, d5, d6)>
   iterator_types:
   - parallel
   - parallel
   - parallel
-  - parallel
-  - reduction
   - reduction
   - reduction
+  - parallel
+  - parallel
   assignments:
   - !ScalarAssign
     arg: O
@@ -1171,17 +1151,37 @@ structured_op: !LinalgStructuredOpConfig
             fn_name: mul
             operands:
             - !ScalarExpression
-              symbolic_cast:
-                type_var: U
+              scalar_apply:
+                fn_name: sub
                 operands:
                 - !ScalarExpression
-                  scalar_arg: I
+                  symbolic_cast:
+                    type_var: U
+                    operands:
+                    - !ScalarExpression
+                      scalar_arg: I
+                - !ScalarExpression
+                  symbolic_cast:
+                    type_var: U
+                    operands:
+                    - !ScalarExpression
+                      scalar_arg: IZp
             - !ScalarExpression
-              symbolic_cast:
-                type_var: U
+              scalar_apply:
+                fn_name: sub
                 operands:
                 - !ScalarExpression
-                  scalar_arg: K
+                  symbolic_cast:
+                    type_var: U
+                    operands:
+                    - !ScalarExpression
+                      scalar_arg: K
+                - !ScalarExpression
+                  symbolic_cast:
+                    type_var: U
+                    operands:
+                    - !ScalarExpression
+                      scalar_arg: KZp
 --- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: pooling_nhwc_sum
@@ -1896,3 +1896,4 @@ structured_op: !LinalgStructuredOpConfig
                     operands:
                     - !ScalarExpression
                       scalar_arg: I
+
index 37687337e10b81cbe24d43b9729f1d6b9815eef0..8e24f03a0dacfa5f691662872e38e9e8c28f9698 100644 (file)
@@ -849,104 +849,213 @@ static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis,
   return success();
 }
 
-static LogicalResult
-convolutionMatchAndRewriterHelper(Operation *op,
-                                  ConversionPatternRewriter &rewriter) {
-  Location loc = op->getLoc();
-  Value input = op->getOperand(0);
-  Value weight = op->getOperand(1);
-  Value bias = op->getOperand(2);
+namespace {
 
-  ShapedType inputTy = input.getType().cast<ShapedType>();
-  ShapedType weightTy = weight.getType().cast<ShapedType>();
-  ShapedType biasTy = bias.getType().cast<ShapedType>();
-  ShapedType resultTy = op->getResult(0).getType().cast<ShapedType>();
+template <typename SrcOp>
+class PointwiseConverter : public OpRewritePattern<SrcOp> {
+public:
+  using OpRewritePattern<SrcOp>::OpRewritePattern;
 
-  Type inputETy = inputTy.getElementType();
-  Type resultETy = resultTy.getElementType();
-
-  auto padAttr = op->getAttr("pad").cast<ArrayAttr>();
-  auto strideTosaAttr = op->getAttr("stride").cast<ArrayAttr>();
-  auto dilationTosaAttr = op->getAttr("dilation").cast<ArrayAttr>();
-
-  bool isQuantized = op->hasAttr("quantization_info");
-  IntegerAttr iZp;
-  IntegerAttr kZp;
-  if (isQuantized) {
-    auto quantizationInfo =
-        op->getAttr("quantization_info").cast<tosa::ConvOpQuantizationAttr>();
-    iZp = rewriter.getI32IntegerAttr(
-        quantizationInfo.input_zp().getValue().getSExtValue());
-    kZp = rewriter.getI32IntegerAttr(
-        quantizationInfo.weight_zp().getValue().getSExtValue());
+  LogicalResult matchAndRewrite(SrcOp op,
+                                PatternRewriter &rewriter) const final {
+    return elementwiseMatchAndRewriteHelper(op, rewriter);
   }
+};
 
-  if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() ||
-      !biasTy.hasStaticShape() || !resultTy.hasStaticShape())
-    return rewriter.notifyMatchFailure(op,
-                                       "tosa.conv ops require static shapes");
+class ConvConverter : public OpConversionPattern<tosa::Conv2DOp> {
+public:
+  using OpConversionPattern<tosa::Conv2DOp>::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(tosa::Conv2DOp op, ArrayRef<Value> args,
+                  ConversionPatternRewriter &rewriter) const final {
+    Location loc = op->getLoc();
+    Value input = op->getOperand(0);
+    Value weight = op->getOperand(1);
+    Value bias = op->getOperand(2);
 
-  auto weightShape = weightTy.getShape();
-  auto resultShape = resultTy.getShape();
+    ShapedType inputTy = input.getType().cast<ShapedType>();
+    ShapedType weightTy = weight.getType().cast<ShapedType>();
+    ShapedType biasTy = bias.getType().cast<ShapedType>();
+    ShapedType resultTy = op->getResult(0).getType().cast<ShapedType>();
 
-  // Apply padding as necessary.
-  Attribute zeroAttr = rewriter.getZeroAttr(inputETy);
-  llvm::SmallVector<int64_t> pad;
-  pad.resize(2, 0);
-  getValuesFromIntArrayAttribute(padAttr, pad);
-  pad.resize(pad.size() + 2, 0);
+    Type inputETy = inputTy.getElementType();
+    Type resultETy = resultTy.getElementType();
 
-  input = applyPad(loc, input, pad, zeroAttr, rewriter);
+    auto padAttr = op->getAttr("pad").cast<ArrayAttr>();
+    auto strideTosaAttr = op->getAttr("stride").cast<ArrayAttr>();
+    auto dilationTosaAttr = op->getAttr("dilation").cast<ArrayAttr>();
+    bool isQuantized = op->hasAttr("quantization_info");
 
-  // Broadcast the initial value to the output tensor before convolving.
-  SmallVector<AffineMap, 4> indexingMaps;
-  indexingMaps.push_back(AffineMap::get(
-      /*dimCount=*/resultTy.getRank(), /*symbolCount=*/0,
-      {rewriter.getAffineDimExpr(3)}, rewriter.getContext()));
-  indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultTy.getRank()));
+    if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() ||
+        !biasTy.hasStaticShape() || !resultTy.hasStaticShape())
+      return rewriter.notifyMatchFailure(op,
+                                         "tosa.conv ops require static shapes");
 
-  Value initTensor = rewriter.create<linalg::InitTensorOp>(
-      loc, resultTy.getShape(), resultTy.getElementType());
+    auto weightShape = weightTy.getShape();
 
-  Value biasBroadcast =
-      rewriter
-          .create<linalg::GenericOp>(
-              loc, resultTy, bias, initTensor, indexingMaps,
-              getNParallelLoopsAttrs(resultTy.getRank()),
-              [&](OpBuilder &nestedBuilder, Location nestedLoc,
-                  ValueRange args) {
-                nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]);
-              })
-          .getResult(0);
-
-  // Extract the attributes for convolution.
-  llvm::SmallVector<int64_t> stride, dilation;
-  getValuesFromIntArrayAttribute(strideTosaAttr, stride);
-  getValuesFromIntArrayAttribute(dilationTosaAttr, dilation);
-
-  // Create the convolution op.
-  auto strideAttr = DenseIntElementsAttr::get(
-      RankedTensorType::get({2}, rewriter.getI64Type()), stride);
-  auto dilationAttr = DenseIntElementsAttr::get(
-      RankedTensorType::get({2}, rewriter.getI64Type()), dilation);
-
-  if (isa<tosa::Conv2DOp>(op) && !isQuantized) {
-    rewriter.replaceOpWithNewOp<linalg::Conv2DInputNhwcFilterOhwiPolyOp>(
+    // Apply padding as necessary.
+    Attribute zeroAttr = rewriter.getZeroAttr(inputETy);
+    llvm::SmallVector<int64_t> pad;
+    pad.resize(2, 0);
+    getValuesFromIntArrayAttribute(padAttr, pad);
+    pad.resize(pad.size() + 2, 0);
+    input = applyPad(loc, input, pad, zeroAttr, rewriter);
+
+    // Transpose the kernel to match dimension ordering of the linalg
+    // convolution operation.
+    // TODO(suderman): See if this can be efficiently folded - check whether
+    // the input is used anywhere else, if not fold the constant.
+    SmallVector<int64_t> weightPerm{1, 2, 3, 0};
+    SmallVector<int64_t> newWeightShape{weightShape[1], weightShape[2],
+                                        weightShape[3], weightShape[0]};
+    auto weightPermAttr = DenseIntElementsAttr::get(
+        RankedTensorType::get({4}, rewriter.getI64Type()), weightPerm);
+    Value weightPermValue = rewriter.create<ConstantOp>(loc, weightPermAttr);
+    Type newWeightTy =
+        RankedTensorType::get(newWeightShape, weightTy.getElementType());
+    weight = rewriter.create<tosa::TransposeOp>(loc, newWeightTy, weight,
+                                                weightPermValue);
+
+    // Broadcast the initial value to the output tensor before convolving.
+    SmallVector<AffineMap, 4> indexingMaps;
+    indexingMaps.push_back(AffineMap::get(
+        /*dimCount=*/resultTy.getRank(), /*symbolCount=*/0,
+        {rewriter.getAffineDimExpr(3)}, rewriter.getContext()));
+    indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultTy.getRank()));
+
+    Value initTensor = rewriter.create<linalg::InitTensorOp>(
+        loc, resultTy.getShape(), resultETy);
+
+    Value biasBroadcast =
+        rewriter
+            .create<linalg::GenericOp>(
+                loc, resultTy, bias, initTensor, indexingMaps,
+                getNParallelLoopsAttrs(resultTy.getRank()),
+                [&](OpBuilder &nestedBuilder, Location nestedLoc,
+                    ValueRange args) {
+                  nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]);
+                })
+            .getResult(0);
+
+    // Extract the attributes for convolution.
+    llvm::SmallVector<int64_t> stride, dilation;
+    getValuesFromIntArrayAttribute(strideTosaAttr, stride);
+    getValuesFromIntArrayAttribute(dilationTosaAttr, dilation);
+
+    // Create the convolution op.
+    auto strideAttr = DenseIntElementsAttr::get(
+        RankedTensorType::get({2}, rewriter.getI64Type()), stride);
+    auto dilationAttr = DenseIntElementsAttr::get(
+        RankedTensorType::get({2}, rewriter.getI64Type()), dilation);
+
+    Value conv;
+    if (isQuantized) {
+      auto quantizationInfo =
+          op->getAttr("quantization_info").cast<tosa::ConvOpQuantizationAttr>();
+      auto iZp = rewriter.getI32IntegerAttr(
+          quantizationInfo.input_zp().getValue().getSExtValue());
+      auto kZp = rewriter.getI32IntegerAttr(
+          quantizationInfo.weight_zp().getValue().getSExtValue());
+
+      auto iZpVal = rewriter.create<ConstantOp>(loc, iZp);
+      auto kZpVal = rewriter.create<ConstantOp>(loc, kZp);
+      rewriter.replaceOpWithNewOp<linalg::Conv2DNhwcHwcfQOp>(
+          op, resultTy, ValueRange{input, weight, iZpVal, kZpVal},
+          ValueRange{biasBroadcast}, strideAttr, dilationAttr);
+      return success();
+    }
+
+    rewriter.replaceOpWithNewOp<linalg::Conv2DNhwcHwcfOp>(
         op, resultTy, ValueRange{input, weight}, ValueRange{biasBroadcast},
         strideAttr, dilationAttr);
     return success();
   }
+};
 
-  if (isa<tosa::Conv2DOp>(op) && isQuantized) {
-    auto iZpVal = rewriter.create<ConstantOp>(loc, iZp);
-    auto kZpVal = rewriter.create<ConstantOp>(loc, kZp);
-    rewriter.replaceOpWithNewOp<linalg::Conv2DInputNhwcFilterOhwiPolyQOp>(
-        op, resultTy, ValueRange{input, weight, iZpVal, kZpVal},
-        ValueRange{biasBroadcast}, strideAttr, dilationAttr);
-    return success();
-  }
+class DepthwiseConvConverter
+    : public OpConversionPattern<tosa::DepthwiseConv2DOp> {
+public:
+  using OpConversionPattern<tosa::DepthwiseConv2DOp>::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(tosa::DepthwiseConv2DOp op, ArrayRef<Value> args,
+                  ConversionPatternRewriter &rewriter) const final {
+    Location loc = op->getLoc();
+    Value input = op->getOperand(0);
+    Value weight = op->getOperand(1);
+    Value bias = op->getOperand(2);
+
+    ShapedType inputTy = input.getType().cast<ShapedType>();
+    ShapedType weightTy = weight.getType().cast<ShapedType>();
+    ShapedType biasTy = bias.getType().cast<ShapedType>();
+    ShapedType resultTy = op->getResult(0).getType().cast<ShapedType>();
 
-  if (isa<tosa::DepthwiseConv2DOp>(op)) {
+    Type inputETy = inputTy.getElementType();
+    Type resultETy = resultTy.getElementType();
+
+    auto padAttr = op->getAttr("pad").cast<ArrayAttr>();
+    auto strideTosaAttr = op->getAttr("stride").cast<ArrayAttr>();
+    auto dilationTosaAttr = op->getAttr("dilation").cast<ArrayAttr>();
+
+    bool isQuantized = op->hasAttr("quantization_info");
+    IntegerAttr iZp;
+    IntegerAttr kZp;
+    if (isQuantized) {
+      auto quantizationInfo =
+          op->getAttr("quantization_info").cast<tosa::ConvOpQuantizationAttr>();
+      iZp = rewriter.getI32IntegerAttr(
+          quantizationInfo.input_zp().getValue().getSExtValue());
+      kZp = rewriter.getI32IntegerAttr(
+          quantizationInfo.weight_zp().getValue().getSExtValue());
+    }
+
+    if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() ||
+        !biasTy.hasStaticShape() || !resultTy.hasStaticShape())
+      return rewriter.notifyMatchFailure(op,
+                                         "tosa.conv ops require static shapes");
+
+    auto weightShape = weightTy.getShape();
+    auto resultShape = resultTy.getShape();
+
+    // Apply padding as necessary.
+    Attribute zeroAttr = rewriter.getZeroAttr(inputETy);
+    llvm::SmallVector<int64_t> pad;
+    pad.resize(2, 0);
+    getValuesFromIntArrayAttribute(padAttr, pad);
+    pad.resize(pad.size() + 2, 0);
+
+    input = applyPad(loc, input, pad, zeroAttr, rewriter);
+
+    // Broadcast the initial value to the output tensor before convolving.
+    SmallVector<AffineMap, 4> indexingMaps;
+    indexingMaps.push_back(AffineMap::get(
+        /*dimCount=*/resultTy.getRank(), /*symbolCount=*/0,
+        {rewriter.getAffineDimExpr(3)}, rewriter.getContext()));
+    indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultTy.getRank()));
+
+    Value initTensor =
+        rewriter.create<linalg::InitTensorOp>(loc, resultShape, resultETy);
+
+    Value biasBroadcast =
+        rewriter
+            .create<linalg::GenericOp>(
+                loc, resultTy, bias, initTensor, indexingMaps,
+                getNParallelLoopsAttrs(resultTy.getRank()),
+                [&](OpBuilder &nestedBuilder, Location nestedLoc,
+                    ValueRange args) {
+                  nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]);
+                })
+            .getResult(0);
+
+    // Extract the attributes for convolution.
+    llvm::SmallVector<int64_t> stride, dilation;
+    getValuesFromIntArrayAttribute(strideTosaAttr, stride);
+    getValuesFromIntArrayAttribute(dilationTosaAttr, dilation);
+
+    // Create the convolution op.
+    auto strideAttr = DenseIntElementsAttr::get(
+        RankedTensorType::get({2}, rewriter.getI64Type()), stride);
+    auto dilationAttr = DenseIntElementsAttr::get(
+        RankedTensorType::get({2}, rewriter.getI64Type()), dilation);
     ShapedType linalgConvTy =
         RankedTensorType::get({resultShape[0], resultShape[1], resultShape[2],
                                weightShape[2], weightShape[3]},
@@ -976,32 +1085,6 @@ convolutionMatchAndRewriterHelper(Operation *op,
     rewriter.replaceOp(op, reshape);
     return success();
   }
-
-  return failure();
-}
-
-namespace {
-
-template <typename SrcOp>
-class PointwiseConverter : public OpRewritePattern<SrcOp> {
-public:
-  using OpRewritePattern<SrcOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(SrcOp op,
-                                PatternRewriter &rewriter) const final {
-    return elementwiseMatchAndRewriteHelper(op, rewriter);
-  }
-};
-
-template <typename T>
-class ConvConverter : public OpConversionPattern<T> {
-public:
-  using OpConversionPattern<T>::OpConversionPattern;
-  LogicalResult
-  matchAndRewrite(T op, ArrayRef<Value> args,
-                  ConversionPatternRewriter &rewriter) const final {
-    return convolutionMatchAndRewriterHelper(op, rewriter);
-  }
 };
 
 class TransposeConvConverter
@@ -2528,8 +2611,8 @@ void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns(
       ReduceConverter<tosa::ReduceProdOp>,
       ArgMaxConverter,
       ConcatConverter,
-      ConvConverter<tosa::Conv2DOp>,
-      ConvConverter<tosa::DepthwiseConv2DOp>,
+      ConvConverter,
+      DepthwiseConvConverter,
       TransposeConvConverter,
       GatherConverter,
       PadConverter,
index fc92c196a059b46c3a8c74c6599c4a3c85d183a6..b9faeeb831dfa903ce7566ce3ad97b0dbc7afd47 100644 (file)
@@ -144,49 +144,39 @@ def dot(
   implements(ContractionOpInterface)
   C[None] += cast(U, A[D.m]) * cast(U, B[D.m])
 
-
 @linalg_structured_op
-def conv_2d_input_nhwc_filter_ohwi_poly(
-    I=TensorDef(T1, S.N, S.IH, S.IW, S.IC),
-    K=TensorDef(T2, S.OC, S.KH, S.KW, S.IC),
-    O=TensorDef(U, S.N, S.OH, S.OW, S.OC, output=True),
+def conv_2d_nchw(
+    I=TensorDef(T1, S.N, S.C, S.IH, S.IW),
+    K=TensorDef(T2, S.F, S.C, S.KH, S.KW),
+    O=TensorDef(U, S.N, S.F, S.OH, S.OW, S.C, output=True),
     strides=AttributeDef(S.SH, S.SW),
     dilations=AttributeDef(S.DH, S.DW)):
-  """Performs 2-D convolution.
+  """Performs 2-D convolution.
 
   Numeric casting is performed on the operands to the inner multiply, promoting
   them to the same data type as the accumulator/output.
   """
-  domain(D.n, D.oh, D.ow, D.kh, D.kw, D.oc, D.ic)
-  O[D.n, D.oh, D.ow, D.oc] += cast(
-      U, I[D.n,
-           D.oh * S.SH + D.kh * S.DH,
-           D.ow * S.SW + D.kw * S.DW,
-           D.ic]) * cast(U, K[D.oc, D.kh, D.kw, D.ic])
+  domain(D.n, D.f, D.oh, D.ow, D.c, D.kh, D.kw)
+  O[D.n, D.f, D.oh, D.ow] += cast(
+      U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
+           ]) * cast(U, K[D.f, D.c, D.kh, D.kw])
 
 @linalg_structured_op
-def conv_2d_input_nhwc_filter_ohwi_poly_q(
-    I=TensorDef(T1, S.N, S.IH, S.IW, S.IC),
-    K=TensorDef(T2, S.OC, S.KH, S.KW, S.IC),
-    IZp=ScalarDef(I32),
-    KZp=ScalarDef(I32),
-    O=TensorDef(U, S.N, S.OH, S.OW, S.OC, output=True),
+def conv_2d_nhwc_hwcf(
+    I=TensorDef(T1, S.N, S.IH, S.IW, S.C),
+    K=TensorDef(T2, S.KH, S.KW, S.C, S.F),
+    O=TensorDef(U, S.N, S.OH, S.OW, S.F, output=True),
     strides=AttributeDef(S.SH, S.SW),
     dilations=AttributeDef(S.DH, S.DW)):
-  """Performs a 2-D quantized convolution.
+  """Performs 2-D convolution.
 
   Numeric casting is performed on the operands to the inner multiply, promoting
-  them to the same data type as the accumulator/output. Includes zero point
-  adjustment for quantization.
+  them to the same data type as the accumulator/output.
   """
-  domain(D.n, D.oh, D.ow, D.kh, D.kw, D.oc, D.ic)
-  O[D.n, D.oh, D.ow, D.oc] += ((cast(
-      U, I[D.n,
-           D.oh * S.SH + D.kh * S.DH,
-           D.ow * S.SW + D.kw * S.DW,
-           D.ic]) - cast(U, IZp)) *
-           (cast(U, K[D.oc, D.kh, D.kw, D.ic]) - cast(U, KZp)))
-
+  domain(D.n, D.oh, D.ow, D.f, D.kh, D.kw, D.c)
+  O[D.n, D.oh, D.ow, D.f] += cast(
+      U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c
+           ]) * cast(U, K[D.kh, D.kw, D.c, D.f])
 
 @linalg_structured_op
 def depthwise_conv_2d_input_nhwc_filter_hwc_poly(
@@ -206,24 +196,27 @@ def depthwise_conv_2d_input_nhwc_filter_hwc_poly(
            D.c]) * cast(U, K[D.kh, D.kw, D.c])
 
 @linalg_structured_op
-def conv_2d_nchw(
-    I=TensorDef(T1, S.N, S.C, S.IH, S.IW),
-    K=TensorDef(T2, S.F, S.C, S.KH, S.KW),
-    O=TensorDef(U, S.N, S.F, S.OH, S.OW, S.C, output=True),
+def conv_2d_nhwc_hwcf_q(
+    I=TensorDef(T1, S.N, S.IH, S.IW, S.C),
+    K=TensorDef(T2, S.KH, S.KW, S.C, S.F),
+    IZp=ScalarDef(I32),
+    KZp=ScalarDef(I32),
+    O=TensorDef(U, S.N, S.OH, S.OW, S.F, output=True),
     strides=AttributeDef(S.SH, S.SW),
     dilations=AttributeDef(S.DH, S.DW)):
-  """Performs 2-D convolution.
+  """Performs 2-D convolution with zero point offsets.
 
   Numeric casting is performed on the operands to the inner multiply, promoting
-  them to the same data type as the accumulator/output.
+  them to the same data type as the accumulator/output. This includes the zero
+  point offsets common to quantized operations.
   """
-  domain(D.n, D.f, D.oh, D.ow, D.c, D.kh, D.kw)
-  O[D.n, D.f, D.oh, D.ow] += cast(
-      U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
-           ]) * cast(U, K[D.f, D.c, D.kh, D.kw])
+  domain(D.n, D.oh, D.ow, D.f, D.kh, D.kw, D.c)
+  O[D.n, D.oh, D.ow, D.f] += (cast(
+      U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c
+           ]) - cast(U, IZp)) * (cast(U, K[D.kh, D.kw, D.c, D.f]) - cast(U, KZp))
 
-
-def depthwise_conv2D_nchw(  #TODO: Fix name
+@linalg_structured_op
+def depthwise_conv2D_nchw(
     I=TensorDef(T1, S.N, S.IH, S.IW, S.IC),
     K=TensorDef(T2, S.KH, S.KW, S.IC, S.CM),
     O=TensorDef(U, S.N, S.OH, S.OW, S.IC, S.CM, output=True),
@@ -239,8 +232,8 @@ def depthwise_conv2D_nchw(  #TODO: Fix name
       U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
            D.ic]) * cast(U, K[D.kh, D.kw, D.ic, D.cm])
 
-
-def depthwise_conv2D_nchw_q(  #TODO: Fix name
+@linalg_structured_op
+def depthwise_conv2D_nchw_q(
     I=TensorDef(T1, S.N, S.IH, S.IW, S.IC),
     K=TensorDef(T2, S.KH, S.KW, S.IC, S.CM),
     IZp=ScalarDef(I32),
index 309846d66c94eec7576f8a7ff2656a3622cb05cb..3c89de39518785f8ce9bc7f240e1b997f9351006 100644 (file)
@@ -1176,14 +1176,19 @@ func @avg_pool(%arg0: tensor<1x6x34x62xf32>) -> (tensor<1x5x33x62xf32>) {
 
 // -----
 
-// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d3)>
-// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d0, d1, d2)>
+// CHECK: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)>
 
-// CHECK-LABEL: @conv2d_f32
+// CHECK-LABEL @conv2d_f32
 func @conv2d_f32(%input: tensor<1x49x42x27xf32>, %weights: tensor<28x3x3x27xf32>, %bias: tensor<28xf32>) -> () {
-  // CHECK: %[[INIT:.+]] = linalg.init_tensor [1, 45, 40, 28]
-  // CHECK: %[[BROADCAST:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<28xf32>) outs(%[[INIT]] : tensor<1x45x40x28xf32>)
-  // CHECK: linalg.conv_2d_input_nhwc_filter_ohwi_poly {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<1x49x42x27xf32>, tensor<28x3x3x27xf32>) outs(%[[BROADCAST]] : tensor<1x45x40x28xf32>)
+  // CHECK: %[[W_IN:.+]] = linalg.init_tensor [3, 3, 27, 28]
+  // CHECK: %[[W:.+]] = linalg.generic {indexing_maps = [#[[MAP0]], #[[MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg1 : tensor<28x3x3x27xf32>) outs(%[[W_IN]] : tensor<3x3x27x28xf32>)
+  // CHECK:   linalg.yield %arg3 : f32
+  // CHECK: %[[B_IN:.+]] = linalg.init_tensor [1, 45, 40, 28]
+  // CHECK: %[[B:.+]] = linalg.generic {indexing_maps = [#[[MAP2]], #[[MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<28xf32>) outs(%[[B_IN]] : tensor<1x45x40x28xf32>)
+  // CHECK:   linalg.yield %arg3 : f32
+  // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %1 : tensor<1x49x42x27xf32>, tensor<3x3x27x28xf32>) outs(%[[B]] : tensor<1x45x40x28xf32>)
   %0 = "tosa.conv2d"(%input, %weights, %bias) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [2, 1]} : (tensor<1x49x42x27xf32>, tensor<28x3x3x27xf32>, tensor<28xf32>)  -> (tensor<1x45x40x28xf32>)
   return
 }
@@ -1192,26 +1197,17 @@ func @conv2d_f32(%input: tensor<1x49x42x27xf32>, %weights: tensor<28x3x3x27xf32>
 
 // CHECK-LABEL: @conv2d_padded_f32
 func @conv2d_padded_f32(%input: tensor<1x47x40x28xf32>, %weights: tensor<28x3x3x28xf32>, %bias: tensor<28xf32>) -> () {
-  // CHECK: linalg.pad_tensor %arg0
-  // CHECK: linalg.conv_2d_input_nhwc_filter_ohwi_poly
+  // CHECK: linalg.pad_tensor %arg0 low[0, 1, 1, 0] high[0, 1, 1, 0]
+  // CHECK: linalg.conv_2d_nhwc_hwcf
   %0 = "tosa.conv2d"(%input, %weights, %bias) {pad = [1, 1, 1, 1], stride = [1, 1], dilation = [2, 1]} : (tensor<1x47x40x28xf32>, tensor<28x3x3x28xf32>, tensor<28xf32>)  -> (tensor<1x45x40x28xf32>)
   return
 }
 
 // -----
 
-// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d3)>
-// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-
 // CHECK-LABEL: @conv2d_quant
 func @conv2d_quant(%arg0 : tensor<1x12x12x1xi8>, %arg1 : tensor<1024x3x3x1xi8>, %arg2 : tensor<1024xi32>) -> () {
-  // CHECK:   %[[INIT:.+]] = linalg.init_tensor [1, 10, 10, 1024]
-  // CHECK:   %[[BROADCAST:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1024xi32>) outs(%[[INIT]] : tensor<1x10x10x1024xi32>)
-  // CHECK:   ^bb0(%arg3: i32, %arg4: i32): 
-  // CHECK:     linalg.yield %arg3 : i32
-  // CHECK:   %[[C128:.+]] = constant -128 
-  // CHECK:   %[[C42:.+]] = constant 42 
-  // CHECK:   linalg.conv_2d_input_nhwc_filter_ohwi_poly_q {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1, %[[C128]], %[[C42]] : tensor<1x12x12x1xi8>, tensor<1024x3x3x1xi8>, i32, i32) outs(%1 : tensor<1x10x10x1024xi32>)
+  // CHECK: linalg.conv_2d_nhwc_hwcf_q
   %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {dilation = [1, 1], pad = [0, 0, 0, 0], quantization_info = {input_zp = -128 : i32, weight_zp = 42 : i32}, stride = [1, 1]} : (tensor<1x12x12x1xi8>, tensor<1024x3x3x1xi8>, tensor<1024xi32>) -> tensor<1x10x10x1024xi32>
   return
 }
@@ -1229,7 +1225,7 @@ func @depthwise_conv(%arg0 : tensor<1x7x5x3xf32>, %arg1 : tensor<3x1x3x11xf32>,
   // CHECK:   linalg.yield %arg3 : f32
   // CHECK: } -> tensor<1x5x5x33xf32>
   // CHECK: [[DBIAS:%.+]] = linalg.tensor_expand_shape [[BIAS]] {{\[}}[0], [1], [2], [3, 4]]
-  // CHECK: [[DEPTH:%.+]] = linalg.depthwise_conv_2D_nchw {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<1x7x5x3xf32>, tensor<3x1x3x11xf32>) outs([[DBIAS]] : tensor<1x5x5x3x11xf32>)
+  // CHECK: [[DEPTH:%.+]] = linalg.depthwise_conv2D_nchw {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<1x7x5x3xf32>, tensor<3x1x3x11xf32>) outs([[DBIAS]] : tensor<1x5x5x3x11xf32>)
   // CHECK: linalg.tensor_collapse_shape %3 {{\[}}[0], [1], [2], [3, 4]]
   %2 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) { pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1] } : (tensor<1x7x5x3xf32>, tensor<3x1x3x11xf32>, tensor<33xf32>)  -> (tensor<1x5x5x33xf32>)
   return
@@ -1260,8 +1256,8 @@ func @depthwise_conv_quant(%arg0 : tensor<1x12x12x4xi8>, %arg1 : tensor<3x3x4x12
 
 // CHECK-LABEL: @transpose_conv
 func @transpose_conv(%arg0 : tensor<1x12x12x2xf32>, %arg1 : tensor<4x3x3x2xf32>, %arg2 : tensor<4xf32>) -> () {
-  // CHECK: [[PAD:%.+]] = linalg.pad_tensor %arg0 low[0, 2, 2, 0] high[0, 2, 2, 0]
-  // CHECK: linalg.conv_2d_input_nhwc_filter_ohwi_poly {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins([[PAD]], {{%.+}} : tensor<1x16x16x2xf32>, tensor<4x3x3x2xf32>)
+  // CHECK: linalg.pad_tensor %arg0 low[0, 2, 2, 0] high[0, 2, 2, 0]
+  // CHECK: linalg.conv_2d_nhwc_hwcf
   %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {dilation = [1, 1], out_pad = [0, 0], out_shape = [1, 14, 14, 4], stride = [1, 1]} : (tensor<1x12x12x2xf32>, tensor<4x3x3x2xf32>, tensor<4xf32>) -> tensor<1x14x14x4xf32>
   return
 }
@@ -1271,7 +1267,7 @@ func @transpose_conv(%arg0 : tensor<1x12x12x2xf32>, %arg1 : tensor<4x3x3x2xf32>,
 // CHECK-LABEL: @transpose_conv_dilated
 func @transpose_conv_dilated(%arg0 : tensor<1x12x12x2xf32>, %arg1 : tensor<4x3x3x2xf32>, %arg2 : tensor<4xf32>) -> () {
   // CHECK: [[PAD:%.+]] = linalg.pad_tensor %arg0 low[0, 4, 4, 0] high[0, 4, 4, 0]
-  // CHECK: linalg.conv_2d_input_nhwc_filter_ohwi_poly {dilations = dense<2> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins([[PAD]], {{%.+}} : tensor<1x20x20x2xf32>, tensor<4x3x3x2xf32>)
+  // CHECK: linalg.conv_2d_nhwc_hwcf {dilations = dense<2> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins([[PAD]], {{%.+}} : tensor<1x20x20x2xf32>, tensor<3x3x2x4xf32>)
   %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {dilation = [2, 2], out_pad = [0, 0], out_shape = [1, 16, 16, 4], stride = [1, 1]} : (tensor<1x12x12x2xf32>, tensor<4x3x3x2xf32>, tensor<4xf32>) -> tensor<1x16x16x4xf32>
   return
 }
index 138d6c219dd2c26ff1b98ca5e791f0f334833715..d19b87c487f9012d8486afb28115335d4d8f645e 100644 (file)
@@ -1,19 +1,5 @@
 // RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s
 
-// CHECK-LABEL: func @conv_2d_input_nhwc_filter_ohwi_poly_q_tensor
-func @conv_2d_input_nhwc_filter_ohwi_poly_q_tensor(%input: tensor<2x4x5x3xi8>, %filter: tensor<2x2x2x3xi8>) -> tensor<2x3x4x2xi32> {
-  %zero = constant 0 : i32
-  %init = linalg.init_tensor [2, 3, 4, 2] : tensor<2x3x4x2xi32>
-  %fill = linalg.fill(%zero, %init) : i32, tensor<2x3x4x2xi32> -> tensor<2x3x4x2xi32>
-  %c128 = constant -128 : i32
-  %c42 = constant 42 : i32
-  %0 = linalg.conv_2d_input_nhwc_filter_ohwi_poly_q
-     { dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
-     ins(%input, %filter, %c128, %c42 : tensor<2x4x5x3xi8>, tensor<2x2x2x3xi8>, i32, i32)
-    outs(%fill : tensor<2x3x4x2xi32>) -> tensor<2x3x4x2xi32>
-  return %0 : tensor<2x3x4x2xi32>
-}
-
 // CHECK-LABEL: func @depthwise_conv_2d_input_nhwc_filter_hwcf_tensor
 func @depthwise_conv_2d_input_nhwc_filter_hwcf_tensor(%input: tensor<2x4x5x2xf32>, %filter: tensor<2x2x2x3xf32>) -> tensor<2x3x4x2x3xf32> {
   %zero = constant 0.000000e+00 : f32