#ifndef MLIR_SHAPE_IR_SHAPE_H
#define MLIR_SHAPE_IR_SHAPE_H
-#include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/OpAsmInterface.td"
-include "mlir/IR/SymbolInterfaces.td"
//===----------------------------------------------------------------------===//
// Shape op definitions
}
def Shape_YieldOp : Shape_Op<"yield",
- [HasParent<"ReduceOp, FunctionLibraryOp">,
+ [HasParent<"ReduceOp">,
NoSideEffect,
ReturnLike,
Terminator]> {
let hasFolder = 1;
}
-//===----------------------------------------------------------------------===//
-// Shape collection ops.
-//===----------------------------------------------------------------------===//
-
-def Shape_FunctionLibraryOp : Shape_Op<"function_library",
- [AffineScope, IsolatedFromAbove, NoRegionArguments, SymbolTable, Symbol,
- SingleBlockImplicitTerminator<"ShapeFunctionLibraryTerminatorOp">]> {
- let summary = "Represents shape functions and corresponding ops";
- let description = [{
- Represents a list of shape functions and the ops whose shape transfer
- functions they represent.
-
- Example:
-
- ```mlir
- shape.function_library {
- func @same_result_shape(%arg: !shape.value_shape) -> !shape.shape {
- %0 = shape.shape_of %arg : !shape.value_shape -> !shape.shape
- return %0 : !shape.shape
- }
- } mapping {
- std.atan = @same_result_shape
- }
- ```
- }];
-
- let arguments = (ins SymbolNameAttr:$sym_name,
- OptionalAttr<StrAttr>:$sym_visibility);
- let arguments = (ins DictionaryAttr:$mapping);
- let regions = (region AnyRegion:$body);
-
- let extraClassDeclaration = [{
- /// Returns an associated shape function for an operation if defined.
- FuncOp getShapeFunction(Operation *op);
- }];
-
- let builders = [OpBuilderDAG<(ins "StringRef":$name)>];
- let skipDefaultBuilders = 1;
-
- let printer = [{ ::print(p, *this); }];
- let parser = [{ return ::parse$cppClass(parser, result); }];
-}
-
-//===----------------------------------------------------------------------===//
-// ShapeFunctionLibraryTerminatorOp
-//===----------------------------------------------------------------------===//
-
-def ShapeFunctionLibraryTerminatorOp : Shape_Op<"fn_lib_terminator",
- [Terminator, HasParent<"FunctionLibraryOp">]> {
- let summary = "A pseudo op that marks the end of a shape function library";
- let description = [{
- `shape_fn_lib_terminator` is a special pseudo terminator operation for the
- shape function library. It has no semantic meaning beyond keeping the body
- well-formed.
- }];
- let assemblyFormat = "attr-dict";
-}
-
#endif // SHAPE_OPS
#include "mlir/Dialect/Traits.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/DialectImplementation.h"
-#include "mlir/IR/Function.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Transforms/InliningUtils.h"
}
//===----------------------------------------------------------------------===//
-// FunctionLibraryOp
-//===----------------------------------------------------------------------===//
-
-void FunctionLibraryOp::build(OpBuilder &builder, OperationState &result,
- StringRef name) {
- ensureTerminator(*result.addRegion(), builder, result.location);
- result.attributes.push_back(builder.getNamedAttr(
- ::mlir::SymbolTable::getSymbolAttrName(), builder.getStringAttr(name)));
-}
-
-FuncOp FunctionLibraryOp::getShapeFunction(Operation *op) {
- auto attr = mapping()
- .get(op->getName().getIdentifier())
- .dyn_cast_or_null<FlatSymbolRefAttr>();
- if (!attr)
- return nullptr;
- return lookupSymbol<FuncOp>(attr);
-}
-
-ParseResult parseFunctionLibraryOp(OpAsmParser &parser,
- OperationState &result) {
- // Parse the op name.
- StringAttr nameAttr;
- if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(),
- result.attributes))
- return failure();
-
- if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
- return failure();
-
- auto *bodyRegion = result.addRegion();
- if (parser.parseRegion(*bodyRegion))
- return failure();
-
- FunctionLibraryOp::ensureTerminator(*bodyRegion, parser.getBuilder(),
- result.location);
- if (parser.parseKeyword("mapping"))
- return failure();
-
- DictionaryAttr mappingAttr;
- if (parser.parseAttribute(mappingAttr,
- parser.getBuilder().getType<NoneType>(), "mapping",
- result.attributes))
- return failure();
- return success();
-}
-
-void print(OpAsmPrinter &p, FunctionLibraryOp op) {
- p << op.getOperationName() << ' ';
- p.printSymbolName(op.getName());
- p.printOptionalAttrDictWithKeyword(
- op.getAttrs(), {SymbolTable::getSymbolAttrName(), "mapping"});
- p.printRegion(op.getOperation()->getRegion(0), /*printEntryBlockArgs=*/false,
- /*printBlockTerminators=*/false);
- p << " mapping ";
- p.printAttributeWithoutType(op.mappingAttr());
-}
-
-//===----------------------------------------------------------------------===//
// GetExtentOp
//===----------------------------------------------------------------------===//
+++ /dev/null
-// RUN: mlir-opt %s --test-shape-function-report -verify-diagnostics
-
-// expected-remark@+1 {{associated shape function: same_result_shape}}
-func @tanh(%arg: tensor<10x20xf32>) -> tensor<10x20xf32>
- attributes {shape.function = @shape_lib::@same_result_shape} {
- // expected-remark@+1 {{no associated way}}
- %0 = tanh %arg : tensor<10x20xf32>
- // expected-remark@+1 {{associated shape function: same_result_shape}}
- %1 = "test.same_operand_result_type"(%0) : (tensor<10x20xf32>) -> tensor<10x20xf32>
- return %1 : tensor<10x20xf32>
-}
-
-// The shape function library with some local functions.
-shape.function_library @shape_lib {
- // Test shape function that returns the shape of input arg as result shape.
- func @same_result_shape(%arg: !shape.value_shape) -> !shape.shape {
- %0 = shape.shape_of %arg : !shape.value_shape -> !shape.shape
- return %0 : !shape.shape
- }
-} mapping {
- test.same_operand_result_type = @same_result_shape
-}
add_subdirectory(Affine)
-add_subdirectory(Shape)
add_subdirectory(SPIRV)
add_subdirectory(Test)
add_subdirectory(Tosa)
+++ /dev/null
-# Exclude tests from libMLIR.so
-add_mlir_library(MLIRShapeTestPasses
- TestShapeFunctions.cpp
-
- EXCLUDE_FROM_LIBMLIR
-
- ADDITIONAL_HEADER_DIRS
- ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Shape
- ${MLIR_MAIN_INCLUDE_DIR}/mlir/IR
-
- LINK_LIBS PUBLIC
- MLIRIR
- MLIRPass
- MLIRShape
- MLIRSupport
- )
+++ /dev/null
-//===- TestShapeFunctions.cpp - Passes to test shape function ------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include <queue>
-
-#include "mlir/Dialect/Shape/IR/Shape.h"
-#include "mlir/IR/BuiltinDialect.h"
-#include "mlir/Interfaces/InferTypeOpInterface.h"
-#include "mlir/Pass/Pass.h"
-
-using namespace mlir;
-
-namespace {
-/// This is a pass that reports shape functions associated with ops.
-struct ReportShapeFnPass
- : public PassWrapper<ReportShapeFnPass, OperationPass<ModuleOp>> {
- void runOnOperation() override;
-};
-} // end anonymous namespace
-
-void ReportShapeFnPass::runOnOperation() {
- auto module = getOperation();
-
- // Lookup shape function library.
- shape::FunctionLibraryOp shapeFnLib = nullptr;
- for (auto lib : module.getOps<shape::FunctionLibraryOp>()) {
- if (shapeFnLib) {
- lib.emitError("duplicate shape library op")
- .attachNote(shapeFnLib.getLoc())
- << "previous mapping";
- return signalPassFailure();
- }
- shapeFnLib = lib;
- };
-
- // Report the shape function available to refine the op.
- auto shapeFnId = Identifier::get("shape.function", &getContext());
- auto remarkShapeFn = [&](Operation *op) {
- if (op->isKnownTerminator())
- return;
- if (auto typeInterface = dyn_cast<InferTypeOpInterface>(op)) {
- op->emitRemark() << "implements InferType op interface";
- } else if (auto fn = shapeFnLib.getShapeFunction(op)) {
- op->emitRemark() << "associated shape function: " << fn.getName();
- } else if (auto symbol = op->getAttrOfType<SymbolRefAttr>(shapeFnId)) {
- auto fn = cast<FuncOp>(SymbolTable::lookupSymbolIn(module, symbol));
- op->emitRemark() << "associated shape function: " << fn.getName();
- } else {
- op->emitRemark() << "no associated way to refine shape";
- }
- };
-
- module.getBodyRegion().walk([&](FuncOp func) {
- // Skip ops in the shape function library.
- if (isa<shape::FunctionLibraryOp>(func.getParentOp()))
- return;
-
- func.walk([&](Operation *op) { remarkShapeFn(op); });
- });
-}
-
-namespace mlir {
-void registerShapeFunctionTestPasses() {
- PassRegistration<ReportShapeFnPass>(
- "test-shape-function-report",
- "Test pass to report associated shape functions");
-}
-} // namespace mlir
let results = (outs AnySignlessInteger:$result);
}
-def SameOperandsResultType : TEST_Op<
- "same_operand_result_type", [SameOperandsAndResultType]> {
- let arguments = (ins AnyTensor:$operand);
- let results = (outs AnyTensor:$result);
-}
-
//===----------------------------------------------------------------------===//
// Test Results
//===----------------------------------------------------------------------===//
if(MLIR_INCLUDE_TESTS)
set(test_libs
MLIRAffineTransformsTestPasses
- MLIRShapeTestPasses
MLIRSPIRVTestPasses
MLIRTestDialect
MLIRTestIR
void registerConvertToTargetEnvPass();
void registerPassManagerTestPass();
void registerPrintOpAvailabilityPass();
-void registerShapeFunctionTestPasses();
void registerSideEffectTestPasses();
void registerSliceAnalysisTestPass();
void registerSymbolTestPasses();
registerConvertToTargetEnvPass();
registerPassManagerTestPass();
registerPrintOpAvailabilityPass();
- registerShapeFunctionTestPasses();
registerSideEffectTestPasses();
registerSliceAnalysisTestPass();
registerSymbolTestPasses();