[mlir][tosa] Move tosa canonicalizers to optional optimization pass
authorAaron DeBattista <aaron.debattista@arm.com>
Fri, 17 Dec 2021 07:24:47 +0000 (23:24 -0800)
committerRob Suderman <rob.suderman@gmail.com>
Fri, 17 Dec 2021 07:33:54 +0000 (23:33 -0800)
TOSA's canonicalizers that change dense operations should be moved to a
seperate optimization pass to avoid canonicalizing to operations not supported
for relevant backends.

Reviewed By: rsuderman

Differential Revision: https://reviews.llvm.org/D115890

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
mlir/lib/Dialect/Tosa/Transforms/TosaOptimization.cpp [new file with mode: 0644]
mlir/test/Dialect/Tosa/canonicalize.mlir
mlir/test/Dialect/Tosa/operation_optimization.mlir [new file with mode: 0644]

index 982880e..5978134 100644 (file)
@@ -118,8 +118,6 @@ def Tosa_Conv2DOp : Tosa_Op<"conv2d", [
   let builders = [Tosa_ConvOpQuantInfoBuilder];
 
   let verifier = [{ return verifyConvOp(*this); }];
-
-  let hasCanonicalizer = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -187,8 +185,6 @@ def Tosa_DepthwiseConv2DOp : Tosa_Op<"depthwise_conv2d", [
   let builders = [Tosa_ConvOpQuantInfoBuilder];
 
   let verifier = [{ return verifyConvOp(*this); }];
-
-  let hasCanonicalizer = 1;
 }
 
 //===----------------------------------------------------------------------===//
index 278402e..e94daca 100644 (file)
@@ -22,6 +22,7 @@ namespace tosa {
 std::unique_ptr<Pass> createTosaDecomposeTransposeConvPass();
 std::unique_ptr<Pass> createTosaInferShapesPass();
 std::unique_ptr<Pass> createTosaMakeBroadcastablePass();
+std::unique_ptr<Pass> createTosaOptimizationPass();
 std::unique_ptr<Pass> createTosaTestQuantUtilAPIPass();
 
 #define GEN_PASS_REGISTRATION
index 7d6af62..4a75482 100644 (file)
@@ -1,4 +1,4 @@
-//===-- Passes.td - TOSA optimization pass declarations ----*- tablegen -*-===//
+//===-- Passes.td - TOSA pass declarations ----*- tablegen -*-===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -6,7 +6,7 @@
 //
 //===----------------------------------------------------------------------===//
 //
-// This file declares the optimization passes for the TOSA Dialect in MLIR.
+// This file declares the passes for the TOSA Dialect in MLIR.
 //
 //===----------------------------------------------------------------------===//
 
@@ -58,4 +58,13 @@ def TosaMakeBroadcastable : FunctionPass<"tosa-make-broadcastable"> {
   let constructor = "createTosaMakeBroadcastablePass()";
 }
 
+def TosaOptimization : FunctionPass<"tosa-optimization"> {
+  let summary = "TOSA operation optimizations";
+  let description = [{
+    "Pass to perform optimizations on TOSA operations"
+  }];
+
+  let constructor = "createTosaOptimizationPass()";
+}
+
 #endif // MLIR_DIALECT_TOSA_TRANSFORMS_PASSES
index 9809e57..f61ce68 100644 (file)
@@ -423,197 +423,6 @@ void PadOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
   results.insert<MaterializePadValue>(context);
 }
 
-struct Conv2DIsFullyConnected : public OpRewritePattern<tosa::Conv2DOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(tosa::Conv2DOp op,
-                                PatternRewriter &rewriter) const override {
-    Value input = op.input();
-    Value weight = op.weight();
-    ShapedType inputType = input.getType().cast<ShapedType>();
-    ShapedType weightType = weight.getType().cast<ShapedType>();
-    ShapedType resultType = op.getType().cast<ShapedType>();
-
-    if (!inputType.hasStaticShape() || !weightType.hasRank()) {
-      return failure();
-    }
-
-    for (Attribute pad : op.pad().getValue()) {
-      if (!pad.cast<IntegerAttr>().getValue().isZero()) {
-        return failure();
-      }
-    }
-
-    // Stride must be 1 for this optimization.
-    for (Attribute stride : op.stride().getValue()) {
-      if (!stride.cast<IntegerAttr>().getValue().isOne()) {
-        return failure();
-      }
-    }
-
-    // Only works for a 1x1 kernel.
-    ArrayRef<int64_t> weightShape = weightType.getShape();
-    if (weightShape[1] != 1 || weightShape[2] != 1) {
-      return failure();
-    }
-
-    // Reshape input to [N,IH,IW,IC] -> [N * IH * IW, IC].
-    ArrayRef<int64_t> inputShape = inputType.getShape();
-    llvm::SmallVector<int64_t, 2> revisedInputShape{
-        inputShape[0] * inputShape[1] * inputShape[2], inputShape[3]};
-    auto revisedInputShapeType =
-        RankedTensorType::get(revisedInputShape, inputType.getElementType());
-    auto reshapedInput = rewriter
-                             .create<tosa::ReshapeOp>(
-                                 op.getLoc(), revisedInputShapeType, input,
-                                 rewriter.getI64ArrayAttr(revisedInputShape))
-                             .getResult();
-
-    // Reshape kernel to [OC,KH,KW,IC] -> [OC, IC].
-    llvm::SmallVector<int64_t, 2> revisedWeightShape{weightShape[0],
-                                                     weightShape[3]};
-    auto revisedWeightShapeType =
-        RankedTensorType::get(revisedWeightShape, weightType.getElementType());
-    auto reshapedWeight = rewriter
-                              .create<tosa::ReshapeOp>(
-                                  op.getLoc(), revisedWeightShapeType, weight,
-                                  rewriter.getI64ArrayAttr(revisedWeightShape))
-                              .getResult();
-
-    // Perform a fully connected network over the reshaped input and weight.
-    llvm::SmallVector<int64_t, 2> fullyConnectedShape{
-        inputShape[0] * inputShape[1] * inputShape[2], weightShape[0]};
-    auto fullyConnectedShapeType =
-        RankedTensorType::get(fullyConnectedShape, resultType.getElementType());
-
-    Value fullyConnectedValue;
-    if (op.quantization_info()) {
-      fullyConnectedValue =
-          rewriter
-              .create<tosa::FullyConnectedOp>(
-                  op.getLoc(), fullyConnectedShapeType, reshapedInput,
-                  reshapedWeight, op.bias(), op.quantization_info().getValue())
-              .getResult();
-    } else {
-      fullyConnectedValue = rewriter
-                                .create<tosa::FullyConnectedOp>(
-                                    op.getLoc(), fullyConnectedShapeType,
-                                    reshapedInput, reshapedWeight, op.bias())
-                                .getResult();
-    }
-
-    // Reshape output to [N, IH, IW, OC].
-    llvm::SmallVector<int64_t, 4> outputShape{inputShape[0], inputShape[1],
-                                              inputShape[2], weightShape[0]};
-    rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
-        op, resultType, fullyConnectedValue,
-        rewriter.getI64ArrayAttr(outputShape));
-    return success();
-  }
-};
-
-void Conv2DOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
-                                           MLIRContext *context) {
-  results.insert<Conv2DIsFullyConnected>(context);
-}
-
-struct DepthwiseConv2DMulOptimization
-    : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(tosa::DepthwiseConv2DOp op,
-                                PatternRewriter &rewriter) const override {
-    Value input = op.input();
-    Value weight = op.weight();
-    ShapedType inputType = input.getType().cast<ShapedType>();
-    ShapedType weightType = weight.getType().cast<ShapedType>();
-    ShapedType resultType = op.output().getType().cast<ShapedType>();
-    Type inputEType = inputType.getElementType();
-
-    if (!(inputType.hasStaticShape() && weightType.hasStaticShape() &&
-          resultType.hasStaticShape())) {
-      return failure();
-    }
-
-    // Quantization information needs to still be performed.
-    if (op.quantization_info() || !inputEType.isa<FloatType>()) {
-      return failure();
-    }
-
-    // Stride must be 1 for this optimization.
-    for (Attribute stride : op.stride().getValue()) {
-      if (!stride.cast<IntegerAttr>().getValue().isOne()) {
-        return failure();
-      }
-    }
-
-    // Only works for a 1x1 kernel.
-    ArrayRef<int64_t> weightShape = weightType.getShape();
-    if (weightShape[0] != 1 || weightShape[1] != 1) {
-      return failure();
-    }
-
-    // Reshape input to [N, H, W, C] -> [N, H, W, C, 1].
-    ArrayRef<int64_t> inputShape = inputType.getShape();
-    llvm::SmallVector<int64_t, 2> revisedInputShape{
-        inputShape[0], inputShape[1], inputShape[2], inputShape[3], 1};
-    auto revisedInputShapeType = RankedTensorType::get(
-        revisedInputShape,
-        input.getType().dyn_cast<RankedTensorType>().getElementType());
-    auto reshapedInput = rewriter
-                             .create<tosa::ReshapeOp>(
-                                 op.getLoc(), revisedInputShapeType, input,
-                                 rewriter.getI64ArrayAttr(revisedInputShape))
-                             .getResult();
-
-    // Reshape kernel to [KH, KW, C, M] -> [1, 1, 1, C, M].
-    llvm::SmallVector<int64_t, 2> revisedWeightShape{1, 1, 1, weightShape[2],
-                                                     weightShape[3]};
-    auto revisedWeightShapeType = RankedTensorType::get(
-        revisedWeightShape,
-        weight.getType().dyn_cast<RankedTensorType>().getElementType());
-    auto reshapedWeight = rewriter
-                              .create<tosa::ReshapeOp>(
-                                  op.getLoc(), revisedWeightShapeType, weight,
-                                  rewriter.getI64ArrayAttr(revisedWeightShape))
-                              .getResult();
-
-    // Perform an elementwise mul over the reshaped input and weight.
-    llvm::SmallVector<int64_t, 2> mulShape{inputShape[0], inputShape[1],
-                                           inputShape[2], inputShape[3],
-                                           weightShape[3]};
-    auto mulShapeType = RankedTensorType::get(
-        mulShape,
-        weight.getType().dyn_cast<RankedTensorType>().getElementType());
-    Value mulValue =
-        rewriter
-            .create<tosa::MulOp>(op.getLoc(), mulShapeType, reshapedInput,
-                                 reshapedWeight, /*shift=*/0)
-            .getResult();
-
-    // Reshape output to [N, H, W, C * M].
-    auto outputShape = op.output().getType().cast<ShapedType>().getShape();
-    auto outputShapeType = RankedTensorType::get(
-        outputShape,
-        input.getType().dyn_cast<RankedTensorType>().getElementType());
-    auto outputValue =
-        rewriter.create<tosa::ReshapeOp>(op.getLoc(), outputShapeType, mulValue,
-                                         rewriter.getI64ArrayAttr(outputShape));
-
-    // Add in the bias.
-    rewriter
-        .replaceOpWithNewOp<tosa::AddOp>(op, outputShapeType, outputValue,
-                                         op.bias())
-        .getResult();
-    return success();
-  }
-};
-
-void DepthwiseConv2DOp::getCanonicalizationPatterns(
-    OwningRewritePatternList &results, MLIRContext *context) {
-  results.insert<DepthwiseConv2DMulOptimization>(context);
-}
-
 struct MaxPool2dIsNoOp : public OpRewritePattern<tosa::MaxPool2dOp> {
   using OpRewritePattern::OpRewritePattern;
 
@@ -747,7 +556,8 @@ OpFoldResult TransposeOp::fold(ArrayRef<Attribute> operands) {
 // TOSA Operator Verifiers.
 //===----------------------------------------------------------------------===//
 
-template <typename T> static LogicalResult verifyConvOp(T op) {
+template <typename T>
+static LogicalResult verifyConvOp(T op) {
   // All TOSA conv ops have an input() and weight().
   auto inputType = op.input().getType().template dyn_cast<RankedTensorType>();
   auto weightType = op.weight().getType().template dyn_cast<RankedTensorType>();
index b5e90bb..016575f 100644 (file)
@@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRTosaTransforms
   TosaDecomposeTransposeConv.cpp
   TosaInferShapes.cpp
   TosaMakeBroadcastable.cpp
+  TosaOptimization.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tosa/Transforms
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaOptimization.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaOptimization.cpp
new file mode 100644 (file)
index 0000000..61b1618
--- /dev/null
@@ -0,0 +1,243 @@
+//===- TosaOptimization.cpp ------------------------------------------===//\r
+//\r
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.\r
+// See https://llvm.org/LICENSE.txt for license information.\r
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception\r
+//\r
+//===----------------------------------------------------------------------===//\r
+//\r
+// Pass to perform optimizations on TOSA operations\r
+//\r
+//===----------------------------------------------------------------------===//\r
+\r
+#include "mlir/Analysis/DataFlowAnalysis.h"\r
+#include "mlir/Dialect/StandardOps/IR/Ops.h"\r
+#include "mlir/Dialect/Tensor/IR/Tensor.h"\r
+#include "mlir/Dialect/Tosa/IR/TosaOps.h"\r
+#include "mlir/Dialect/Tosa/Transforms/PassDetail.h"\r
+#include "mlir/Dialect/Tosa/Transforms/Passes.h"\r
+#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"\r
+#include "mlir/IR/BlockAndValueMapping.h"\r
+#include "mlir/IR/Builders.h"\r
+#include "mlir/IR/BuiltinOps.h"\r
+#include "mlir/IR/Matchers.h"\r
+#include "mlir/Pass/Pass.h"\r
+#include "mlir/Transforms/DialectConversion.h"\r
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"\r
+#include "llvm/Support/FormatVariadic.h"\r
+\r
+using namespace mlir;\r
+using namespace mlir::tosa;\r
+\r
+#define PASS_NAME "tosa-optimization"\r
+#define DEBUG_TYPE PASS_NAME\r
+\r
+namespace {\r
+\r
+struct Conv2DIsFullyConnected : public OpRewritePattern<tosa::Conv2DOp> {\r
+  explicit Conv2DIsFullyConnected(MLIRContext *context)\r
+      : OpRewritePattern(context) {}\r
+\r
+  LogicalResult matchAndRewrite(tosa::Conv2DOp op,\r
+                                PatternRewriter &rewriter) const {\r
+    Value input = op.input();\r
+    Value weight = op.weight();\r
+    ShapedType inputType = input.getType().cast<ShapedType>();\r
+    ShapedType weightType = weight.getType().cast<ShapedType>();\r
+    ShapedType resultType = op.getType().cast<ShapedType>();\r
+\r
+    if (!inputType.hasStaticShape() || !weightType.hasRank()) {\r
+      return failure();\r
+    }\r
+\r
+    // Stride must be 1 for this optimization.\r
+    for (Attribute stride : op.stride().getValue()) {\r
+      if (!stride.cast<IntegerAttr>().getValue().isOne()) {\r
+        return failure();\r
+      }\r
+    }\r
+\r
+    // Only works for a 1x1 kernel.\r
+    ArrayRef<int64_t> weightShape = weightType.getShape();\r
+    if (weightShape[1] != 1 || weightShape[2] != 1) {\r
+      return failure();\r
+    }\r
+\r
+    // Reshape input to [N,IH,IW,IC] -> [N * IH * IW, IC].\r
+    ArrayRef<int64_t> inputShape = inputType.getShape();\r
+    llvm::SmallVector<int64_t, 2> revisedInputShape{\r
+        inputShape[0] * inputShape[1] * inputShape[2], inputShape[3]};\r
+    auto revisedInputShapeType = RankedTensorType::get(\r
+        revisedInputShape,\r
+        input.getType().dyn_cast<RankedTensorType>().getElementType());\r
+    auto reshapedInput = rewriter\r
+                             .create<tosa::ReshapeOp>(\r
+                                 op.getLoc(), revisedInputShapeType, input,\r
+                                 rewriter.getI64ArrayAttr(revisedInputShape))\r
+                             .getResult();\r
+\r
+    // Reshape kernel to [OC,KH,KW,IC] -> [OC, IC].\r
+    llvm::SmallVector<int64_t, 2> revisedWeightShape{weightShape[0],\r
+                                                     weightShape[3]};\r
+    auto revisedWeightShapeType = RankedTensorType::get(\r
+        revisedWeightShape,\r
+        weight.getType().dyn_cast<RankedTensorType>().getElementType());\r
+    auto reshapedWeight = rewriter\r
+                              .create<tosa::ReshapeOp>(\r
+                                  op.getLoc(), revisedWeightShapeType, weight,\r
+                                  rewriter.getI64ArrayAttr(revisedWeightShape))\r
+                              .getResult();\r
+\r
+    // Perform a fully connected network over the reshaped input and weight.\r
+    llvm::SmallVector<int64_t, 2> fullyConnectedShape{\r
+        inputShape[0] * inputShape[1] * inputShape[2], weightShape[0]};\r
+    auto fullyConnectedShapeType = RankedTensorType::get(\r
+        fullyConnectedShape,\r
+        resultType.dyn_cast<ShapedType>().getElementType());\r
+\r
+    Value fullyConnectedValue;\r
+    if (op.quantization_info()) {\r
+      fullyConnectedValue =\r
+          rewriter\r
+              .create<tosa::FullyConnectedOp>(\r
+                  op.getLoc(), fullyConnectedShapeType, reshapedInput,\r
+                  reshapedWeight, op.bias(), op.quantization_info().getValue())\r
+              .getResult();\r
+    } else {\r
+      fullyConnectedValue = rewriter\r
+                                .create<tosa::FullyConnectedOp>(\r
+                                    op.getLoc(), fullyConnectedShapeType,\r
+                                    reshapedInput, reshapedWeight, op.bias())\r
+                                .getResult();\r
+    }\r
+\r
+    // Reshape output to [N, IH, IW, OC].\r
+    llvm::SmallVector<int64_t, 4> outputShape{inputShape[0], inputShape[1],\r
+                                              inputShape[2], weightShape[0]};\r
+    rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(\r
+        op, resultType, fullyConnectedValue,\r
+        rewriter.getI64ArrayAttr(outputShape));\r
+    return success();\r
+  }\r
+};\r
+\r
+struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {\r
+  explicit DepthwiseConv2DIsMul(MLIRContext *context)\r
+      : OpRewritePattern(context) {}\r
+\r
+  LogicalResult matchAndRewrite(tosa::DepthwiseConv2DOp op,\r
+                                PatternRewriter &rewriter) const {\r
+    Value input = op.input();\r
+    Value weight = op.weight();\r
+    ShapedType inputType = input.getType().cast<ShapedType>();\r
+    ShapedType weightType = weight.getType().cast<ShapedType>();\r
+    ShapedType resultType = op.output().getType().cast<ShapedType>();\r
+    Type inputEType = inputType.getElementType();\r
+\r
+    if (!(inputType.hasStaticShape() && weightType.hasStaticShape() &&\r
+          resultType.hasStaticShape())) {\r
+      return failure();\r
+    }\r
+\r
+    // Quantization information needs to still be performed.\r
+    if (op.quantization_info() || !inputEType.isa<FloatType>()) {\r
+      return failure();\r
+    }\r
+\r
+    // Stride must be 1 for this optimization.\r
+    for (Attribute stride : op.stride().getValue()) {\r
+      if (!stride.cast<IntegerAttr>().getValue().isOne()) {\r
+        return failure();\r
+      }\r
+    }\r
+\r
+    // Only works for a 1x1 kernel.\r
+    ArrayRef<int64_t> weightShape = weightType.getShape();\r
+    if (weightShape[0] != 1 || weightShape[1] != 1) {\r
+      return failure();\r
+    }\r
+\r
+    // Reshape input to [N, H, W, C] -> [N, H, W, C, 1].\r
+    ArrayRef<int64_t> inputShape = inputType.getShape();\r
+    llvm::SmallVector<int64_t, 2> revisedInputShape{\r
+        inputShape[0], inputShape[1], inputShape[2], inputShape[3], 1};\r
+    auto revisedInputShapeType = RankedTensorType::get(\r
+        revisedInputShape,\r
+        input.getType().dyn_cast<RankedTensorType>().getElementType());\r
+    auto reshapedInput = rewriter\r
+                             .create<tosa::ReshapeOp>(\r
+                                 op.getLoc(), revisedInputShapeType, input,\r
+                                 rewriter.getI64ArrayAttr(revisedInputShape))\r
+                             .getResult();\r
+\r
+    // Reshape kernel to [KH, KW, C, M] -> [1, 1, 1, C, M].\r
+    llvm::SmallVector<int64_t, 2> revisedWeightShape{1, 1, 1, weightShape[2],\r
+                                                     weightShape[3]};\r
+    auto revisedWeightShapeType = RankedTensorType::get(\r
+        revisedWeightShape,\r
+        weight.getType().dyn_cast<RankedTensorType>().getElementType());\r
+    auto reshapedWeight = rewriter\r
+                              .create<tosa::ReshapeOp>(\r
+                                  op.getLoc(), revisedWeightShapeType, weight,\r
+                                  rewriter.getI64ArrayAttr(revisedWeightShape))\r
+                              .getResult();\r
+\r
+    // Perform an elementwise mul over the reshaped input and weight.\r
+    llvm::SmallVector<int64_t, 2> mulShape{inputShape[0], inputShape[1],\r
+                                           inputShape[2], inputShape[3],\r
+                                           weightShape[3]};\r
+    auto mulShapeType = RankedTensorType::get(\r
+        mulShape,\r
+        weight.getType().dyn_cast<RankedTensorType>().getElementType());\r
+    Value mulValue =\r
+        rewriter\r
+            .create<tosa::MulOp>(op.getLoc(), mulShapeType, reshapedInput,\r
+                                 reshapedWeight, /*shift=*/0)\r
+            .getResult();\r
+\r
+    // Reshape output to [N, H, W, C * M].\r
+    auto outputShape = op.output().getType().cast<ShapedType>().getShape();\r
+    auto outputShapeType = RankedTensorType::get(\r
+        outputShape,\r
+        input.getType().dyn_cast<RankedTensorType>().getElementType());\r
+    auto outputValue =\r
+        rewriter.create<tosa::ReshapeOp>(op.getLoc(), outputShapeType, mulValue,\r
+                                         rewriter.getI64ArrayAttr(outputShape));\r
+\r
+    // Add in the bias.\r
+    rewriter\r
+        .replaceOpWithNewOp<tosa::AddOp>(op, outputShapeType, outputValue,\r
+                                         op.bias())\r
+        .getResult();\r
+    return success();\r
+  }\r
+};\r
+\r
+class TosaOptimization : public PassWrapper<TosaOptimization, FunctionPass> {\r
+public:\r
+  explicit TosaOptimization() {}\r
+  void runOnFunction() override;\r
+\r
+  StringRef getArgument() const final { return PASS_NAME; }\r
+  StringRef getDescription() const final {\r
+    return "Applies TOSA Operation Optimizations";\r
+  }\r
+};\r
+\r
+void TosaOptimization::runOnFunction() {\r
+  OwningRewritePatternList patterns(&getContext());\r
+\r
+  patterns.insert<Conv2DIsFullyConnected>(&getContext());\r
+  patterns.insert<DepthwiseConv2DIsMul>(&getContext());\r
+\r
+  auto func = getFunction();\r
+  if (applyPatternsAndFoldGreedily(func, std::move(patterns)).failed()) {\r
+    signalPassFailure();\r
+  }\r
+}\r
+\r
+} // namespace\r
+\r
+std::unique_ptr<Pass> mlir::tosa::createTosaOptimizationPass() {\r
+  return std::make_unique<TosaOptimization>();\r
+}\r
index fa5304b..1f2f9a1 100644 (file)
@@ -68,52 +68,6 @@ func @concat_fold_cast(%arg0: tensor<?x1xf32>) -> tensor<?x?xf32> {
 
 // -----
 
-// CHECK-LABEL: @conv2d_as_fully_connected
-func @conv2d_as_fully_connected(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<3x1x1x2xf32>, %arg2: tensor<3xf32>) -> tensor<4x10x10x3xf32> {
-  // CHECK-NOT: "tosa.conv2d"
-  // CHECK: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = [400, 2]}
-  // CHECK-SAME: -> tensor<400x2xf32>
-  // CHECK: %[[VAR1:.*]] = "tosa.reshape"(%arg1) {new_shape = [3, 2]}
-  // CHECK-SAME: -> tensor<3x2xf32>
-  // CHECK: %[[VAR2:.*]] = "tosa.fully_connected"(%[[VAR0]], %[[VAR1]], %arg2)
-  // CHECK-SAME: -> tensor<400x3xf32>
-  // CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = [4, 10, 10, 3]}
-  // CHECK-SAME: -> tensor<4x10x10x3xf32>
-  // CHECK: return %[[VAR3]]
-  %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1]} : (tensor<4x10x10x2xf32>, tensor<3x1x1x2xf32>, tensor<3xf32>) -> tensor<4x10x10x3xf32>
-  return %0 : tensor<4x10x10x3xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @conv2d_as_fully_connected_quant
-func @conv2d_as_fully_connected_quant(%arg0: tensor<4x10x10x2xi8>, %arg1: tensor<3x1x1x2xi8>, %arg2: tensor<3xi32>) -> tensor<4x10x10x3xi32> {
-  // CHECK-NOT: "tosa.conv2d"
-  // CHECK: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = [400, 2]}
-  // CHECK-SAME: -> tensor<400x2xi8>
-  // CHECK: %[[VAR1:.*]] = "tosa.reshape"(%arg1) {new_shape = [3, 2]}
-  // CHECK-SAME: -> tensor<3x2xi8>
-  // CHECK: %[[VAR2:.*]] = "tosa.fully_connected"(%[[VAR0]], %[[VAR1]], %arg2)
-  // CHECK-SAME: quantization_info = {input_zp = 42 : i32, weight_zp = 24 : i32}
-  // CHECK-SAME: -> tensor<400x3xi32>
-  // CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = [4, 10, 10, 3]}
-  // CHECK-SAME: -> tensor<4x10x10x3xi32>
-  // CHECK: return %[[VAR3]]
-  %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1], quantization_info = {input_zp = 42 : i32, weight_zp = 24 : i32}} : (tensor<4x10x10x2xi8>, tensor<3x1x1x2xi8>, tensor<3xi32>) -> tensor<4x10x10x3xi32>
-  return %0 : tensor<4x10x10x3xi32>
-}
-
-// -----
-
-// CHECK-LABEL: @conv2d_padded
-func @conv2d_padded(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<3x1x1x2xf32>, %arg2: tensor<3xf32>) -> tensor<4x12x12x3xf32> {
-  // CHECK: "tosa.conv2d"
-  %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {pad = [1, 1, 1, 1], stride = [1, 1], dilation = [1, 1]} : (tensor<4x10x10x2xf32>, tensor<3x1x1x2xf32>, tensor<3xf32>) -> tensor<4x12x12x3xf32>
-  return %0 : tensor<4x12x12x3xf32>
-}
-
-// -----
-
 // CHECK-LABEL: @conv2d_stride_2
 func @conv2d_stride_2(%arg0: tensor<4x10x10x2xf32>) -> tensor<4x10x10x3xf32> {
   // CHECK: "tosa.conv2d"
@@ -136,35 +90,6 @@ func @conv2d_weight_2x2(%arg0: tensor<4x10x10x1xf32>) -> tensor<4x10x10x1xf32> {
 
 // -----
 
-// CHECK-LABEL: @depthwise_conv2d_as_mul
-func @depthwise_conv2d_as_mul(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<1x1x2x3xf32>, %arg2: tensor<6xf32>) -> tensor<4x10x10x6xf32> {
-  // CHECK-NOT: "tosa.depthwise_conv2d"
-  // CHECK: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = [4, 10, 10, 2, 1]}
-  // CHECK-SAME: -> tensor<4x10x10x2x1xf32>
-  // CHECK: %[[VAR1:.*]] = "tosa.reshape"(%arg1) {new_shape = [1, 1, 1, 2, 3]}
-  // CHECK-SAME: -> tensor<1x1x1x2x3xf32>
-  // CHECK: %[[VAR2:.*]] = "tosa.mul"(%[[VAR0]], %[[VAR1]])
-  // CHECK-SAME: -> tensor<4x10x10x2x3xf32>
-  // CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = [4, 10, 10, 6]}
-  // CHECK-SAME: -> tensor<4x10x10x6xf32>
-  // CHECK: %[[VAR4:.*]] = "tosa.add"(%[[VAR3]], %arg2)
-  // CHECK-SAME: -> tensor<4x10x10x6xf32>
-  // CHECK: return %[[VAR4]]
-  %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1]} : (tensor<4x10x10x2xf32>, tensor<1x1x2x3xf32>, tensor<6xf32>) -> tensor<4x10x10x6xf32>
-  return %0 : tensor<4x10x10x6xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @depthwise_conv2d_as_mul_q
-func @depthwise_conv2d_as_mul_q(%arg0: tensor<4x10x10x2xi8>, %arg1: tensor<1x1x2x3xi8>, %arg2: tensor<6xi32>) -> tensor<4x10x10x6xi32> {
-  // CHECK: "tosa.depthwise_conv2d"
-  %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1], quantization_info = {input_zp = 0 : i32, weight_zp = 0 : i32}} : (tensor<4x10x10x2xi8>, tensor<1x1x2x3xi8>, tensor<6xi32>) -> tensor<4x10x10x6xi32>
-  return %0 : tensor<4x10x10x6xi32>
-}
-
-// -----
-
 // CHECK-LABEL: @depthwise_conv2d_stride_2
 func @depthwise_conv2d_stride_2(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<1x1x2x3xf32>, %arg2: tensor<6xf32>) -> tensor<4x10x10x6xf32> {
   // CHECK: "tosa.depthwise_conv2d"
diff --git a/mlir/test/Dialect/Tosa/operation_optimization.mlir b/mlir/test/Dialect/Tosa/operation_optimization.mlir
new file mode 100644 (file)
index 0000000..aa65b96
--- /dev/null
@@ -0,0 +1,69 @@
+// RUN: mlir-opt --split-input-file --tosa-optimization %s | FileCheck %s\r
+\r
+// -----\r
+\r
+// CHECK-LABEL: @conv2d_as_fully_connected\r
+func @conv2d_as_fully_connected(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<3x1x1x2xf32>, %arg2: tensor<3xf32>) -> tensor<4x10x10x3xf32> {\r
+  // CHECK-NOT: "tosa.conv2d"\r
+  // CHECK: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = [400, 2]}\r
+  // CHECK-SAME: -> tensor<400x2xf32>\r
+  // CHECK: %[[VAR1:.*]] = "tosa.reshape"(%arg1) {new_shape = [3, 2]}\r
+  // CHECK-SAME: -> tensor<3x2xf32>\r
+  // CHECK: %[[VAR2:.*]] = "tosa.fully_connected"(%[[VAR0]], %[[VAR1]], %arg2)\r
+  // CHECK-SAME: -> tensor<400x3xf32>\r
+  // CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = [4, 10, 10, 3]}\r
+  // CHECK-SAME: -> tensor<4x10x10x3xf32>\r
+  // CHECK: return %[[VAR3]]\r
+  %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1]} : (tensor<4x10x10x2xf32>, tensor<3x1x1x2xf32>, tensor<3xf32>) -> tensor<4x10x10x3xf32>\r
+  return %0 : tensor<4x10x10x3xf32>\r
+}\r
+\r
+// -----\r
+\r
+// CHECK-LABEL: @conv2d_as_fully_connected_quant\r
+func @conv2d_as_fully_connected_quant(%arg0: tensor<4x10x10x2xi8>, %arg1: tensor<3x1x1x2xi8>, %arg2: tensor<3xi32>) -> tensor<4x10x10x3xi32> {\r
+  // CHECK-NOT: "tosa.conv2d"\r
+  // CHECK: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = [400, 2]}\r
+  // CHECK-SAME: -> tensor<400x2xi8>\r
+  // CHECK: %[[VAR1:.*]] = "tosa.reshape"(%arg1) {new_shape = [3, 2]}\r
+  // CHECK-SAME: -> tensor<3x2xi8>\r
+  // CHECK: %[[VAR2:.*]] = "tosa.fully_connected"(%[[VAR0]], %[[VAR1]], %arg2)\r
+  // CHECK-SAME: quantization_info = {input_zp = 42 : i32, weight_zp = 24 : i32}\r
+  // CHECK-SAME: -> tensor<400x3xi32>\r
+  // CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = [4, 10, 10, 3]}\r
+  // CHECK-SAME: -> tensor<4x10x10x3xi32>\r
+  // CHECK: return %[[VAR3]]\r
+  %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1], quantization_info = {input_zp = 42 : i32, weight_zp = 24 : i32}} : (tensor<4x10x10x2xi8>, tensor<3x1x1x2xi8>, tensor<3xi32>) -> tensor<4x10x10x3xi32>\r
+  return %0 : tensor<4x10x10x3xi32>\r
+}\r
+\r
+// -----\r
+\r
+// CHECK-LABEL: @depthwise_conv2d_as_mul\r
+func @depthwise_conv2d_as_mul(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<1x1x2x3xf32>, %arg2: tensor<6xf32>) -> tensor<4x10x10x6xf32> {\r
+  // CHECK-NOT: "tosa.depthwise_conv2d"\r
+  // CHECK: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = [4, 10, 10, 2, 1]}\r
+  // CHECK-SAME: -> tensor<4x10x10x2x1xf32>\r
+  // CHECK: %[[VAR1:.*]] = "tosa.reshape"(%arg1) {new_shape = [1, 1, 1, 2, 3]}\r
+  // CHECK-SAME: -> tensor<1x1x1x2x3xf32>\r
+  // CHECK: %[[VAR2:.*]] = "tosa.mul"(%[[VAR0]], %[[VAR1]])\r
+  // CHECK-SAME: -> tensor<4x10x10x2x3xf32>\r
+  // CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = [4, 10, 10, 6]}\r
+  // CHECK-SAME: -> tensor<4x10x10x6xf32>\r
+  // CHECK: %[[VAR4:.*]] = "tosa.add"(%[[VAR3]], %arg2)\r
+  // CHECK-SAME: -> tensor<4x10x10x6xf32>\r
+  // CHECK: return %[[VAR4]]\r
+  %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1]} : (tensor<4x10x10x2xf32>, tensor<1x1x2x3xf32>, tensor<6xf32>) -> tensor<4x10x10x6xf32>\r
+  return %0 : tensor<4x10x10x6xf32>\r
+}\r
+\r
+// -----\r
+\r
+// CHECK-LABEL: @depthwise_conv2d_as_mul_q\r
+func @depthwise_conv2d_as_mul_q(%arg0: tensor<4x10x10x2xi8>, %arg1: tensor<1x1x2x3xi8>, %arg2: tensor<6xi32>) -> tensor<4x10x10x6xi32> {\r
+  // CHECK: "tosa.depthwise_conv2d"\r
+  %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1], quantization_info = {input_zp = 0 : i32, weight_zp = 0 : i32}} : (tensor<4x10x10x2xi8>, tensor<1x1x2x3xi8>, tensor<6xi32>) -> tensor<4x10x10x6xi32>\r
+  return %0 : tensor<4x10x10x6xi32>\r
+}\r
+\r
+// -----\r