From d84d418e2adc421c98e484ab3b09e2f4f3e5c1ef Mon Sep 17 00:00:00 2001 From: Tina Jung Date: Wed, 5 Jul 2023 08:37:47 +0100 Subject: [PATCH] [mlir][tosa] Constant folding for reciprocal Add constant fold for tosa.reciprocal, which can be applied if the input is a dense constant tensor. The reciprocal is computed for every element and the result is a tensor with the same dimensions as the input tensor. As the input tensor might require a lot of memory and the folding might double the required memory, a heuristic decides when to actually apply the folding. Currently, the operation will be replaced only if the input constant is a splat (i.e. requires little memory) or has in single user (similar to the already existing fold for constant transposes). This keeps the additionally required space low. Differential Revision: https://reviews.llvm.org/D150578 --- mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h | 2 + mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt | 2 +- .../Tosa/Transforms/TosaFoldConstantTranspose.cpp | 138 ---------- mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp | 302 +++++++++++++++++++++ .../Transforms/TosaLayerwiseConstantFoldPass.cpp | 1 + .../Dialect/Tosa/constant-reciprocal-fold.mlir | 137 ++++++++++ 6 files changed, 443 insertions(+), 139 deletions(-) delete mode 100644 mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantTranspose.cpp create mode 100644 mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp create mode 100644 mlir/test/Dialect/Tosa/constant-reciprocal-fold.mlir diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h index d6ae781..c81f59b 100644 --- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h @@ -30,6 +30,8 @@ void populateTosaDecomposeTransposeConv(MLIRContext *ctx, RewritePatternSet &patterns); void populateTosaDecomposeDepthwise(MLIRContext *ctx, RewritePatternSet &patterns); +void populateTosaFoldConstantReciprocalPatterns(MLIRContext *ctx, + RewritePatternSet &patterns); void populateTosaFoldConstantTransposePatterns(MLIRContext *ctx, RewritePatternSet &patterns); diff --git a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt index 4f5a54d..0e6510b 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt @@ -2,7 +2,7 @@ add_mlir_dialect_library(MLIRTosaTransforms TosaDecomposeTransposeConv.cpp TosaDecomposeConv2D.cpp TosaDecomposeDepthwise.cpp - TosaFoldConstantTranspose.cpp + TosaFolders.cpp TosaInferShapes.cpp TosaLayerwiseConstantFoldPass.cpp TosaMakeBroadcastable.cpp diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantTranspose.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantTranspose.cpp deleted file mode 100644 index 302e279..0000000 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantTranspose.cpp +++ /dev/null @@ -1,138 +0,0 @@ -//===- TosaFoldConstantTranspose.cpp --------------------------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// Fold TOSA Transpose operation on constant data -// -//===----------------------------------------------------------------------===// - -#include "mlir/Dialect/Tosa/IR/TosaOps.h" -#include "mlir/Dialect/Tosa/Transforms/Passes.h" -#include "mlir/Dialect/Utils/IndexingUtils.h" -#include "mlir/IR/Matchers.h" -#include "mlir/Pass/Pass.h" - -using namespace mlir; -using namespace mlir::tosa; - -namespace { - -template -DenseElementsAttr transposeType(ElementsAttr attr, ShapedType inputType, - ShapedType outputType, - llvm::ArrayRef permValues) { - if (inputType.getNumElements() == 0) - return DenseElementsAttr::get(outputType, llvm::ArrayRef{}); - - auto attrValues = attr.getValues(); - auto inputShape = inputType.getShape(); - - // The inverted permutation map and strides of the output are used to compute - // the contribution of a given dimension to the destination linear index in - // an order-independent way. - auto outputStrides = computeStrides(outputType.getShape()); - auto invertedPermValues = invertPermutationVector(permValues); - - auto initialValue = *std::begin(attrValues); - SmallVector outputValues(inputType.getNumElements(), initialValue); - - for (const auto &it : llvm::enumerate(attrValues)) { - auto srcLinearIndex = it.index(); - - uint64_t dstLinearIndex = 0; - for (int64_t dim = inputShape.size() - 1; dim >= 0; --dim) { - // Compute the index into the current dimension of the source vector. - auto sourceIndexForDim = srcLinearIndex % inputShape[dim]; - srcLinearIndex /= inputShape[dim]; - - // Add the contribution of the current dimension to the output using the - // permutation map. - dstLinearIndex += - outputStrides[invertedPermValues[dim]] * sourceIndexForDim; - } - - outputValues[dstLinearIndex] = it.value(); - } - - return DenseElementsAttr::get(outputType, - llvm::ArrayRef(outputValues)); -} - -// A type specialized transposition of an ElementsAttr. -// This implementation tries to operate on the underlying data in its raw -// representation when possible to avoid allocating a large number of Attribute -// objects. -DenseElementsAttr transpose(ElementsAttr attr, ShapedType inputType, - ShapedType outputType, - llvm::ArrayRef permValues) { - auto baseType = inputType.getElementType(); - - // Handle possible integer types - if (auto intType = dyn_cast(baseType)) { - switch (intType.getWidth()) { - case 1: - return transposeType(attr, inputType, outputType, permValues); - case 8: - return transposeType(attr, inputType, outputType, permValues); - case 16: - return transposeType(attr, inputType, outputType, permValues); - case 32: - return transposeType(attr, inputType, outputType, permValues); - case 64: - return transposeType(attr, inputType, outputType, permValues); - default: - return transposeType(attr, inputType, outputType, permValues); - } - } - - // Handle possible float types - if (baseType.isF32()) { - return transposeType(attr, inputType, outputType, permValues); - } - - return transposeType(attr, inputType, outputType, permValues); -} - -struct TosaFoldConstantTranspose : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(tosa::TransposeOp op, - PatternRewriter &rewriter) const override { - auto outputType = cast(op.getType()); - // TOSA supports quantized types. - if (!outputType.getElementType().isIntOrIndexOrFloat()) - return failure(); - - ElementsAttr inputValues; - if (!matchPattern(op.getInput1(), m_Constant(&inputValues))) - return failure(); - // Make sure the input is a constant that has a single user. - if (!llvm::hasSingleElement(op.getInput1().getDefiningOp()->getUsers())) - return failure(); - - DenseIntElementsAttr permAttr; - if (!matchPattern(op.getPerms(), m_Constant(&permAttr))) - return failure(); - auto permValues = llvm::to_vector<6>(llvm::map_range( - // TOSA allows both 32- and 64-bit integer tensors here. - permAttr.getValues(), - [](const APInt &val) { return val.getSExtValue(); })); - - auto inputType = cast(op.getInput1().getType()); - - auto resultAttr = transpose(inputValues, inputType, outputType, permValues); - rewriter.replaceOpWithNewOp(op, outputType, resultAttr); - return success(); - } -}; - -} // namespace - -void mlir::tosa::populateTosaFoldConstantTransposePatterns( - MLIRContext *ctx, RewritePatternSet &patterns) { - patterns.add(ctx); -} diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp new file mode 100644 index 0000000..5869399 --- /dev/null +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp @@ -0,0 +1,302 @@ +//===- TosaFolders.cpp ----------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Fold TOSA operations +// +//===----------------------------------------------------------------------===// + +#include + +#include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/Dialect/Tosa/Transforms/Passes.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Matchers.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LogicalResult.h" +#include "llvm/ADT/APFloat.h" +#include "llvm/ADT/FloatingPointMode.h" +#include "llvm/ADT/SmallVector.h" + +using namespace mlir; +using namespace mlir::tosa; + +namespace { + +/// Rounding mode to be used on floating point operations that require rounding. +static constexpr llvm::RoundingMode tosaRoundingMode = + llvm::APFloat::rmNearestTiesToEven; + +/// Apply the given transformation \p toApply to every element of the tensor to +/// be transformed \p toTransform. +/// +/// Elements of \p toTransform are extracted as \p SrcValueType. +/// +/// \returns A tensor with the same size as \p toTransform, containing +/// \p TargetValueType values of type \p TargetType. +template +DenseElementsAttr applyElementWise( + const DenseElementsAttr &toTransform, + const std::function &toApply, + TargetType targetType) { + SmallVector transformedValues; + // We already know the amount of values we will insert, reserve space for + // all of them to avoid dynamic resizing + transformedValues.reserve(toTransform.getNumElements()); + for (auto val : toTransform.getValues()) { + auto transformedVal = toApply(val, targetType); + transformedValues.push_back(transformedVal); + } + + // Make sure that the output tensor has the expected output type + auto inShape = toTransform.getType(); + auto outTy = inShape.cloneWith({}, targetType); + + return DenseElementsAttr::get(outTy, transformedValues); +} + +template DenseElementsAttr applyElementWise( + const DenseElementsAttr &toTransform, + const std::function &toApply, + FloatType targetType); + +/// Function that checks if the type contained in \p toCheck is float. +LogicalResult notifyIfNotFloat(TypedValue toCheck, TosaOp location, + PatternRewriter &rewriter) { + if (isa(toCheck.getType().getElementType())) { + return success(); + } + return rewriter.notifyMatchFailure(location, + "Unexpected input tensor type: the " + "TOSA spec only allows floats"); +} + +/// Function that checks if \p toCheck is a dense TOSA constant tensor. +LogicalResult notifyIfNoTosaDenseConstantTensor(TypedValue toCheck, + TosaOp location, + PatternRewriter &rewriter) { + // Check whether the tensor is constant and dense + // TODO We currently ensure the tensor is dense by using the correct type for + // the bind_value, however we do not actually need this value. It would be + // nicer to only have a check here. + DenseElementsAttr tmp; + if (!matchPattern(toCheck, m_Constant(&tmp))) { + return rewriter.notifyMatchFailure(location, + "Non-const or non-dense input tensor"); + } + + // Make sure it actually is a TOSA constant (the match allows for other + // constants as well) + if (isa(toCheck.getDefiningOp())) { + return success(); + } + + return rewriter.notifyMatchFailure(location, + "The reciprocal can only be folded if " + "it operates on a TOSA constant"); +} + +/// Function that checks if \p toCheck is a dense TOSA constant float tensor. +LogicalResult notifyIfNotConstantFloatTosaTensor(TypedValue toCheck, + TosaOp location, + PatternRewriter &rewriter) { + auto floatCheck = notifyIfNotFloat(toCheck, location, rewriter); + if (failed(floatCheck)) { + return floatCheck; + } + return notifyIfNoTosaDenseConstantTensor(toCheck, location, rewriter); +} + +/// Heuristic to decide when to replace a unary operation on a constant with the +/// folded value. +/// Folding operations on constants can lead to an increased memory usage +/// whenever the input cannot be replaced but a new constant is inserted. Hence, +/// this will currently only suggest folding when the memory impact is +/// negligible. +/// Takes the \p unaryOp and the constant input \p values. +/// \returns Whether folding should be applied. +bool constantUnaryOpShouldBeFolded(TosaOp unaryOp, DenseElementsAttr values) { + assert(unaryOp->getNumOperands() == 1); + auto inputOp = unaryOp->getOperand(0); + + // If the input is a splat, we don't care for the number of users + if (isa(values)) { + return true; + } + + // If this is the only use of the tensor it should be replaced as no + // additional memory is required + return inputOp.hasOneUse(); +} + +template +DenseElementsAttr transposeType(ElementsAttr attr, ShapedType inputType, + ShapedType outputType, + llvm::ArrayRef permValues) { + if (inputType.getNumElements() == 0) + return DenseElementsAttr::get(outputType, llvm::ArrayRef{}); + + auto attrValues = attr.getValues(); + auto inputShape = inputType.getShape(); + + // The inverted permutation map and strides of the output are used to compute + // the contribution of a given dimension to the destination linear index in + // an order-independent way. + auto outputStrides = computeStrides(outputType.getShape()); + auto invertedPermValues = invertPermutationVector(permValues); + + auto initialValue = *std::begin(attrValues); + SmallVector outputValues(inputType.getNumElements(), initialValue); + + for (const auto &it : llvm::enumerate(attrValues)) { + auto srcLinearIndex = it.index(); + + uint64_t dstLinearIndex = 0; + for (int64_t dim = inputShape.size() - 1; dim >= 0; --dim) { + // Compute the index into the current dimension of the source vector. + auto sourceIndexForDim = srcLinearIndex % inputShape[dim]; + srcLinearIndex /= inputShape[dim]; + + // Add the contribution of the current dimension to the output using the + // permutation map. + dstLinearIndex += + outputStrides[invertedPermValues[dim]] * sourceIndexForDim; + } + + outputValues[dstLinearIndex] = it.value(); + } + + return DenseElementsAttr::get(outputType, + llvm::ArrayRef(outputValues)); +} + +// A type specialized transposition of an ElementsAttr. +// This implementation tries to operate on the underlying data in its raw +// representation when possible to avoid allocating a large number of Attribute +// objects. +DenseElementsAttr transpose(ElementsAttr attr, ShapedType inputType, + ShapedType outputType, + llvm::ArrayRef permValues) { + auto baseType = inputType.getElementType(); + + // Handle possible integer types + if (auto intType = dyn_cast(baseType)) { + switch (intType.getWidth()) { + case 1: + return transposeType(attr, inputType, outputType, permValues); + case 8: + return transposeType(attr, inputType, outputType, permValues); + case 16: + return transposeType(attr, inputType, outputType, permValues); + case 32: + return transposeType(attr, inputType, outputType, permValues); + case 64: + return transposeType(attr, inputType, outputType, permValues); + default: + return transposeType(attr, inputType, outputType, permValues); + } + } + + // Handle possible float types + if (baseType.isF32()) { + return transposeType(attr, inputType, outputType, permValues); + } + + return transposeType(attr, inputType, outputType, permValues); +} + +struct TosaFoldConstantTranspose : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tosa::TransposeOp op, + PatternRewriter &rewriter) const override { + auto outputType = cast(op.getType()); + // TOSA supports quantized types. + if (!outputType.getElementType().isIntOrIndexOrFloat()) + return failure(); + + ElementsAttr inputValues; + if (!matchPattern(op.getInput1(), m_Constant(&inputValues))) + return failure(); + // Make sure the input is a constant that has a single user. + if (!llvm::hasSingleElement(op.getInput1().getDefiningOp()->getUsers())) + return failure(); + + DenseIntElementsAttr permAttr; + if (!matchPattern(op.getPerms(), m_Constant(&permAttr))) + return failure(); + auto permValues = llvm::to_vector<6>(llvm::map_range( + // TOSA allows both 32- and 64-bit integer tensors here. + permAttr.getValues(), + [](const APInt &val) { return val.getSExtValue(); })); + + auto inputType = cast(op.getInput1().getType()); + + auto resultAttr = transpose(inputValues, inputType, outputType, permValues); + rewriter.replaceOpWithNewOp(op, outputType, resultAttr); + return success(); + } +}; + +struct TosaFoldConstantReciprocal : public OpRewritePattern { + + using OpRewritePattern::OpRewritePattern; + + static APFloat computeReciprocal(const APFloat &floatVal, FloatType floatTy) { + auto recipAttr = FloatAttr::get(floatTy, 1.0); + APFloat recip = recipAttr.getValue(); + recip.divide(floatVal, tosaRoundingMode); + + return recip; + } + + LogicalResult matchAndRewrite(ReciprocalOp recip, + PatternRewriter &rewriter) const override { + auto inputTensor = recip.getInput1(); + + // Check that we can apply folding + auto preCondCheck = + notifyIfNotConstantFloatTosaTensor(inputTensor, recip, rewriter); + if (failed(preCondCheck)) { + return preCondCheck; + } + + // Extract the tensor values + DenseElementsAttr inputValues; + matchPattern(inputTensor, m_Constant(&inputValues)); + + // Check whether this should be folded. + if (!constantUnaryOpShouldBeFolded(recip, inputValues)) { + return rewriter.notifyMatchFailure( + recip, "Currently, reciprocals will only be folded if the input " + "tensor has a single user"); + } + + // Create a new tensor with the updated values + auto newTensor = applyElementWise( + inputValues, &computeReciprocal, + cast(inputValues.getElementType())); + + // Replace the use of the reciprocal with the transformed tensor + rewriter.replaceOpWithNewOp(recip, newTensor.getType(), newTensor); + return success(); + } +}; + +} // namespace + +void mlir::tosa::populateTosaFoldConstantTransposePatterns( + MLIRContext *ctx, RewritePatternSet &patterns) { + patterns.add(ctx); +} + +void mlir::tosa::populateTosaFoldConstantReciprocalPatterns( + MLIRContext *ctx, RewritePatternSet &patterns) { + patterns.add(ctx); +} diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp index a217f66..2e2d338 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp @@ -50,6 +50,7 @@ struct TosaLayerwiseConstantFoldPass RewritePatternSet patterns(ctx); auto func = getOperation(); + mlir::tosa::populateTosaFoldConstantReciprocalPatterns(ctx, patterns); mlir::tosa::populateTosaFoldConstantTransposePatterns(ctx, patterns); populateTosaOpsCanonicalizationPatterns(ctx, patterns); diff --git a/mlir/test/Dialect/Tosa/constant-reciprocal-fold.mlir b/mlir/test/Dialect/Tosa/constant-reciprocal-fold.mlir new file mode 100644 index 0000000..cc71c43 --- /dev/null +++ b/mlir/test/Dialect/Tosa/constant-reciprocal-fold.mlir @@ -0,0 +1,137 @@ +// RUN: mlir-opt --split-input-file --tosa-layerwise-constant-fold %s | FileCheck %s + +// CHECK-LABEL: @reciprocal_fold_single_valued +func.func @reciprocal_fold_single_valued() -> tensor { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}2.5{{0*}}e-01{{.*}}tensor + // CHECK-NOT: tosa.reciprocal + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = dense<4.0> : tensor} : () -> tensor + %1 = "tosa.reciprocal"(%0) : (tensor) -> tensor + return %1 : tensor +} + +// CHECK-LABEL: @reciprocal_fold_splat +func.func @reciprocal_fold_splat() -> tensor<12x7xf32> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}2.5{{0*}}e-01{{.*}}tensor<12x7xf32> + // CHECK-NOT: tosa.reciprocal + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = dense<4.0> : tensor<12x7xf32>} : () -> tensor<12x7xf32> + %1 = "tosa.reciprocal"(%0) : (tensor<12x7xf32>) -> tensor<12x7xf32> + return %1 : tensor<12x7xf32> +} + +// CHECK-LABEL: @reciprocal_div_zero +func.func @reciprocal_div_zero() -> tensor { + // 0x7F800000 is the value for +infinity + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}0x7F800000 + // CHECK-NOT: tosa.reciprocal + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = dense<0.0> : tensor} : () -> tensor + %1 = "tosa.reciprocal"(%0) : (tensor) -> tensor + return %1 : tensor +} + +// CHECK-LABEL: @reciprocal_div_neg_zero +func.func @reciprocal_div_neg_zero() -> tensor { + // 0xFF800000 is the value for -infinity + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}0xFF800000 + // CHECK-NOT: tosa.reciprocal + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = dense<-0.0> : tensor} : () -> tensor + %1 = "tosa.reciprocal"(%0) : (tensor) -> tensor + return %1 : tensor +} + +// CHECK-LABEL: @reciprocal_div_nan +func.func @reciprocal_div_nan() -> tensor { + // 0x7FC00000 is the value for NAN + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}0x7FC00000 + // CHECK-NOT: tosa.reciprocal + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = dense<0x7FC00000> : tensor} : () -> tensor + %1 = "tosa.reciprocal"(%0) : (tensor) -> tensor + return %1 : tensor +} + +// CHECK-LABEL: @reciprocal_div_infinity +func.func @reciprocal_div_infinity() -> tensor { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}<0.{{0*}}e+00> + // CHECK-NOT: tosa.reciprocal + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = dense<0x7F800000> : tensor} : () -> tensor + %1 = "tosa.reciprocal"(%0) : (tensor) -> tensor + return %1 : tensor +} + +// CHECK-LABEL: @reciprocal_div_neg_infinity +func.func @reciprocal_div_neg_infinity() -> tensor { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}<-0.{{0*}}e+00> + // CHECK-NOT: tosa.reciprocal + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = dense<0xFF800000> : tensor} : () -> tensor + %1 = "tosa.reciprocal"(%0) : (tensor) -> tensor + return %1 : tensor +} + +// CHECK-LABEL: @reciprocal_div_underflow +func.func @reciprocal_div_underflow() -> tensor<2xf16> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}-0.{{0*}}e+00, 0.{{0*}}e+00 + // CHECK-NOT: tosa.reciprocal + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = dense<[-6.0e+15, 6.0e+15]> : tensor<2xf16>} : () -> tensor<2xf16> + %1 = "tosa.reciprocal"(%0) : (tensor<2xf16>) -> tensor<2xf16> + return %1 : tensor<2xf16> +} + +// CHECK-LABEL: @reciprocal_div_overflow +func.func @reciprocal_div_overflow() -> tensor<2xf16> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}0x7C00, 0xFC00 + // CHECK-NOT: tosa.reciprocal + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = dense<[0.0000001, -0.0000001]> : tensor<2xf16>} : () -> tensor<2xf16> + %1 = "tosa.reciprocal"(%0) : (tensor<2xf16>) -> tensor<2xf16> + return %1 : tensor<2xf16> +} + +// CHECK-LABEL: @reciprocal_no_fold +// The folding optimization works only intra-procedurally, so we won't be able +// to fold anything here +func.func @reciprocal_no_fold(%arg0: tensor) -> tensor { + // CHECK: tosa.reciprocal + // CHECK-NEXT: return + %0 = "tosa.reciprocal"(%arg0) : (tensor) -> tensor + return %0 : tensor +} + +// CHECK-LABEL: @reciprocal_fold +func.func @reciprocal_fold() -> tensor<4x6xf32> { + // CHECK: [[RES:]] ={{.*}}tosa.const + // CHECK-SAME{LITERAL}: [[5.68828249, 11.4416485, 1.6880486, 0.680272102, -0.875350117, 0.342313349], + // CHECK-SAME{LITERAL}: [-4.81231928, 0.698080301, 0.65432179, -82.6446304, -4.33651352, -0.747551739], + // CHECK-SAME{LITERAL}: [-12.4378109, 13.140605, 1.89501607, 0.885582745, 4.08830738, 1.4396776], + // CHECK-SAME{LITERAL}: [2.02880907, -1.53280187, 0.552730501, 7.15819644, 0.64495325, -0.973709881]] + // CHECK-NOT: tosa.reciprocal + // CHECK: return [[RES]] + %0 = "tosa.const"() { value = dense<[ + [ 0.1758, 0.0874, 0.5924, 1.4700, -1.1424, 2.9213], + [-0.2078, 1.4325, 1.5283, -0.0121, -0.2306, -1.3377], + [-0.0804, 0.0761, 0.5277, 1.1292, 0.2446, 0.6946], + [ 0.4929, -0.6524, 1.8092, 0.1397, 1.5505, -1.0270]]> + : tensor<4x6xf32> + } : () -> tensor<4x6xf32> + %1 = "tosa.reciprocal"(%0) : (tensor<4x6xf32>) -> tensor<4x6xf32> + return %1 : tensor<4x6xf32> +} + +// CHECK-LABEL: @reciprocal_of_const_sparse +// Sparse tensors are currently not supported +func.func @reciprocal_of_const_sparse() -> tensor<32xbf16> { + // CHECK: tosa.const + // CHECK: tosa.reciprocal + %0 = "tosa.const"() { value = sparse< + [[0], [3], [11], [17], [20], [23], [25], [30], [31]], + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]> + : tensor<32xbf16> } : () -> tensor<32xbf16> + %1 = "tosa.reciprocal"(%0) : (tensor<32xbf16>) -> tensor<32xbf16> + return %1 : tensor<32xbf16> +} -- 2.7.4