#include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
#include "mlir/Conversion/StandardToSPIRV/StandardToSPIRVPass.h"
+#include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h"
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
#include "mlir/Conversion/VectorToROCDL/VectorToROCDL.h"
#include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
}
//===----------------------------------------------------------------------===//
+// TosaToLinalg
+//===----------------------------------------------------------------------===//
+
+def TosaToLinalgOnTensors : FunctionPass<"tosa-to-linalg-on-tensors"> {
+ let summary = "Lower TOSA to LinAlg on tensors";
+ let description = [{
+ Pass that converts TOSA operations to the equivalent operations using the
+ tensor operations in LinAlg.
+ }];
+
+ let constructor = "tosa::createTosaToLinalgOnTensors()";
+}
+
+//===----------------------------------------------------------------------===//
// VectorToSCF
//===----------------------------------------------------------------------===//
--- /dev/null
+//===-- TosaToLinalg.h - TOSA optimization pass declarations ----------*- C++
+//-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares the passes for the TOSA Linalg Dialect in MLIR.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_TOSATOLINALG_TOSATOLINALG_H
+#define MLIR_CONVERSION_TOSATOLINALG_TOSATOLINALG_H
+
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+namespace tosa {
+
+std::unique_ptr<Pass> createTosaToLinalgOnTensors();
+
+/// Populates passes to convert from TOSA to Linalg on buffers. At the end of
+/// the pass, the function will only contain linalg ops or standard ops if the
+/// pipeline succeeds.
+void addTosaToLinalgOnTensorsPasses(OpPassManager &pm);
+
+/// Populates conversion passes from TOSA dialect to Linalg dialect.
+void populateTosaToLinalgOnTensorsConversionPatterns(
+ MLIRContext *context, OwningRewritePatternList *patterns);
+
+} // namespace tosa
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_TOSATOLINALG_TOSATOLINALG_H
add_subdirectory(SPIRVToLLVM)
add_subdirectory(StandardToLLVM)
add_subdirectory(StandardToSPIRV)
+add_subdirectory(TosaToLinalg)
add_subdirectory(ArmSVEToLLVM)
add_subdirectory(VectorToROCDL)
add_subdirectory(VectorToLLVM)
--- /dev/null
+add_mlir_conversion_library(MLIRTosaToLinalg
+ TosaToLinalg.cpp
+ TosaToLinalgPass.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tosa
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/IR
+
+ DEPENDS
+ MLIRConversionPassIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRIR
+ MLIRLinalg
+ MLIRLinalgUtils
+ MLIRPass
+ MLIRTosa
+ MLIRTosaTransforms
+ MLIRSupport
+ )
--- /dev/null
+//===- TosaToLinalg.cpp - Lowering Tosa to Linalg Dialect -----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// These rewriters lower from the Tosa to the Linalg dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/Passes.h"
+#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+
+static SmallVector<StringRef> getNParallelLoopsAttrs(unsigned nParallelLoops) {
+ return SmallVector<StringRef>(nParallelLoops, getParallelIteratorTypeName());
+}
+
+static Value
+createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
+ ArrayRef<Type> resultTypes,
+ PatternRewriter &rewriter) {
+ Location loc = op->getLoc();
+ auto elementTy =
+ op->getResult(0).getType().cast<ShapedType>().getElementType();
+
+ // tosa::AbsOp
+ if (isa<tosa::AbsOp>(op) && elementTy.isa<FloatType>())
+ return rewriter.create<mlir::AbsFOp>(loc, resultTypes, args);
+
+ // tosa::AddOp
+ if (isa<tosa::AddOp>(op) && elementTy.isa<FloatType>())
+ return rewriter.create<mlir::AddFOp>(loc, resultTypes, args);
+
+ if (isa<tosa::AddOp>(op) && elementTy.isa<IntegerType>())
+ return rewriter.create<mlir::AddIOp>(loc, resultTypes, args);
+
+ // tosa::BitwiseAndOp
+ if (isa<tosa::BitwiseAndOp>(op) && elementTy.isa<IntegerType>())
+ return rewriter.create<mlir::AndOp>(loc, resultTypes, args);
+
+ // tosa::BitwiseOrOp
+ if (isa<tosa::BitwiseOrOp>(op) && elementTy.isa<IntegerType>())
+ return rewriter.create<mlir::OrOp>(loc, resultTypes, args);
+
+ // tosa::BitwiseXOrOp
+ if (isa<tosa::BitwiseXorOp>(op) && elementTy.isa<IntegerType>())
+ return rewriter.create<mlir::XOrOp>(loc, resultTypes, args);
+
+ // tosa::LogicalLeftShiftOp
+ if (isa<tosa::LogicalLeftShiftOp>(op) && elementTy.isa<IntegerType>())
+ return rewriter.create<mlir::ShiftLeftOp>(loc, resultTypes, args);
+
+ // tosa::LogicalrightShiftOp
+ if (isa<tosa::LogicalRightShiftOp>(op) && elementTy.isa<IntegerType>())
+ return rewriter.create<mlir::UnsignedShiftRightOp>(loc, resultTypes, args);
+
+ // tosa::PowOp
+ if (isa<tosa::PowOp>(op) && elementTy.isa<FloatType>())
+ return rewriter.create<mlir::PowFOp>(loc, resultTypes, args);
+
+ // tosa::SubOp
+ if (isa<tosa::SubOp>(op) && elementTy.isa<FloatType>())
+ return rewriter.create<mlir::SubFOp>(loc, resultTypes, args);
+
+ if (isa<tosa::SubOp>(op) && elementTy.isa<IntegerType>())
+ return rewriter.create<mlir::SubIOp>(loc, resultTypes, args);
+
+ // tosa::TanhOp
+ if (isa<tosa::TanhOp>(op) && elementTy.isa<FloatType>())
+ return rewriter.create<mlir::TanhOp>(loc, resultTypes, args);
+
+ rewriter.notifyMatchFailure(
+ op, "unhandled op for linalg body calculation for elementwise op");
+ return nullptr;
+}
+
+static LogicalResult
+elementwiseMatchAndRewriteHelper(Operation *operation,
+ PatternRewriter &rewriter) {
+ auto loc = operation->getLoc();
+ auto results = operation->getResults();
+ auto t0 = operation->getOperand(0).getType().template dyn_cast<ShapedType>();
+ if (!t0)
+ return rewriter.notifyMatchFailure(operation,
+ "All results must be a shaped type");
+
+ // For now require no broadcasting. Consider making it support broadcasting
+ // operations.
+ Type uniqueTy = operation->getOperand(0).getType();
+ bool allInputTypesEqual =
+ llvm::all_of(operation->getOperandTypes(),
+ [&](Type operandTy) { return operandTy == uniqueTy; });
+ if (!allInputTypesEqual)
+ return rewriter.notifyMatchFailure(operation,
+ "All operands must have the same type");
+ bool allResultTypesEqual =
+ llvm::all_of(operation->getResultTypes(),
+ [&](Type resultTy) { return resultTy == uniqueTy; });
+ if (!allResultTypesEqual)
+ return rewriter.notifyMatchFailure(
+ operation, "All results must have the same type as the input");
+
+ // Construct the indexing maps needed for linalg.generic ops.
+ SmallVector<Type> bodyArgTypes;
+
+ for (Value in : operation->getOperands())
+ bodyArgTypes.emplace_back(getElementTypeOrSelf(in.getType()));
+
+ SmallVector<Type> opResultTypes;
+ SmallVector<Value> initTensors;
+ for (auto result : results) {
+ auto resultType = result.getType().template cast<ShapedType>();
+ if (!resultType.hasStaticShape())
+ return rewriter.notifyMatchFailure(
+ operation,
+ "tosa to linalg conversion expects statically shaped tensors");
+
+ initTensors.push_back(rewriter.create<linalg::InitTensorOp>(
+ loc, ArrayRef<Value>({}), resultType.getShape(),
+ resultType.getElementType()));
+ opResultTypes.push_back(result.getType());
+ }
+
+ auto bodyResultTypes = llvm::to_vector<4>(llvm::map_range(
+ initTensors, [](Value v) { return getElementTypeOrSelf(v); }));
+
+ // Supports only non-broadcasted operation. Shoudl consider update indexing
+ // map to be multidimensional.
+ unsigned nloops = t0.getRank();
+ AffineMap commonIndexingMap = rewriter.getMultiDimIdentityMap(nloops);
+ SmallVector<AffineMap, 2> indexingMaps(
+ operation->getNumOperands() + bodyResultTypes.size(), commonIndexingMap);
+
+ bool didEncounterError = false;
+ auto linalgOp = rewriter.create<linalg::GenericOp>(
+ loc, opResultTypes, operation->getOperands(), initTensors, indexingMaps,
+ getNParallelLoopsAttrs(nloops),
+ [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) {
+ Value opResult = createLinalgBodyCalculationForElementwiseOp(
+ operation, blockArgs.take_front(operation->getNumOperands()),
+ bodyResultTypes, rewriter);
+ if (opResult) {
+ didEncounterError = true;
+ }
+ nestedBuilder.create<linalg::YieldOp>(loc, opResult);
+ });
+
+ if (!didEncounterError)
+ return failure();
+
+ rewriter.replaceOp(operation, linalgOp->getResults());
+ return success();
+}
+
+namespace {
+
+template <typename SrcOp>
+class PointwiseConverter : public OpRewritePattern<SrcOp> {
+public:
+ using OpRewritePattern<SrcOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(SrcOp op,
+ PatternRewriter &rewriter) const final {
+ return elementwiseMatchAndRewriteHelper(op, rewriter);
+ }
+};
+
+} // namespace
+
+void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns(
+ MLIRContext *context, OwningRewritePatternList *patterns) {
+ patterns->insert<
+ PointwiseConverter<tosa::AddOp>, PointwiseConverter<tosa::SubOp>,
+ PointwiseConverter<tosa::PowOp>, PointwiseConverter<tosa::AbsOp>,
+ PointwiseConverter<tosa::TanhOp>, PointwiseConverter<tosa::BitwiseAndOp>,
+ PointwiseConverter<tosa::BitwiseOrOp>,
+ PointwiseConverter<tosa::BitwiseXorOp>,
+ PointwiseConverter<tosa::LogicalLeftShiftOp>,
+ PointwiseConverter<tosa::LogicalRightShiftOp>>(context);
+}
--- /dev/null
+//===- TosaToLinalgPass.cpp - Lowering Tosa to Linalg Dialect -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This transformation pass legalizes Tosa operations to the Linalg dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "../PassDetail.h"
+#include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h"
+#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/Dialect/Tosa/Transforms/PassDetail.h"
+#include "mlir/Dialect/Tosa/Transforms/Passes.h"
+#include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+
+namespace {
+struct TosaToLinalgOnTensors
+ : public TosaToLinalgOnTensorsBase<TosaToLinalgOnTensors> {
+public:
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<linalg::LinalgDialect, StandardOpsDialect>();
+ }
+
+ void runOnFunction() override {
+ OwningRewritePatternList patterns;
+ ConversionTarget target(getContext());
+ target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect>();
+ target.addIllegalDialect<tosa::TosaDialect>();
+ target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
+
+ FuncOp func = getFunction();
+ mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns(
+ func.getContext(), &patterns);
+ if (failed(applyFullConversion(func, target, std::move(patterns))))
+ signalPassFailure();
+ }
+};
+} // namespace
+
+std::unique_ptr<Pass> mlir::tosa::createTosaToLinalgOnTensors() {
+ return std::make_unique<TosaToLinalgOnTensors>();
+}
+
+void mlir::tosa::addTosaToLinalgOnTensorsPasses(OpPassManager &pm) {
+ pm.addNestedPass<FuncOp>(createTosaMakeBroadcastablePass());
+ pm.addNestedPass<FuncOp>(createTosaToLinalgOnTensors());
+}
--- /dev/null
+// RUN: mlir-opt --split-input-file --tosa-to-linalg-on-tensors %s -verify-diagnostics -o -| FileCheck %s
+
+// CHECK: #map = affine_map<() -> ()>
+
+// CHECK-LABEL: @test_abs
+func @test_abs(%arg0: tensor<f32>) -> tensor<f32> {
+ // CHECK: [[INIT:%.+]] = linalg.init_tensor [] : tensor<f32>
+ // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = []} ins(%arg0 : tensor<f32>) outs([[INIT]] : tensor<f32>) {
+ // CHECK: ^bb0(%arg1: f32, %arg2: f32):
+ // CHECK: [[ELEMENT:%.+]] = absf %arg1
+ // CHECK: linalg.yield [[ELEMENT]] : f32
+ // CHECK: } -> tensor<f32>
+
+ %0 = "tosa.abs"(%arg0) : (tensor<f32>) -> tensor<f32>
+
+ // CHECK: return [[GENERIC]]
+ return %0 : tensor<f32>
+}
+
+// -----
+
+// CHECK: #map = affine_map<(d0) -> (d0)>
+
+// CHECK-LABEL: @test_abs
+func @test_abs(%arg0: tensor<1xf32>) -> tensor<1xf32> {
+ // CHECK: [[INIT:%.+]] = linalg.init_tensor [1] : tensor<1xf32>
+ // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%arg0 : tensor<1xf32>) outs([[INIT]] : tensor<1xf32>) {
+ // CHECK: ^bb0(%arg1: f32, %arg2: f32):
+ // CHECK: [[ELEMENT:%.+]] = absf %arg1
+ // CHECK: linalg.yield [[ELEMENT]] : f32
+ // CHECK: } -> tensor<1xf32>
+ %0 = "tosa.abs"(%arg0) : (tensor<1xf32>) -> tensor<1xf32>
+
+ // CHECK: return [[GENERIC]]
+ return %0 : tensor<1xf32>
+}
+
+// -----
+
+// CHECK: #map = affine_map<(d0, d1) -> (d0, d1)>
+
+// CHECK-LABEL: @test_abs
+func @test_abs(%arg0: tensor<1x2xf32>) -> tensor<1x2xf32> {
+ // CHECK: [[INIT:%.+]] = linalg.init_tensor [1, 2] : tensor<1x2xf32>
+ // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<1x2xf32>) outs([[INIT]] : tensor<1x2xf32>) {
+ // CHECK: ^bb0(%arg1: f32, %arg2: f32):
+ // CHECK: [[ELEMENT:%.+]] = absf %arg1
+ // CHECK: linalg.yield [[ELEMENT]] : f32
+ // CHECK: } -> tensor<1x2xf32>
+ %0 = "tosa.abs"(%arg0) : (tensor<1x2xf32>) -> tensor<1x2xf32>
+
+ // CHECK: return [[GENERIC]]
+ return %0 : tensor<1x2xf32>
+}
+
+// -----
+
+func @test_add(%arg0: tensor<1xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
+ // expected-error @+1 {{failed to legalize operation 'tosa.add'}}
+ %0 = "tosa.add"(%arg0, %arg1) : (tensor<1xf32>, tensor<2xf32>) -> tensor<2xf32>
+ return %0 : tensor<2xf32>
+}
+
+// -----
+
+func @test_add(%arg0: tensor<1xf32>, %arg1: tensor<f32>) -> tensor<1xf32> {
+ // expected-error @+1 {{failed to legalize operation 'tosa.add'}}
+ %0 = "tosa.add"(%arg0, %arg1) : (tensor<1xf32>, tensor<f32>) -> tensor<1xf32>
+ return %0 : tensor<1xf32>
+}
+
+// -----
+
+func @test_abs(%arg0: tensor<?xf32>) -> tensor<?xf32> {
+ // expected-error @+1 {{failed to legalize operation 'tosa.abs'}}
+ %0 = "tosa.abs"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
+ return %0 : tensor<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @test_simple_f32
+func @test_simple_f32(%arg0: tensor<1xf32>) -> () {
+ // CHECK: linalg.generic
+ // CHECK: tanh
+ %0 = "tosa.tanh"(%arg0) : (tensor<1xf32>) -> tensor<1xf32>
+
+ // CHECK: linalg.generic
+ // CHECK: absf
+ %1 = "tosa.abs"(%arg0) : (tensor<1xf32>) -> tensor<1xf32>
+
+ // CHECK: linalg.generic
+ // CHECK: addf
+ %2 = "tosa.add"(%0, %0) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
+
+ // CHECK: linalg.generic
+ // CHECK: subf
+ %3 = "tosa.sub"(%0, %1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
+
+ // CHECK: linalg.generic
+ // CHECK: pow
+ %4 = "tosa.pow"(%1, %2) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @test_simple_i32
+func @test_simple_i32(%arg0: tensor<1xi32>) -> () {
+ // CHECK: linalg.generic
+ // CHECK: addi
+ %0 = "tosa.add"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+
+ // CHECK: linalg.generic
+ // CHECK: subi
+ %1 = "tosa.sub"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+
+ // CHECK: linalg.generic
+ // CHECK: and
+ %2 = "tosa.bitwise_and"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+
+ // CHECK: linalg.generic
+ // CHECK: or
+ %3 = "tosa.bitwise_or"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+
+ // CHECK: linalg.generic
+ // CHECK: xor
+ %4 = "tosa.bitwise_xor"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+
+ // CHECK: linalg.generic
+ // CHECK: shift_left
+ %5 = "tosa.logical_left_shift"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+
+ // CHECK: linalg.generic
+ // CHECK: shift_right_unsigned
+ %6 = "tosa.logical_right_shift"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+
+ return
+}
+