#define SHAPE_OPS
include "mlir/Dialect/Shape/IR/ShapeBase.td"
+include "mlir/Interfaces/CallInterfaces.td"
include "mlir/Interfaces/CastInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/OpAsmInterface.td"
+include "mlir/IR/FunctionInterfaces.td"
include "mlir/IR/SymbolInterfaces.td"
//===----------------------------------------------------------------------===//
def Shape_FunctionLibraryOp : Shape_Op<"function_library",
[AffineScope, IsolatedFromAbove, NoRegionArguments, SymbolTable, Symbol,
- NoTerminator, SingleBlock]> {
+ NoTerminator, OpAsmOpInterface, SingleBlock]> {
let summary = "Represents shape functions and corresponding ops";
let description = [{
Represents a list of shape functions and the ops whose shape transfer
```mlir
shape.function_library {
- func.func @same_result_shape(%arg: !shape.value_shape) -> !shape.shape {
- %0 = shape.shape_of %arg : !shape.value_shape -> !shape.shape
+ func @same_result_shape(%arg: !shape.value_shape) -> !shape.shape {
+ %0 = shape_of %arg : !shape.value_shape -> !shape.shape
return %0 : !shape.shape
}
} mapping {
let extraClassDeclaration = [{
/// Returns an associated shape function for an operation if defined.
- func::FuncOp getShapeFunction(Operation *op);
+ FuncOp getShapeFunction(Operation *op);
+
+ //===------------------------------------------------------------------===//
+ // OpAsmOpInterface
+ //===------------------------------------------------------------------===//
+
+ // This will filter the `shape.` prefix in front of operations inside the
+ // func body.
+ static StringRef getDefaultDialect() { return "shape";}
}];
let builders = [OpBuilder<(ins "StringRef":$name)>];
let hasCustomAssemblyFormat = 1;
}
+def Shape_FuncOp : Shape_Op<"func",
+ [AffineScope, AutomaticAllocationScope, CallableOpInterface,
+ FunctionOpInterface, IsolatedFromAbove, OpAsmOpInterface, Symbol]> {
+ let summary = "Shape function";
+ let description = [{
+ An operation with a name containing a single `SSACFG` region which
+ represents a shape transfer function or helper function for shape transfer
+ function.
+ }];
+
+ let arguments = (ins SymbolNameAttr:$sym_name,
+ TypeAttrOf<FunctionType>:$function_type,
+ OptionalAttr<StrAttr>:$sym_visibility);
+ let regions = (region AnyRegion:$body);
+
+ let extraClassDeclaration = [{
+ //===------------------------------------------------------------------===//
+ // CallableOpInterface
+ //===------------------------------------------------------------------===//
+
+ /// Returns the region on the current operation that is callable. This may
+ /// return null in the case of an external callable object, e.g. an external
+ /// function.
+ ::mlir::Region *getCallableRegion() { return isExternal() ? nullptr : &getBody(); }
+
+ /// Returns the results types that the callable region produces when
+ /// executed.
+ ArrayRef<Type> getCallableResults() { return getFunctionType().getResults(); }
+
+ //===------------------------------------------------------------------===//
+ // FunctionOpInterface Methods
+ //===------------------------------------------------------------------===//
+
+ /// Returns the argument types of this function.
+ ArrayRef<Type> getArgumentTypes() { return getFunctionType().getInputs(); }
+
+ /// Returns the result types of this function.
+ ArrayRef<Type> getResultTypes() { return getFunctionType().getResults(); }
+
+ //===------------------------------------------------------------------===//
+ // OpAsmOpInterface
+ //===------------------------------------------------------------------===//
+
+ // This will filter the `shape.` prefix in front of operations inside the
+ // func body.
+ static StringRef getDefaultDialect() { return "shape";}
+
+ //===------------------------------------------------------------------===//
+ // SymbolOpInterface Methods
+ //===------------------------------------------------------------------===//
+
+ bool isDeclaration() { return isExternal(); }
+ }];
+ let hasCustomAssemblyFormat = 1;
+}
+
+def Shape_ReturnOp : Shape_Op<"return",
+ [NoSideEffect, HasParent<"FuncOp">, ReturnLike, Terminator]> {
+ let summary = "Shape function return operation";
+ let description = [{
+ The `shape.return` operation represents a return operation within a function.
+ The operation takes variable number of operands and produces no results.
+ }];
+
+ let arguments = (ins Variadic<AnyType>:$operands);
+
+ let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?";
+
+ // TODO: Tighten verification.
+}
+
#endif // SHAPE_OPS
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/FunctionImplementation.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
::mlir::SymbolTable::getSymbolAttrName(), builder.getStringAttr(name)));
}
-func::FuncOp FunctionLibraryOp::getShapeFunction(Operation *op) {
+FuncOp FunctionLibraryOp::getShapeFunction(Operation *op) {
auto attr = getMapping()
.get(op->getName().getIdentifier())
.dyn_cast_or_null<FlatSymbolRefAttr>();
if (!attr)
return nullptr;
- return lookupSymbol<func::FuncOp>(attr);
+ return lookupSymbol<FuncOp>(attr);
}
ParseResult FunctionLibraryOp::parse(OpAsmParser &parser,
}
//===----------------------------------------------------------------------===//
+// FuncOp
+//===----------------------------------------------------------------------===//
+
+ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
+ auto buildFuncType =
+ [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
+ function_interface_impl::VariadicFlag,
+ std::string &) { return builder.getFunctionType(argTypes, results); };
+
+ return function_interface_impl::parseFunctionOp(
+ parser, result, /*allowVariadic=*/false, buildFuncType);
+}
+
+void FuncOp::print(OpAsmPrinter &p) {
+ function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false);
+}
+
+//===----------------------------------------------------------------------===//
// GetExtentOp
//===----------------------------------------------------------------------===//