From 940d3e08cf05bcd2779f6d3879850b2606274e3f Mon Sep 17 00:00:00 2001 From: TatWai Chong Date: Mon, 14 Nov 2022 17:29:45 -0800 Subject: [PATCH] [mlir][tosa] Create a profile validation pass for TOSA dialect Add a separate validation pass to check if TOSA operations match with the specification against given requirement. Perform profile type checking as the initial feature in the pass. This is an optional pass that can be enabled via command line. e.g. $mlir-opt --tosa-validate="profile=bi" for validating against the base inference profile. Description: TOSA defines a variety of operator behavior and requirements in the specification. It would be helpful to have a separate validation pass to keep TOSA operation input match with TOSA specification for given criteria, and also diminish the burden of dialect validation during compilation. TOSA supports three profiles of which two are for inference purposes. The main inference profile supports both integer and floating-point data types, but the base inference profile only supports integers. In this initial PR, validate the operations against a given profile of TOSA, so that validation would fail if a floating point tensor is present when the base inference profile is selected. Afterward, others checking will be added to the pass if needed. e.g. control flow operators and custom operators validation. The pass is expected to be able to run on any point of TOSA dialect conversion/transformation pipeline, and not depend on a particular pass run ahead. So that it is can be used to validate the initial tosa operations just converted from other dialects, the intermediate form, or the final tosa operations output. Change-Id: Ib58349c873c783056e89d2ab3b3312b8d2c61863 Reviewed By: rsuderman Differential Revision: https://reviews.llvm.org/D137279 --- .../mlir/Dialect/Tosa/Transforms/CMakeLists.txt | 2 + mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h | 2 + .../include/mlir/Dialect/Tosa/Transforms/Passes.td | 25 ++++++++ .../Conversion/TosaToLinalg/TosaToLinalgPass.cpp | 1 + mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt | 1 + .../lib/Dialect/Tosa/Transforms/TosaValidation.cpp | 68 ++++++++++++++++++++++ 6 files changed, 99 insertions(+) create mode 100644 mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/Tosa/Transforms/CMakeLists.txt index b1363b5..d4e2661 100644 --- a/mlir/include/mlir/Dialect/Tosa/Transforms/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/Tosa/Transforms/CMakeLists.txt @@ -1,5 +1,7 @@ set(LLVM_TARGET_DEFINITIONS Passes.td) mlir_tablegen(Passes.h.inc -gen-pass-decls -name TosaOpt) +mlir_tablegen(PassesEnums.h.inc -gen-enum-decls) +mlir_tablegen(PassesEnums.cpp.inc -gen-enum-defs) add_public_tablegen_target(MLIRTosaPassIncGen) add_dependencies(mlir-headers MLIRTosaPassIncGen) diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h index 9de328a..d6ae781 100644 --- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h @@ -14,6 +14,7 @@ #define MLIR_DIALECT_TOSA_TRANSFORMS_PASSES_H #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tosa/Transforms/PassesEnums.h.inc" #include "mlir/Pass/Pass.h" namespace mlir { @@ -37,6 +38,7 @@ std::unique_ptr createTosaInferShapesPass(); std::unique_ptr createTosaMakeBroadcastablePass(); std::unique_ptr createTosaTestQuantUtilAPIPass(); std::unique_ptr createTosaOptionalDecompositions(); +std::unique_ptr createTosaValidationPass(); #define GEN_PASS_REGISTRATION #include "mlir/Dialect/Tosa/Transforms/Passes.h.inc" diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td index 46bd7a4..c1334be 100644 --- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td @@ -13,6 +13,7 @@ #ifndef MLIR_DIALECT_TOSA_TRANSFORMS_PASSES #define MLIR_DIALECT_TOSA_TRANSFORMS_PASSES +include "mlir/IR/EnumAttr.td" include "mlir/Pass/PassBase.td" def TosaLayerwiseConstantFoldPass : Pass<"tosa-layerwise-constant-fold", "func::FuncOp"> { @@ -63,4 +64,28 @@ def TosaOptionalDecompositions let constructor = "tosa::createTosaOptionalDecompositions()"; } +def TosaProfileType : I32EnumAttr<"TosaProfileEnum", "Tosa profile", + [ + I32EnumAttrCase<"BaseInference", 0, "bi">, + I32EnumAttrCase<"MainInference", 1, "mi">, + I32EnumAttrCase<"MainTraining", 2, "mt">, + I32EnumAttrCase<"Undefined", 3> + ]>{ + let cppNamespace = "mlir::tosa"; +} + +def TosaValidation : Pass<"tosa-validate", "func::FuncOp"> { + let summary = "Validates TOSA dialect"; + let description = [{ + This pass validates if input TOSA operations match the specification for given + criteria, e.g. TOSA profile. + }]; + let constructor = "createTosaValidationPass()"; + + let options = [ + Option<"profileName", "profile", "std::string", + /*default=*/"\"undefined\"", + "Validation if ops match for given profile">]; +} + #endif // MLIR_DIALECT_TOSA_TRANSFORMS_PASSES diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp index 1cb6e20..5290923c 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp @@ -84,5 +84,6 @@ void mlir::tosa::addTosaToLinalgPasses(OpPassManager &pm, // TODO: Remove pass that operates on const tensor and enable optionality pm.addNestedPass(tosa::createTosaLayerwiseConstantFoldPass()); pm.addNestedPass(tosa::createTosaMakeBroadcastablePass()); + pm.addNestedPass(tosa::createTosaValidationPass()); pm.addNestedPass(tosa::createTosaToLinalg()); } diff --git a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt index ae55269..4f5a54d 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt @@ -7,6 +7,7 @@ add_mlir_dialect_library(MLIRTosaTransforms TosaLayerwiseConstantFoldPass.cpp TosaMakeBroadcastable.cpp TosaOptionalDecompositions.cpp + TosaValidation.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tosa/Transforms diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp new file mode 100644 index 0000000..36a7a3c --- /dev/null +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp @@ -0,0 +1,68 @@ +//===- TosaValidation.cpp ------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Validate if TOSA dialect input matchs with the specification for given +// requirements. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Tosa/Transforms/Passes.h" +#include "mlir/Dialect/Tosa/Transforms/PassesEnums.cpp.inc" + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { +namespace tosa { +#define GEN_PASS_DEF_TOSAVALIDATION +#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc" +} // namespace tosa +} // namespace mlir + +using namespace mlir; +using namespace mlir::tosa; + +namespace { + +//===----------------------------------------------------------------------===// +// TOSA Validation Pass. +//===----------------------------------------------------------------------===// + +struct TosaValidation : public tosa::impl::TosaValidationBase { +public: + explicit TosaValidation() {} + +private: + void runOnOperation() override; + + llvm::Optional profile_type; +}; + +void TosaValidation::runOnOperation() { + profile_type = symbolizeEnum(profileName); + + getOperation().walk([&](Operation *op) { + for (Value operand : op->getOperands()) { + if ((profile_type == TosaProfileEnum::BaseInference) && + getElementTypeOrSelf(operand).isa()) { + return signalPassFailure(); + } + } + }); +} +} // namespace + +std::unique_ptr mlir::tosa::createTosaValidationPass() { + return std::make_unique(); +} -- 2.7.4