From: natashaknk Date: Wed, 12 Jan 2022 22:10:27 +0000 (-0800) Subject: [tosa][mlir] Support dynamic batch dimension for ops where the batch dim is explicit X-Git-Tag: upstream/15.0.7~20721 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=310e9636caeb2f3f02f3cc5bc2f180248061bbe5;p=platform%2Fupstream%2Fllvm.git [tosa][mlir] Support dynamic batch dimension for ops where the batch dim is explicit Dynamic batch for rescale, gather, max_pool, avg_pool, conv2D and depthwise_conv2D. Split helper functions into a separate header file. Reviewed By: rsuderman Differential Revision: https://reviews.llvm.org/D117031 --- diff --git a/mlir/include/mlir/Dialect/Tosa/Utils/CoversionUtils.h b/mlir/include/mlir/Dialect/Tosa/Utils/CoversionUtils.h new file mode 100644 index 0000000..bbf1498 --- /dev/null +++ b/mlir/include/mlir/Dialect/Tosa/Utils/CoversionUtils.h @@ -0,0 +1,84 @@ +//===- ConversionUtils.h - Helper functions for tosa conversion -*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Utility functions for TOSA lowering +// +//===----------------------------------------------------------------------===// + +#ifndef DIALECT_TOSA_UTILS_COVERSION_UTILS_H_ +#define DIALECT_TOSA_UTILS_COVERSION_UTILS_H_ + +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/StructuredOpsUtils.h" +#include "mlir/IR/PatternMatch.h" + +namespace mlir { +namespace tosa { + +// Creates a SmallVector of Stringrefs for N parallel loops +SmallVector getNParallelLoopsAttrs(unsigned nParallelLoops); + +// Takes a vector of values and condenses them to a vector with no gaps. +SmallVector condenseValues(const SmallVector &values); + +// Takes the parameters for a clamp and turns it into a series of ops. +template +mlir::SelectOp clampHelper(Location loc, Value arg, arith::ConstantOp min, + arith::ConstantOp max, P pred, OpBuilder &rewriter) { + auto smallerThanMin = rewriter.create(loc, pred, arg, min); + auto minOrArg = + rewriter.create(loc, smallerThanMin, min, arg); + auto largerThanMax = rewriter.create(loc, pred, max, arg); + return rewriter.create(loc, largerThanMax, max, minOrArg); +} + +// Returns the values in an attribute as an array of values. +template +void getValuesFromIntArrayAttribute(ArrayAttr attr, + SmallVector &arrayValues) { + for (Attribute val : attr.getValue()) { + arrayValues.push_back(val.cast().getValue().getSExtValue()); + } +} + +// Checks for a dynamic batch dim in any of the passed parameters of an op. +// The batch dimention must be #0 and the rest of the dimensions must be static. +template +Optional> checkHasDynamicBatchDims(PatternRewriter &rewriter, + Op op, + ArrayRef params) { + SmallVector dynTypes; + SmallVector dynamicDims; + for (const Value ¶m : params) { + auto paramTy = param.getType().cast(); + if (!paramTy.hasStaticShape()) + dynTypes.push_back(paramTy); + } + + if (dynTypes.empty()) + return dynamicDims; + + for (const ShapedType &dynTy : dynTypes) { + if (llvm::any_of(dynTy.getShape().drop_front(), ShapedType::isDynamic)) { + (void)rewriter.notifyMatchFailure( + op, "input can only be dynamic for batch size"); + return llvm::None; + } + } + + dynamicDims.push_back( + rewriter.create(op->getLoc(), params[0], 0)); + return dynamicDims; +} + +} // namespace tosa +} // namespace mlir + +#endif // DIALECT_TOSA_UTILS_COVERSION_UTILS_H_ diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp index e9a5c37..1fab060 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -18,6 +18,7 @@ #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/Dialect/Tosa/Utils/CoversionUtils.h" #include "mlir/Dialect/Utils/ReshapeOpsUtils.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" @@ -27,10 +28,7 @@ #include using namespace mlir; - -static SmallVector getNParallelLoopsAttrs(unsigned nParallelLoops) { - return SmallVector(nParallelLoops, getParallelIteratorTypeName()); -} +using namespace mlir::tosa; template static arith::ConstantOp @@ -42,33 +40,6 @@ createConstFromIntAttribute(Operation *op, const std::string &attrName, op->getLoc(), IntegerAttr::get(requiredAttrType, castedN)); } -template -static void getValuesFromIntArrayAttribute(ArrayAttr attr, - SmallVector &arrayValues) { - for (Attribute val : attr.getValue()) { - arrayValues.push_back(val.cast().getValue().getSExtValue()); - } -} - -template -static mlir::SelectOp clampHelper(Location loc, Value arg, - arith::ConstantOp min, arith::ConstantOp max, - P pred, OpBuilder &rewriter) { - auto smallerThanMin = rewriter.create(loc, pred, arg, min); - auto minOrArg = - rewriter.create(loc, smallerThanMin, min, arg); - auto largerThanMax = rewriter.create(loc, pred, max, arg); - return rewriter.create(loc, largerThanMax, max, minOrArg); -} - -static SmallVector filterDynamicDims(const SmallVector &dynDims) { - SmallVector filteredDims; - for (auto dim : dynDims) - if (dim) - filteredDims.push_back(dim); - return filteredDims; -} - static Value createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args, ArrayRef resultTypes, @@ -665,7 +636,7 @@ elementwiseMatchAndRewriteHelper(Operation *operation, } } - SmallVector filteredDims = filterDynamicDims(dynDims); + SmallVector filteredDims = condenseValues(dynDims); for (auto result : results) { auto resultTy = result.getType().template cast(); @@ -1184,7 +1155,7 @@ public: inputExprs[value] = rewriter.getAffineDimExpr(index); } - SmallVector filteredDims = filterDynamicDims(dynDims); + SmallVector filteredDims = condenseValues(dynDims); auto initTensor = rewriter.create( loc, filteredDims, resultTy.getShape(), resultTy.getElementType()); @@ -1221,9 +1192,11 @@ public: return rewriter.notifyMatchFailure( op, "tosa.rescale requires scale32 for double_round to be true"); - if (!outputTy.hasStaticShape()) - return rewriter.notifyMatchFailure( - op, "tosa to linalg conversion expects statically shaped tensors"); + auto dynamicDimsOr = + checkHasDynamicBatchDims(rewriter, op, {input, op.output()}); + if (!dynamicDimsOr.hasValue()) + return failure(); + SmallVector dynamicDims = dynamicDimsOr.getValue(); // The shift and multiplier values. SmallVector multiplierValues; @@ -1299,8 +1272,7 @@ public: // Construct the indexing maps needed for linalg.generic ops. Value initTensor = rewriter.create( - loc, ArrayRef({}), outputTy.getShape(), - outputTy.getElementType()); + loc, dynamicDims, outputTy.getShape(), outputTy.getElementType()); auto linalgOp = rewriter.create( loc, outputTy, genericInputs, ValueRange{initTensor}, indexingMaps, @@ -1412,16 +1384,17 @@ public: auto imageH = inputTy.getShape()[1]; auto imageW = inputTy.getShape()[2]; - if (!resultTy.hasStaticShape()) + auto dynamicDimsOr = + checkHasDynamicBatchDims(rewriter, op, {input, op.output()}); + if (!dynamicDimsOr.hasValue()) return failure(); + SmallVector dynamicDims = dynamicDimsOr.getValue(); + if (op.mode() != "NEAREST_NEIGHBOR" && op.mode() != "BILINEAR") return failure(); - auto initTensor = - rewriter - .create(loc, ArrayRef{}, - resultTy.getShape(), resultElementTy) - .result(); + auto initTensor = rewriter.create( + loc, dynamicDims, resultTy.getShape(), resultElementTy); SmallVector affineMaps = { rewriter.getMultiDimIdentityMap(resultTy.getRank())}; @@ -2098,13 +2071,13 @@ public: auto input = adaptor.getOperands()[0]; auto indices = adaptor.getOperands()[1]; - auto inputTy = input.getType().cast(); - auto indicesTy = indices.getType().cast(); auto resultTy = op.getType().cast(); - if (!inputTy.hasStaticShape() || !indicesTy.hasStaticShape()) - return rewriter.notifyMatchFailure( - op, "require input type to have static shape"); + auto dynamicDimsOr = + checkHasDynamicBatchDims(rewriter, op, {input, indices, op.output()}); + if (!dynamicDimsOr.hasValue()) + return failure(); + SmallVector dynamicDims = dynamicDimsOr.getValue(); auto resultElementTy = resultTy.getElementType(); @@ -2112,8 +2085,8 @@ public: auto initTensor = rewriter - .create(loc, ArrayRef{}, - resultTy.getShape(), resultElementTy) + .create(loc, dynamicDims, resultTy.getShape(), + resultElementTy) .result(); SmallVector affineMaps = { diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp index a9c525f..54012c9 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp @@ -18,6 +18,7 @@ #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/Dialect/Tosa/Utils/CoversionUtils.h" #include "mlir/Dialect/Utils/ReshapeOpsUtils.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" @@ -27,29 +28,7 @@ #include using namespace mlir; - -static SmallVector getNParallelLoopsAttrs(unsigned nParallelLoops) { - return SmallVector(nParallelLoops, getParallelIteratorTypeName()); -} - -template -static void getValuesFromIntArrayAttribute(ArrayAttr attr, - SmallVector &arrayValues) { - for (Attribute val : attr.getValue()) { - arrayValues.push_back(val.cast().getValue().getSExtValue()); - } -} - -template -static mlir::SelectOp clampHelper(Location loc, Value arg, - arith::ConstantOp min, arith::ConstantOp max, - P pred, OpBuilder &rewriter) { - auto smallerThanMin = rewriter.create(loc, pred, arg, min); - auto minOrArg = - rewriter.create(loc, smallerThanMin, min, arg); - auto largerThanMax = rewriter.create(loc, pred, max, arg); - return rewriter.create(loc, largerThanMax, max, minOrArg); -} +using namespace mlir::tosa; static mlir::Value applyPad(Location loc, Value input, ArrayRef pad, Attribute padAttr, OpBuilder &rewriter) { @@ -82,14 +61,6 @@ static mlir::Value applyPad(Location loc, Value input, ArrayRef pad, .result(); } -static SmallVector filterDynamicDims(const SmallVector &dynDims) { - SmallVector filteredDims; - for (auto dim : dynDims) - if (dim) - filteredDims.push_back(dim); - return filteredDims; -} - namespace { class ConvConverter : public OpConversionPattern { @@ -116,10 +87,15 @@ public: auto dilationTosaAttr = op->getAttr("dilation").cast(); bool isQuantized = op->hasAttr("quantization_info"); - if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() || - !biasTy.hasStaticShape() || !resultTy.hasStaticShape()) - return rewriter.notifyMatchFailure(op, - "tosa.conv ops require static shapes"); + if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape()) + return rewriter.notifyMatchFailure( + op, "tosa.conv ops require static shapes for weight and bias"); + + auto dynamicDimsOr = + checkHasDynamicBatchDims(rewriter, op, {input, op.output()}); + if (!dynamicDimsOr.hasValue()) + return failure(); + SmallVector dynamicDims = dynamicDimsOr.getValue(); if (inputETy.isUnsignedInteger()) return rewriter.notifyMatchFailure( @@ -172,7 +148,7 @@ public: Attribute resultZeroAttr = rewriter.getZeroAttr(resultETy); Value initTensor = rewriter.create( - loc, resultTy.getShape(), resultETy); + loc, dynamicDims, resultTy.getShape(), resultETy); Value zero = rewriter.create(loc, resultZeroAttr); Value zeroTensor = rewriter.create(loc, zero, initTensor).getResult(0); @@ -197,7 +173,7 @@ public: indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultTy.getRank())); Value biasInitTensor = rewriter.create( - loc, resultTy.getShape(), resultETy); + loc, dynamicDims, resultTy.getShape(), resultETy); if (isQuantized) { auto quantizationInfo = @@ -292,10 +268,15 @@ public: quantizationInfo.weight_zp().getValue().getSExtValue()); } - if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() || - !biasTy.hasStaticShape() || !resultTy.hasStaticShape()) - return rewriter.notifyMatchFailure(op, - "tosa.conv ops require static shapes"); + if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape()) + return rewriter.notifyMatchFailure( + op, "tosa.depthwise_conv ops require static shapes"); + + auto dynamicDimsOr = + checkHasDynamicBatchDims(rewriter, op, {input, op.output()}); + if (!dynamicDimsOr.hasValue()) + return failure(); + SmallVector dynamicDims = dynamicDimsOr.getValue(); auto weightShape = weightTy.getShape(); auto resultShape = resultTy.getShape(); @@ -354,13 +335,13 @@ public: Attribute resultZeroAttr = rewriter.getZeroAttr(resultETy); Value initTensor = rewriter.create( - loc, linalgConvTy.getShape(), resultETy); + loc, dynamicDims, linalgConvTy.getShape(), resultETy); Value zero = rewriter.create(loc, resultZeroAttr); Value zeroTensor = rewriter.create(loc, zero, initTensor).getResult(0); Value biasInitTensor = rewriter.create( - loc, resultTy.getShape(), resultETy); + loc, dynamicDims, resultTy.getShape(), resultETy); if (!isQuantized) { Value conv = rewriter .create( @@ -442,7 +423,7 @@ public: dynDims[2] = rewriter.create(loc, op->getOperand(1), 2); } - SmallVector filteredDims = filterDynamicDims(dynDims); + SmallVector filteredDims = condenseValues(dynDims); auto zeroAttr = rewriter.getZeroAttr(outputElementTy); Value zero = rewriter.create(loc, zeroAttr); @@ -503,7 +484,7 @@ public: dynDims[1] = rewriter.create(loc, weight, 0); } - SmallVector filteredDims = filterDynamicDims(dynDims); + SmallVector filteredDims = condenseValues(dynDims); // Creating maps for the output of MatMul and the bias SmallVector indexingMaps; @@ -611,8 +592,11 @@ public: ShapedType resultTy = op.getType().template cast(); Type resultETy = inputTy.getElementType(); - if (!inputTy.hasStaticShape()) + auto dynamicDimsOr = + checkHasDynamicBatchDims(rewriter, op, {input, op.output()}); + if (!dynamicDimsOr.hasValue()) return failure(); + SmallVector dynamicDims = dynamicDimsOr.getValue(); // Determine what the initial value needs to be for the max pool op. Attribute initialAttr; @@ -649,7 +633,7 @@ public: // Create the linalg op that performs pooling. Value initTensor = rewriter.create( - loc, resultTy.getShape(), resultTy.getElementType()); + loc, dynamicDims, resultTy.getShape(), resultTy.getElementType()); Value filledInitTensor = rewriter.create(loc, initialValue, initTensor).result(); @@ -682,8 +666,11 @@ public: inElementTy.isa() ? rewriter.getI32Type() : inElementTy; ShapedType accTy = resultTy.clone(accETy); - if (!inputTy.hasStaticShape()) + auto dynamicDimsOr = + checkHasDynamicBatchDims(rewriter, op, {input, op.output()}); + if (!dynamicDimsOr.hasValue()) return failure(); + SmallVector dynamicDims = dynamicDimsOr.getValue(); // Apply padding as necessary. llvm::SmallVector pad; @@ -704,8 +691,8 @@ public: Attribute dilationAttr = rewriter.getI64VectorAttr({1, 1}); // Create the linalg op that performs pooling. - Value poolInitTensor = - rewriter.create(loc, accTy.getShape(), accETy); + Value poolInitTensor = rewriter.create( + loc, dynamicDims, accTy.getShape(), accETy); Value filledInitTensor = rewriter.create(loc, initialValue, poolInitTensor) @@ -728,7 +715,7 @@ public: auto affineMap = rewriter.getMultiDimIdentityMap(resultTy.getRank()); Value genericInitTensor = rewriter.create( - loc, resultTy.getShape(), resultETy); + loc, dynamicDims, resultTy.getShape(), resultETy); auto genericOp = rewriter.create( loc, ArrayRef({resultTy}), ValueRange{poolingOp}, @@ -770,7 +757,7 @@ public: auto kH2 = padFn(kH1, y1, pad[3]); auto kHCmp = rewriter.create( loc, arith::CmpIPredicate::slt, kH2, one); - auto kH3 = rewriter.create(loc, kHCmp, one, kH2); + auto kH3 = rewriter.create(loc, kHCmp, one, kH2); // compute the horizontal component of coverage. auto kW0 = rewriter.create(loc, kernel[1]); @@ -778,7 +765,7 @@ public: auto kW2 = padFn(kW1, x1, pad[5]); auto kWCmp = rewriter.create( loc, arith::CmpIPredicate::slt, kW2, one); - auto kW3 = rewriter.create(loc, kWCmp, one, kW2); + auto kW3 = rewriter.create(loc, kWCmp, one, kW2); // Compute the total number of elements and normalize. Value count = rewriter.create(loc, kH3, kW3); diff --git a/mlir/lib/Dialect/Tosa/CMakeLists.txt b/mlir/lib/Dialect/Tosa/CMakeLists.txt index cae6d14..9a9c80f 100644 --- a/mlir/lib/Dialect/Tosa/CMakeLists.txt +++ b/mlir/lib/Dialect/Tosa/CMakeLists.txt @@ -1,4 +1,5 @@ add_mlir_dialect_library(MLIRTosa + Utils/ConversionUtils.cpp Utils/QuantUtils.cpp IR/TosaOps.cpp diff --git a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp new file mode 100644 index 0000000..e994adb --- /dev/null +++ b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp @@ -0,0 +1,30 @@ +//===- ConversionUtils.cpp ------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Utility functions for TOSA lowering +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Tosa/Utils/CoversionUtils.h" + +using namespace mlir; +using namespace mlir::tosa; + +SmallVector +mlir::tosa::getNParallelLoopsAttrs(unsigned nParallelLoops) { + return SmallVector(nParallelLoops, getParallelIteratorTypeName()); +} + +SmallVector +mlir::tosa::condenseValues(const SmallVector &values) { + SmallVector condensedValues; + for (auto value : values) + if (value) + condensedValues.push_back(value); + return condensedValues; +} diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir index f581488..9db6c58 100644 --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir @@ -164,6 +164,19 @@ func @max_pool_padded(%arg0: tensor<1x6x34x62xf32>) -> () { return } +// CHECK-LABEL: @max_pool_dyn +func @max_pool_dyn(%arg0: tensor) -> () { + // CHECK: %[[C0:.+]] = arith.constant 0 + // CHECK: %[[BATCH:.+]] = tensor.dim %arg0, %[[C0]] + // CHECK: %[[CONST:.+]] = arith.constant -3.40282347E+38 + // CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[BATCH]], 4, 32, 62] + // CHECK: %[[FILL:.+]] = linalg.fill(%[[CONST]], %[[INIT]]) + // CHECK: %[[KERNEL:.+]] = linalg.init_tensor [3, 3] + // CHECK: linalg.pooling_nhwc_max {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%arg0, %[[KERNEL]] : tensor, tensor<3x3xf32>) outs(%[[FILL]] : tensor) + %0 = "tosa.max_pool2d"(%arg0) {pad = [0, 0, 0, 0], kernel = [3, 3], stride = [1, 1]} : (tensor) -> (tensor) + return +} + // CHECK-LABEL: @max_pool_i8 func @max_pool_i8(%arg0: tensor<1x6x34x62xi8>) -> () { // CHECK: arith.constant -128 @@ -250,6 +263,24 @@ func @avg_pool(%arg0: tensor<1x6x34x62xf32>) -> (tensor<1x5x33x62xf32>) { // ----- +// CHECK-LABEL: @avg_pool_dyn +func @avg_pool_dyn(%arg0: tensor) -> (tensor) { + // The calculations remain the same as above, only testing for dyn behavior + // CHECK: %[[C0:.+]] = arith.constant 0 + // CHECK: %[[BATCH:.+]] = tensor.dim %arg0, %[[C0]] + // CHECK: %[[PAD:.+]] = linalg.pad_tensor %arg0 low[0, 1, 1, 0] high[0, 1, 1, 0] + // CHECK: %[[POOLINIT:.+]] = linalg.init_tensor [%[[BATCH]], 5, 33, 62] + // CHECK: %[[FILL:.+]] = linalg.fill + // CHECK: %[[KERNEL:.+]] = linalg.init_tensor [4, 4] + // CHECK: %[[POOL:.+]] = linalg.pooling_nhwc_sum {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%[[PAD]], %[[KERNEL]] : tensor, tensor<4x4xf32>) outs(%[[FILL]] : tensor) + // CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[BATCH]], 5, 33, 62] + // CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[POOL]] : tensor) outs(%[[INIT]] : tensor) + %0 = "tosa.avg_pool2d"(%arg0) {pad = [1, 1, 1, 1], kernel = [4, 4], stride = [1, 1]} : (tensor) -> (tensor) + return %0 : tensor +} + +// ----- + // CHECK-LABEL: @avg_pool_i8 func @avg_pool_i8(%arg0 : tensor<1x128x128x2xi8>) -> () { @@ -329,6 +360,29 @@ func @conv2d_f32(%input: tensor<1x49x42x27xf32>, %weights: tensor<28x3x3x27xf32> // ----- +// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)> +// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> + +// CHECK-LABEL: @conv2d_dyn +func @conv2d_dyn(%input: tensor, %weights: tensor<28x3x3x27xf32>, %bias: tensor<28xf32>) -> () { + // CHECK: %[[C0:.+]] = arith.constant 0 + // CHECK: %[[BATCH:.+]] = tensor.dim %arg0, %[[C0]] + // CHECK: %[[PERM:.+]] = arith.constant dense<[1, 2, 3, 0]> + // CHECK: %[[W:.+]] = "tosa.transpose"(%arg1, %[[PERM]]) + // CHECK: %[[M_IN:.+]] = linalg.init_tensor [%[[BATCH]], 45, 40, 28] + // CHECK: %[[CST:.+]] = arith.constant 0 + // CHECK: %[[FILL:.+]] = linalg.fill + // CHECK: %[[B_IN:.+]] = linalg.init_tensor [%[[BATCH]], 45, 40, 28] + // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[W]] : tensor, tensor<3x3x27x28xf32>) outs(%[[FILL]] : tensor) + // CHECK: %[[B:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[CONV]] : tensor<28xf32>, tensor) outs(%[[B_IN]] : tensor) + // CHECK: %[[ADD:.+]] = arith.addf + // CHECK: linalg.yield %[[ADD]] : f32 + %0 = "tosa.conv2d"(%input, %weights, %bias) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [2, 1]} : (tensor, tensor<28x3x3x27xf32>, tensor<28xf32>) -> (tensor) + return +} + +// ----- + // CHECK-LABEL: @conv2d_padded_f32 func @conv2d_padded_f32(%input: tensor<1x47x40x28xf32>, %weights: tensor<28x3x3x28xf32>, %bias: tensor<28xf32>) -> () { // CHECK: %[[C0:.+]] = arith.constant 0 @@ -378,6 +432,30 @@ func @depthwise_conv(%arg0 : tensor<1x7x5x3xf32>, %arg1 : tensor<3x1x3x11xf32>, // CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d3)> // CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK-LABEL: @depthwise_conv_dyn +func @depthwise_conv_dyn(%arg0 : tensor, %arg1 : tensor<3x1x3x11xf32>, %arg2 : tensor<33xf32>) -> () { + // CHECK: %[[C0:.+]] = arith.constant 0 + // CHECK: %[[BATCH:.+]] = tensor.dim %arg0, %[[C0]] + // CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[BATCH]], 5, 5, 3, 11] + // CHECK: %[[CST0:.+]] = arith.constant 0 + // CHECK: %[[FILL:.+]] = linalg.fill + // CHECK: %[[OUT:.+]] = linalg.init_tensor [%[[BATCH]], 5, 5, 33] + // CHECK: %[[DEPTH:.+]] = linalg.depthwise_conv_2d_nhwc_hwcm {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor, tensor<3x1x3x11xf32>) outs(%[[FILL]] : tensor) + // CHECK: %[[COLLAPSED:.+]] = "tosa.reshape"(%[[DEPTH]]) {new_shape = [-1, 5, 5, 33]} + // CHECK: %[[BIAS:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[COLLAPSED]] : tensor<33xf32>, tensor) outs(%[[OUT]] : tensor) { + // CHECK: ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors + // CHECK: %[[ADD:.+]] = arith.addf %arg3, %arg4 : f32 + // CHECK: linalg.yield %[[ADD]] : f32 + // CHECK: } -> tensor + %2 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) { pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1] } : (tensor, tensor<3x1x3x11xf32>, tensor<33xf32>) -> (tensor) + return +} + +// ----- + +// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d3)> +// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> + // CHECK-LABEL: @depthwise_conv_strides func @depthwise_conv_strides(%arg0 : tensor<1x11x9x3xf32>, %arg1 : tensor<3x1x3x11xf32>, %arg2 : tensor<33xf32>) -> () { // CHECK: [[INIT:%.+]] = linalg.init_tensor [1, 5, 5, 3, 11] diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir index e68e76c..3706c41 100644 --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -897,6 +897,26 @@ func @rescale_i8(%arg0 : tensor<2xi8>) -> () { // ----- +// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> + +// CHECK-LABEL: @rescale_i8_dyn +func @rescale_i8_dyn(%arg0 : tensor) -> () { + // CHECK: %[[C0:.+]] = arith.constant 0 + // CHECK: %[[BATCH:.+]] = tensor.dim %arg0, %[[C0]] + // CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[BATCH]], 2] + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor) outs(%[[INIT]] : tensor) + %0 = "tosa.rescale"(%arg0) {input_zp = 17 : i32, output_zp = 22 : i32, multiplier = [19689 : i32], shift = [15 : i32], scale32 = false, double_round = false, per_channel = false} : (tensor) -> (tensor) + + // CHECK: %[[C0:.+]] = arith.constant 0 + // CHECK: %[[BATCH:.+]] = tensor.dim %arg0, %[[C0]] + // CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[BATCH]], 2] + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor) outs(%[[INIT]] : tensor) + %1 = "tosa.rescale"(%arg0) {input_zp = 17 : i32, output_zp = 22 : i32, multiplier = [19689 : i32], shift = [15 : i32], scale32 = false, double_round = false, per_channel = false} : (tensor) -> (tensor) + + return +} +// ----- + // CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)> // CHECK-LABEL: @rescale_ui8 @@ -1184,6 +1204,22 @@ func @gather_float(%arg0: tensor<2x3x2xf32>, %arg1: tensor<2x3xi32>) -> () { return } +// CHECK-LABEL: @gather_float_dyn +func @gather_float_dyn(%arg0: tensor, %arg1: tensor) -> () { + // CHECK: %[[C0:.+]] = arith.constant 0 + // CHECK: %[[BATCH:.+]] = tensor.dim %arg0, %[[C0]] + // CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[BATCH]], 3, 2] + // CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg1 : tensor) outs(%[[INIT]] : tensor) + // CHECK: ^bb0(%[[ARG0:.+]]: i32, %[[ARG1:.+]]: f32) + // CHECK: %[[IDX0:.+]] = linalg.index 0 + // CHECK: %[[CAST:.+]] = arith.index_cast %[[ARG0]] + // CHECK: %[[IDX2:.+]] = linalg.index 2 + // CHECK: %[[EXTRACT:.+]] = tensor.extract %arg0[%[[IDX0]], %[[CAST]], %[[IDX2]]] : tensor + // CHECK: linalg.yield %[[EXTRACT]] + %0 = "tosa.gather"(%arg0, %arg1) : (tensor, tensor) -> (tensor) + return +} + // CHECK-LABEL: @gather_int func @gather_int(%arg0: tensor<2x3x2xi32>, %arg1: tensor<2x3xi32>) -> () { // CHECK: %[[INIT:.+]] = linalg.init_tensor [2, 3, 2] @@ -1548,3 +1584,15 @@ func @resize_bilinear_int(%input: tensor<1x2x2x1xi8>) -> () { %output = "tosa.resize"(%input) { output_size = [4, 4], stride = [128, 128], offset = [1, 2], stride_fp = [0. : f32, 0. : f32], offset_fp = [0. : f32, 0. : f32], shift = 8 : i32, mode = "BILINEAR" } : (tensor<1x2x2x1xi8>) -> (tensor<1x4x4x1xi32>) return } + +// ----- + +// CHECK-LABEL: @resize_dyn +func @resize_dyn(%input: tensor) -> () { + // CHECK: %[[C0:.+]] = arith.constant 0 + // CHECK: %[[BATCH:.+]] = tensor.dim %arg0, %[[C0]] + // CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[BATCH]], 4, 4, 1] + // CHECK: %[[GENERIC:.+]] = linalg.generic + %output = "tosa.resize"(%input) { output_size = [4, 4], stride = [128, 128], offset = [1, 2], stride_fp = [0. : f32, 0. : f32], offset_fp = [0. : f32, 0. : f32], shift = 8 : i32, mode = "BILINEAR" } : (tensor) -> (tensor) + return +} diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel index 4399f2f..0b144d9 100644 --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -7071,13 +7071,14 @@ cc_library( includes = ["include"], deps = [ ":Analysis", + ":ArithmeticDialect", ":Dialect", + ":DialectUtils", ":IR", ":InferTypeOpInterface", ":LoopLikeInterface", ":Pass", ":QuantOps", - ":SideEffectInterfaces", ":StandardOps", ":TensorDialect", ":TosaDialectIncGen",