```c++
// All result-types/operands/attributes have one aggregate parameter.
-static void build(Builder *tblgen_builder, OperationState &tblgen_state,
+static void build(Builder *odsBuilder, OperationState &odsState,
ArrayRef<Type> resultTypes,
ValueRange operands,
ArrayRef<NamedAttribute> attributes);
// Each result-type/operand/attribute has a separate parameter. The parameters
// for attributes are of mlir::Attribute types.
-static void build(Builder *tblgen_builder, OperationState &tblgen_state,
+static void build(Builder *odsBuilder, OperationState &odsState,
Type i32_result, Type f32_result, ...,
Value i32_operand, Value f32_operand, ...,
IntegerAttr i32_attr, FloatAttr f32_attr, ...);
// for attributes are raw values unwrapped with mlir::Attribute instances.
// (Note that this builder will not always be generated. See the following
// explanation for more details.)
-static void build(Builder *tblgen_builder, OperationState &tblgen_state,
+static void build(Builder *odsBuilder, OperationState &odsState,
Type i32_result, Type f32_result, ...,
Value i32_operand, Value f32_operand, ...,
APInt i32_attr, StringRef f32_attr, ...);
// Each operand/attribute has a separate parameter but result type is aggregate.
-static void build(Builder *tblgen_builder, OperationState &tblgen_state,
+static void build(Builder *odsBuilder, OperationState &odsState,
ArrayRef<Type> resultTypes,
Value i32_operand, Value f32_operand, ...,
IntegerAttr i32_attr, FloatAttr f32_attr, ...);
// All operands/attributes have aggregate parameters.
// Generated if InferTypeOpInterface interface is specified.
-static void build(Builder *tblgen_builder, OperationState &tblgen_state,
+static void build(Builder *odsBuilder, OperationState &odsState,
ValueRange operands,
ArrayRef<NamedAttribute> attributes);
* The op's traits (e.g., commutative) are modelled along with the op in the
registry.
* The op's operand/return type constraints are modelled along with the op in
- the registry (see [Shape inference](#shape-inference) discussion below),
+ the registry (see [Shape inference](ShapeInference.md) discussion below),
this allows (e.g.) optimized concise syntax in textual dumps.
* Behavior of the op is documented along with the op with a summary and a
description. The description is written in markdown and extracted for
Printing is effectively the inverse of the parsing function generated with the
mnemonic string serving as a template.
-### Shape inference
-
-Type constraints are along (at least) three axis: 1) elemental type, 2) rank
-(including static or dynamic), 3) dimensions. While some ops have no compile
-time fixed shape (e.g., output shape is dictated by data) we could still have
-some knowledge of constraints/bounds in the system for that op (e.g., the output
-of a `tf.where` is at most the size of the input data). And so there are
-additional valuable constraints that could be captured even without full
-knowledge.
-
-Initially the shape inference will be declaratively specified using:
-
-* Constraint on the operands of an operation directly. For example
- constraining the input type to be tensor/vector elements or that the
- elemental type be of a specific type (e.g., output of sign is of elemental
- type `i1`) or class (e.g., float like).
-* Constraints across operands and results of an operation. For example,
- enabling specifying equality constraints on type/constituents of a type
- (shape and elemental type) between operands and results (e.g., the output
- type of an add is the same as those of the input operands).
-
-In general there is an input/output transfer function which maps the inputs to
-the outputs (e.g., given input X and Y [or slices thereof] with these sizes, the
-output is Z [or this slice thereof]). Such a function could be used to determine
-the output type (shape) for given input type (shape).
-
-But shape functions are determined by attributes and could be arbitrarily
-complicated with a wide-range of specification possibilities. Equality
-relationships are common (e.g., the elemental type of the output matches the
-primitive type of the inputs, both inputs have exactly the same type [primitive
-type and shape]) and so these should be easy to specify. Algebraic relationships
-would also be common (e.g., a concat of `[n,m]` and `[n,m]` matrix along axis 0
-is `[n+n, m]` matrix), while some ops only have defined shapes under certain
-cases (e.g., matrix multiplication of `[a,b]` and `[c,d]` is only defined if
-`b == c`). As ops are also verified, the shape inference need only specify rules
-for the allowed cases (e.g., shape inference for matmul can ignore the case
-where `b != c`), which would simplify type constraint specification.
-
-Instead of specifying an additional mechanism to specify a shape transfer
-function, the reference implementation of the operation will be used to derive
-the shape function. The reference implementation is general and can support the
-arbitrary computations needed to specify output shapes.
-
[TableGen]: https://llvm.org/docs/TableGen/index.html
[TableGenIntro]: https://llvm.org/docs/TableGen/LangIntro.html
[TableGenRef]: https://llvm.org/docs/TableGen/LangRef.html
--- /dev/null
+# Shape inference
+
+Shape inference as discussed here is considered a specific instance of type
+inference for [ShapedType][ShapedType]. Type constraints are along (at least)
+three axis: 1) elemental type, 2) rank (including static or dynamic), 3)
+dimensions. While some operations have no compile time fixed shape (e.g., output
+shape is dictated by data) we could still have some knowledge of
+constraints/bounds in the system for that operation (e.g., the output of a
+`tf.where` is at most the size of the input data). That is, there are additional
+valuable constraints that could be captured even without full knowledge of the
+shape.
+
+Type inference is currently modelled executionally for op creation using the
+[`InferTypeOpInterface`][InferTypeOpInterface], while
+`InferShapedTypeOpInterface` is used to implement the shape and element type
+inference. The return type can often be deduced from the deduced return shape
+and elemental type (queryable from `InferShapedTypeOpInterface`) and so type
+inference for tensor types can be implemented with `InferShapedTypeOpInterface`.
+
+## Shape functions
+
+The C++ interfaces are the base mechanism whereby shape inference is queried and
+executed, but not the intended way to specify shape constraints in general.
+
+Initially the shape inference will be declaratively specified using:
+
+* Constraints on the operands of an operation directly. For example
+ constraining the input type to be tensor/vector elements or that the
+ elemental type be of a specific type (e.g., output of computing the size
+ of a value is of elemental type `i1`) or class (e.g., float like).
+* Constraints across operands and results of an operation.
+
+ - For example, specifying equality constraints on type/constituents of a
+ type (shape and elemental type) between operands and results (e.g., the
+ output type of an add is the same as those of the input operands).
+
+NOTE: The C++ shape functions are an intermediate step until the shape dialect
+is more full-fledged, at which point the C++ functions should become the
+exceptional case.
+
+## Testing
+
+Shape inference is currently tested alongside type inference by
+`TestReturnTypeDriver` in the test dialect. The driver performs two checks:
+
+1. Verification that the return types specified matches the infered types. This
+ explicit check will be removed and made part of Op verificaton instead.
+2. Test the creation of Ops without specifying the return type explicitly in
+ function `testCreateFunctions` by creating new binary Ops (Op classes
+ specified in `TestReturnTypeDriver`) using 1) all operands to
+ `testCreateFunctions` as both operands, and 2) using combinations of input
+ operands of the function.
+
+## WIP/Future considerations
+
+Shape functions are determined by attributes and could be arbitrarily
+complicated with a wide-range of specification possibilities. Equality
+relationships are common (e.g., the elemental type of the output matches the
+primitive type of the inputs, both inputs have exactly the same type [primitive
+type and shape]) and so these should be easy to specify. Algebraic relationships
+would also be common (e.g., a concat of `[n,m]` and `[n,m]` matrix along axis 0
+is `[n+n, m]` matrix), while some ops only have defined shapes under certain
+cases (e.g., matrix multiplication of `[a,b]` and `[c,d]` is only defined if `b
+== c`).
+
+Instead of specifying an additional mechanism to specify a shape transfer
+function, the reference implementation of the operation will be used to derive
+the shape function. The reference implementation is general and can support the
+arbitrary computations needed to specify output shapes.
+
+[InferTypeOpInterface]: https://github.com/llvm/llvm-project/tree/master/mlir/include/mlir/Analysis/InferTypeOpInterface.td
+[ShapedType]: https://github.com/llvm/llvm-project/tree/master/mlir/include/mlir/IR/StandardTypes.h
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/OpDefinition.h"
-#include "mlir/IR/Operation.h"
-#include "mlir/IR/Types.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/SmallVector.h"
namespace mlir {
+/// ShapedTypeComponents that represents the components of a ShapedType.
+/// The components consist of
+/// - A ranked or unranked shape with the dimension specification match those
+/// of ShapeType's getShape() (e.g., dynamic dimension represented using
+/// ShapedType::kDynamicSize)
+/// - A element type, may be unset (nullptr)
+/// - A attribute, may be unset (nullptr)
+/// Used by ShapedType type inferences.
+class ShapedTypeComponents {
+ /// Internal storage type for shape.
+ using ShapeStorageT = SmallVector<int64_t, 3>;
+
+public:
+ /// Default construction is an unranked shape.
+ ShapedTypeComponents() : ranked(false), elementType(nullptr), attr(nullptr){};
+
+ template <typename Arg, typename = typename std::enable_if_t<
+ std::is_constructible<ShapeStorageT, Arg>::value>>
+ ShapedTypeComponents(Arg &&arg, Type elementType = nullptr,
+ Attribute attr = nullptr)
+ : dims(std::forward<Arg>(arg)), ranked(true), elementType(elementType),
+ attr(attr) {}
+ ShapedTypeComponents(ArrayRef<int64_t> vec, Type elementType = nullptr,
+ Attribute attr = nullptr)
+ : dims(vec.begin(), vec.end()), ranked(true), elementType(elementType),
+ attr(attr) {}
+
+ /// Return the dimensions of the shape.
+ /// Requires: shape is ranked.
+ ArrayRef<int64_t> getDims() const {
+ assert(ranked && "requires ranked shape");
+ return dims;
+ }
+
+ /// Return whether the shape has a rank.
+ bool hasRank() const { return ranked; };
+
+ /// Return the element type component.
+ Type getElementType() const { return elementType; };
+
+ /// Return the raw attribute component.
+ Attribute getAttribute() const { return attr; };
+
+private:
+ ShapeStorageT dims;
+ bool ranked;
+ Type elementType;
+ Attribute attr;
+};
+
#include "mlir/Analysis/InferTypeOpInterface.h.inc"
+namespace detail {
+// Helper function to infer return tensor returns types given element and shape
+// inference function.
+//
+// TODO: Consider generating typedefs for trait member functions if this usage
+// becomes more common.
+LogicalResult inferReturnTensorTypes(
+ function_ref<LogicalResult(
+ MLIRContext *, Optional<Location> location, ValueRange operands,
+ ArrayRef<NamedAttribute> attributes, RegionRange regions,
+ SmallVectorImpl<ShapedTypeComponents> &retComponents)>
+ componentTypeFn,
+ MLIRContext *context, Optional<Location> location, ValueRange operands,
+ ArrayRef<NamedAttribute> attributes, RegionRange regions,
+ SmallVectorImpl<Type> &inferedReturnTypes);
+} // namespace detail
+
namespace OpTrait {
+
+/// Tensor type inference trait that constructs a tensor from the infered
+/// shape and elemental types.
+/// Requires: Op implements functions of InferShapedTypeOpInterface.
template <typename ConcreteType>
-class TypeOpInterfaceDefault
- : public TraitBase<ConcreteType, TypeOpInterfaceDefault> {
+class InferTensorType : public TraitBase<ConcreteType, InferTensorType> {
public:
- /// Returns whether two arrays are equal as strongest check for compatibility
- /// by default.
- static bool isCompatibleReturnTypes(ArrayRef<Type> lhs, ArrayRef<Type> rhs) {
- return lhs == rhs;
- };
+ static LogicalResult
+ inferReturnTypes(MLIRContext *context, Optional<Location> location,
+ ValueRange operands, ArrayRef<NamedAttribute> attributes,
+ RegionRange regions,
+ SmallVectorImpl<Type> &inferedReturnTypes) {
+ return ::mlir::detail::inferReturnTensorTypes(
+ ConcreteType::inferReturnTypeComponents, context, location, operands,
+ attributes, regions, inferedReturnTypes);
+ }
};
-} // namespace OpTrait
+} // namespace OpTrait
} // namespace mlir
#endif // MLIR_ANALYSIS_INFERTYPEOPINTERFACE_H_
// mismatch).
def InferTypeOpInterface : OpInterface<"InferTypeOpInterface"> {
let description = [{
- Interface to access a registered method to infer the return types for an
- operation that could be used during op construction, verification or
- type inference.
+ Interface to infer the return types for an operation that could be used
+ during op construction, verification or type inference.
}];
let methods = [
}],
/*retTy=*/"LogicalResult",
/*methodName=*/"inferReturnTypes",
- /*args=*/(ins "Optional<Location>":$location,
+ /*args=*/(ins "MLIRContext*":$context,
+ "Optional<Location>":$location,
"ValueRange":$operands,
"ArrayRef<NamedAttribute>":$attributes,
"RegionRange":$regions,
];
}
+def InferShapedTypeOpInterface : OpInterface<"InferShapedTypeOpInterface"> {
+ let description = [{
+ Interface to infer the components of a ShapedType returned by an operation
+ that could be used during op construction, verification or shape inference.
+
+ The components consists of element type, shape and raw attribute.
+ }];
+
+ let methods = [
+ StaticInterfaceMethod<
+ /*desc=*/[{Infer the components of return type of shape containter.
+
+ The method takes an optional location which, if set, will be used to
+ report errors on. The operands and attributes correspond to those with
+ which an Operation would be created (e.g., as used in Operation::create)
+ and the regions of the op.
+
+ Unknown (e.g., unranked) shape and nullptrs for element type and attribute
+ may be returned by this function while returning success. E.g., partial
+ population of components is not error condition.
+ }],
+ /*retTy=*/"LogicalResult",
+ /*methodName=*/"inferReturnTypeComponents",
+ /*args=*/(ins "MLIRContext*":$context,
+ "Optional<Location>":$location,
+ "ValueRange":$operands,
+ "ArrayRef<NamedAttribute>":$attributes,
+ "RegionRange":$regions,
+ "SmallVectorImpl<ShapedTypeComponents>&":
+ $inferedReturnShapes)
+ >,
+ ];
+}
+
#endif // MLIR_INFERTYPEOPINTERFACE
// following signatures:
//
// ```c++
- // static void build(Builder *, OperationState &tblgen_state,
+ // static void build(Builder *, OperationState &odsState,
// Type <result0-name>, Type <result1-name>, ...,
// Value <arg0-name>, Value <arg1-name>, ...,
// Attribute <attr0-name>, Attribute <attr1-name>, ...);
// * where the attributes follow the same declaration order as in the op.
//
// ```c++
- // static void build(Builder *, OperationState &tblgen_state,
+ // static void build(Builder *, OperationState &odsState,
// ArrayRef<Type> resultTypes,
// ArrayRef<Value> operands,
// ArrayRef<NamedAttribute> attributes);
//===----------------------------------------------------------------------===//
#include "mlir/Analysis/InferTypeOpInterface.h"
-#include "mlir/IR/Attributes.h"
-#include "mlir/IR/Operation.h"
-#include "mlir/IR/Types.h"
-#include "llvm/ADT/SmallVector.h"
+
+#include "mlir/IR/StandardTypes.h"
+
+using namespace mlir;
namespace mlir {
#include "mlir/Analysis/InferTypeOpInterface.cpp.inc"
} // namespace mlir
+
+LogicalResult mlir::detail::inferReturnTensorTypes(
+ function_ref<LogicalResult(
+ MLIRContext *, Optional<Location> location, ValueRange operands,
+ ArrayRef<NamedAttribute> attributes, RegionRange regions,
+ SmallVectorImpl<ShapedTypeComponents> &retComponents)>
+ componentTypeFn,
+ MLIRContext *context, Optional<Location> location, ValueRange operands,
+ ArrayRef<NamedAttribute> attributes, RegionRange regions,
+ SmallVectorImpl<Type> &inferedReturnTypes) {
+ SmallVector<ShapedTypeComponents, 2> retComponents;
+ if (failed(componentTypeFn(context, location, operands, attributes, regions,
+ retComponents)))
+ return failure();
+ for (auto shapeAndType : retComponents) {
+ assert(shapeAndType.getAttribute() == nullptr && "attribute not supported");
+ if (shapeAndType.hasRank())
+ inferedReturnTypes.push_back(RankedTensorType::get(
+ shapeAndType.getDims(), shapeAndType.getElementType()));
+ else
+ inferedReturnTypes.push_back(
+ UnrankedTensorType::get(shapeAndType.getElementType()));
+ }
+ return success();
+}
}
LogicalResult mlir::OpWithInferTypeInterfaceOp::inferReturnTypes(
- llvm::Optional<Location> location, ValueRange operands,
+ MLIRContext *, Optional<Location> location, ValueRange operands,
ArrayRef<NamedAttribute> attributes, RegionRange regions,
SmallVectorImpl<Type> &inferedReturnTypes) {
if (operands[0].getType() != operands[1].getType()) {
return success();
}
+LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents(
+ MLIRContext *context, Optional<Location> location, ValueRange operands,
+ ArrayRef<NamedAttribute> attributes, RegionRange regions,
+ SmallVectorImpl<ShapedTypeComponents> &inferedComponents) {
+ // Create return type consisting of the first element of each shape of the
+ // input operands or unknown for unranked operand.
+ std::vector<int64_t> shape;
+ shape.reserve(operands.size());
+ for (auto operandType : operands.getTypes()) {
+ if (auto sval = operandType.dyn_cast<ShapedType>()) {
+ if (sval.hasRank())
+ shape.push_back(sval.getShape().front());
+ else
+ shape.push_back(ShapedType::kDynamicSize);
+ } else {
+ return emitOptionalError(location, "only shaped type operands allowed");
+ }
+ }
+ inferedComponents.reserve(1);
+ auto type = IntegerType::get(17, context);
+ inferedComponents.emplace_back(shape, type);
+ return success();
+}
+
// Static initialization for Test dialect registration.
static mlir::DialectRegistration<mlir::TestDialect> testDialect;
let results = (outs AnyTensor);
}
+def InferTensorType : NativeOpTrait<"InferTensorType">;
+def OpWithShapedTypeInferTypeInterfaceOp : TEST_Op<"op_with_shaped_type_infer_type_if",
+ [
+ // Op implements infer type op interface.
+ InferTypeOpInterface,
+ // The op will have methods implementing the ShapedType type infer interface.
+ DeclareOpInterfaceMethods<InferShapedTypeOpInterface>,
+ // The op produces tensors and will use the ShapedType type infer interface
+ // along with knowledge that it is producing Tensors to infer shape.
+ InferTensorType
+ ]> {
+ let arguments = (ins AnyTensor, AnyTensor);
+ let results = (outs AnyTensor);
+}
+
def IsNotScalar : Constraint<CPred<"$0.getType().getRank() != 0">>;
def UpdateAttr : Pat<(I32ElementsAttrOp $attr),
//===----------------------------------------------------------------------===//
namespace {
-struct ReturnTypeOpMatch : public RewritePattern {
- ReturnTypeOpMatch(MLIRContext *ctx)
- : RewritePattern(OpWithInferTypeInterfaceOp::getOperationName(), 1, ctx) {
- }
-
- PatternMatchResult matchAndRewrite(Operation *op,
- PatternRewriter &rewriter) const final {
- if (auto retTypeFn = dyn_cast<InferTypeOpInterface>(op)) {
- SmallVector<Value, 4> values(op->getOperands());
+// Generate ops for each instance where the type can be succesfully infered.
+template <typename OpTy>
+static void invokeCreateWithInferedReturnType(Operation *op) {
+ auto *context = op->getContext();
+ auto fop = op->getParentOfType<FuncOp>();
+ auto location = UnknownLoc::get(context);
+ OpBuilder b(op);
+ b.setInsertionPointAfter(op);
+
+ // Use permutations of 2 args as operands.
+ assert(fop.getNumArguments() >= 2);
+ for (int i = 0, e = fop.getNumArguments(); i < e; ++i) {
+ for (int j = 0; j < e; ++j) {
+ std::array<Value, 2> values = {fop.getArgument(i), fop.getArgument(j)};
SmallVector<Type, 2> inferedReturnTypes;
- if (failed(retTypeFn.inferReturnTypes(op->getLoc(), values,
- op->getAttrs(), op->getRegions(),
- inferedReturnTypes)))
- return matchFailure();
- SmallVector<Type, 1> resultTypes(op->getResultTypes());
- if (!retTypeFn.isCompatibleReturnTypes(inferedReturnTypes, resultTypes))
- return op->emitOpError(
- "inferred type incompatible with return type of operation"),
- matchFailure();
-
- // TODO(jpienaar): Split this out to make the test more focused.
- // Create new op with unknown location to verify building with
- // InferTypeOpInterface is triggered.
- auto fop = op->getParentOfType<FuncOp>();
- if (values[0] == fop.getArgument(0)) {
- // Use the 2nd function argument if the first function argument is used
- // when constructing the new op so that a new return type is inferred.
- values[0] = fop.getArgument(1);
- values[1] = fop.getArgument(1);
+ if (succeeded(OpTy::inferReturnTypes(context, llvm::None, values,
+ op->getAttrs(), op->getRegions(),
+ inferedReturnTypes))) {
+ OperationState state(location, OpTy::getOperationName());
// TODO(jpienaar): Expand to regions.
- rewriter.create<OpWithInferTypeInterfaceOp>(
- UnknownLoc::get(op->getContext()), values, op->getAttrs());
+ OpTy::build(&b, state, values, op->getAttrs());
+ (void)b.createOperation(state);
}
}
- return matchFailure();
}
-};
+}
struct TestReturnTypeDriver : public FunctionPass<TestReturnTypeDriver> {
void runOnFunction() override {
- mlir::OwningRewritePatternList patterns;
- populateWithGenerated(&getContext(), &patterns);
- patterns.insert<ReturnTypeOpMatch>(&getContext());
- applyPatternsGreedily(getFunction(), patterns);
+ if (getFunction().getName() == "testCreateFunctions") {
+ std::vector<Operation *> ops;
+ // Collect ops to avoid triggering on inserted ops.
+ for (auto &op : getFunction().getBody().front())
+ ops.push_back(&op);
+ // Generate test patterns for each, but skip terminator.
+ for (auto *op : llvm::makeArrayRef(ops).drop_back()) {
+ // Test create method of each of the Op classes below. The resultant
+ // output would be in reverse order underneath `op` from which
+ // the attributes and regions are used.
+ invokeCreateWithInferedReturnType<OpWithInferTypeInterfaceOp>(op);
+ invokeCreateWithInferedReturnType<OpWithShapedTypeInferTypeInterfaceOp>(
+ op);
+ };
+ return;
+ }
+
+ // Verification check.
+ // TODO: Move to ops that implement type infer interface.
+ getFunction().walk([this](Operation *op) -> void {
+ auto retTypeFn = dyn_cast<InferTypeOpInterface>(op);
+ if (!retTypeFn)
+ return;
+ auto *context = &getContext();
+ SmallVector<Type, 2> inferedReturnTypes;
+ if (failed(retTypeFn.inferReturnTypes(
+ context, op->getLoc(), op->getOperands(), op->getAttrs(),
+ op->getRegions(), inferedReturnTypes)))
+ return;
+ SmallVector<Type, 1> resultTypes(op->getResultTypes());
+ if (!retTypeFn.isCompatibleReturnTypes(inferedReturnTypes, resultTypes)) {
+ op->emitOpError(
+ "inferred type incompatible with return type of operation");
+ return;
+ }
+ });
}
};
} // end anonymous namespace
// ---
// DEF: void AOp::build(
-// DEF: tblgen_state.addAttribute("aAttr", aAttr);
-// DEF: tblgen_state.addAttribute("bAttr", bAttr);
+// DEF: odsState.addAttribute("aAttr", aAttr);
+// DEF: odsState.addAttribute("bAttr", bAttr);
// DEF: if (cAttr) {
-// DEF-NEXT: tblgen_state.addAttribute("cAttr", cAttr);
+// DEF-NEXT: odsState.addAttribute("cAttr", cAttr);
// DEF: void AOp::build(
// DEF: some-return-type aAttr, some-return-type bAttr, /*optional*/some-attr-kind cAttr
-// DEF: tblgen_state.addAttribute("aAttr", some-const-builder-call((*tblgen_builder), aAttr));
+// DEF: odsState.addAttribute("aAttr", some-const-builder-call((*odsBuilder), aAttr));
// DEF: void AOp::build(
// DEF: ArrayRef<NamedAttribute> attributes
-// DEF: tblgen_state.addAttributes(attributes);
+// DEF: odsState.addAttributes(attributes);
// Test verify method
// ---
// DEF-LABEL: MixOperandsAndAttrs definitions
// DEF-DAG: Value MixOperandsAndAttrs::operand()
// DEF-DAG: Value MixOperandsAndAttrs::otherArg()
-// DEF-DAG: void MixOperandsAndAttrs::build(Builder *tblgen_builder, OperationState &tblgen_state, FloatAttr attr, Value operand, FloatAttr otherAttr, Value otherArg)
+// DEF-DAG: void MixOperandsAndAttrs::build(Builder *odsBuilder, OperationState &odsState, FloatAttr attr, Value operand, FloatAttr otherAttr, Value otherArg)
// DEF-DAG: APFloat MixOperandsAndAttrs::attr()
// DEF-DAG: APFloat MixOperandsAndAttrs::otherAttr()
// DEF: bool UnitAttrOp::attr() {
// DEF: return {{.*}} != nullptr
-// DEF: build(Builder *tblgen_builder, OperationState &tblgen_state, /*optional*/UnitAttr attr)
+// DEF: build(Builder *odsBuilder, OperationState &odsState, /*optional*/UnitAttr attr)
// CHECK: FloatAttr attr2Attr()
// CHECK: Optional< APFloat > attr2();
// CHECK: static void build(Value val);
-// CHECK: static void build(Builder *tblgen_builder, OperationState &tblgen_state, Type r, ArrayRef<Type> s, Value a, ValueRange b, IntegerAttr attr1, /*optional*/FloatAttr attr2)
-// CHECK: static void build(Builder *tblgen_builder, OperationState &tblgen_state, Type r, ArrayRef<Type> s, Value a, ValueRange b, APInt attr1, /*optional*/FloatAttr attr2)
-// CHECK: static void build(Builder *, OperationState &tblgen_state, ArrayRef<Type> resultTypes, ValueRange operands, ArrayRef<NamedAttribute> attributes)
+// CHECK: static void build(Builder *odsBuilder, OperationState &odsState, Type r, ArrayRef<Type> s, Value a, ValueRange b, IntegerAttr attr1, /*optional*/FloatAttr attr2)
+// CHECK: static void build(Builder *odsBuilder, OperationState &odsState, Type r, ArrayRef<Type> s, Value a, ValueRange b, APInt attr1, /*optional*/FloatAttr attr2)
+// CHECK: static void build(Builder *, OperationState &odsState, ArrayRef<Type> resultTypes, ValueRange operands, ArrayRef<NamedAttribute> attributes)
// CHECK: static ParseResult parse(OpAsmParser &parser, OperationState &result);
// CHECK: void print(OpAsmPrinter &p);
// CHECK: LogicalResult verify();
// CHECK: void OpA::build
// CHECK: Value input
-// CHECK: tblgen_state.addOperands(input);
+// CHECK: odsState.addOperands(input);
// CHECK: void OpA::build
// CHECK: ValueRange operands
// CHECK: assert(operands.size() == 1u && "mismatched number of parameters");
-// CHECK: tblgen_state.addOperands(operands);
+// CHECK: odsState.addOperands(operands);
def OpB : NS_Op<"one_variadic_operand_op", []> {
let arguments = (ins Variadic<I32>:$input);
// CHECK-LABEL: OpB::build
// CHECK: ValueRange input
// CHECK-NOT: assert
-// CHECK: tblgen_state.addOperands(input);
+// CHECK: odsState.addOperands(input);
def OpD : NS_Op<"mix_variadic_and_normal_inputs_op", [SameVariadicOperandSize]> {
let arguments = (ins Variadic<AnyTensor>:$input1, AnyTensor:$input2, Variadic<AnyTensor>:$input3);
// CHECK-NEXT: return *getODSOperands(1).begin();
// CHECK-LABEL: OpD::build
-// CHECK-NEXT: tblgen_state.addOperands(input1);
-// CHECK-NEXT: tblgen_state.addOperands(input2);
-// CHECK-NEXT: tblgen_state.addOperands(input3);
+// CHECK-NEXT: odsState.addOperands(input1);
+// CHECK-NEXT: odsState.addOperands(input2);
+// CHECK-NEXT: odsState.addOperands(input3);
// CHECK-LABEL: void OpA::build
// CHECK: ArrayRef<Type> resultTypes, ValueRange operands
// CHECK: assert(resultTypes.size() == 1u && "mismatched number of return types");
-// CHECK-NEXT: tblgen_state.addTypes(resultTypes);
+// CHECK-NEXT: odsState.addTypes(resultTypes);
def OpB : NS_Op<"same_input_output_type_op", [SameOperandsAndResultType]> {
let arguments = (ins I32:$x);
}
// CHECK-LABEL: OpB definitions
-// CHECK: void OpB::build(Builder *tblgen_builder, OperationState &tblgen_state, Type y, Value x)
-// CHECK: tblgen_state.addTypes(y);
-// CHECK: void OpB::build(Builder *tblgen_builder, OperationState &tblgen_state, Value x)
-// CHECK: tblgen_state.addTypes({x.getType()});
+// CHECK: void OpB::build(Builder *odsBuilder, OperationState &odsState, Type y, Value x)
+// CHECK: odsState.addTypes(y);
+// CHECK: void OpB::build(Builder *odsBuilder, OperationState &odsState, Value x)
+// CHECK: odsState.addTypes({x.getType()});
def OpC : NS_Op<"three_normal_result_op", []> {
let results = (outs I32:$x, /*unnamed*/I32, I32:$z);
}
// CHECK-LABEL: OpC definitions
-// CHECK: void OpC::build(Builder *tblgen_builder, OperationState &tblgen_state, Type x, Type resultType1, Type z)
-// CHECK-NEXT: tblgen_state.addTypes(x)
-// CHECK-NEXT: tblgen_state.addTypes(resultType1)
-// CHECK-NEXT: tblgen_state.addTypes(z)
+// CHECK: void OpC::build(Builder *odsBuilder, OperationState &odsState, Type x, Type resultType1, Type z)
+// CHECK-NEXT: odsState.addTypes(x)
+// CHECK-NEXT: odsState.addTypes(resultType1)
+// CHECK-NEXT: odsState.addTypes(z)
def IntegerTypeAttr : TypeAttrBase<"IntegerType", "Integer type attribute">;
def OpD : NS_Op<"type_attr_as_result_type", [FirstAttrDerivedResultType]> {
}
// CHECK-LABEL: OpD definitions
-// CHECK: void OpD::build(Builder *, OperationState &tblgen_state, ValueRange operands, ArrayRef<NamedAttribute> attributes)
-// CHECK: tblgen_state.addTypes({attr.second.cast<TypeAttr>().getValue()});
+// CHECK: void OpD::build(Builder *, OperationState &odsState, ValueRange operands, ArrayRef<NamedAttribute> attributes)
+// CHECK: odsState.addTypes({attr.second.cast<TypeAttr>().getValue()});
def OpE : NS_Op<"value_attr_as_result_type", [FirstAttrDerivedResultType]> {
let arguments = (ins I32:$x, F32Attr:$attr);
}
// CHECK-LABEL: OpE definitions
-// CHECK: void OpE::build(Builder *, OperationState &tblgen_state, ValueRange operands, ArrayRef<NamedAttribute> attributes)
-// CHECK: tblgen_state.addTypes({attr.second.getType()});
+// CHECK: void OpE::build(Builder *, OperationState &odsState, ValueRange operands, ArrayRef<NamedAttribute> attributes)
+// CHECK: odsState.addTypes({attr.second.getType()});
def OpF : NS_Op<"one_variadic_result_op", []> {
let results = (outs Variadic<I32>:$x);
// CHECK-LABEL: void OpF::build
// CHECK-SAME: ArrayRef<Type> x
// CHECK-NOT: assert
-// CHECK: tblgen_state.addTypes(x);
+// CHECK: odsState.addTypes(x);
def OpG : NS_Op<"one_normal_and_one_variadic_result_op", []> {
// CHECK-LABEL: OpG definitions
-// CHECK: void OpG::build(Builder *tblgen_builder, OperationState &tblgen_state, Type x, ArrayRef<Type> y)
-// CHECK-NEXT: tblgen_state.addTypes(x);
-// CHECK-NEXT: tblgen_state.addTypes(y);
+// CHECK: void OpG::build(Builder *odsBuilder, OperationState &odsState, Type x, ArrayRef<Type> y)
+// CHECK-NEXT: odsState.addTypes(x);
+// CHECK-NEXT: odsState.addTypes(y);
// CHECK: void OpG::build
// CHECK: ArrayRef<Type> resultTypes
// CHECK: assert(resultTypes.size() >= 1u && "mismatched number of return types");
-// CHECK-NEXT: tblgen_state.addTypes(resultTypes);
+// CHECK-NEXT: odsState.addTypes(resultTypes);
def OpI : NS_Op<"mix_variadic_and_normal_results_op", [SameVariadicResultSize]> {
let results = (outs Variadic<AnyTensor>:$output1, AnyTensor:$output2, Variadic<AnyTensor>:$output3);
// CHECK-NEXT: return *getODSResults(1).begin();
// CHECK-LABEL: OpI::build
-// CHECK-NEXT: tblgen_state.addTypes(output1);
-// CHECK-NEXT: tblgen_state.addTypes(output2);
-// CHECK-NEXT: tblgen_state.addTypes(output3);
+// CHECK-NEXT: odsState.addTypes(output1);
+// CHECK-NEXT: odsState.addTypes(output2);
+// CHECK-NEXT: odsState.addTypes(output3);
// Test that if the only operand is variadic, we access the first value in the
// pack to set result type
let results = (outs AnyTensor:$result);
}
-// CHECK-LABEL: OpK::build(Builder *tblgen_builder, OperationState &tblgen_state, ValueRange input)
-// CHECK: tblgen_state.addTypes({input.front().getType()});
+// CHECK-LABEL: OpK::build(Builder *odsBuilder, OperationState &odsState, ValueRange input)
+// CHECK: odsState.addTypes({input.front().getType()});
// RUN: mlir-opt %s -test-return-type -split-input-file -verify-diagnostics | FileCheck %s --dump-input-on-failure
-// CHECK-LABEL: testReturnTypeOpInterface
-func @testReturnTypeOpInterface(%arg0 : tensor<10xf32>, %arg1 : tensor<20xi32>) {
- %good = "test.op_with_infer_type_if"(%arg0, %arg0) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
- // CHECK: test.op_with_infer_type_if
- // CHECK-SAME: tensor<20xi32>
- // CHECK: test.op_with_infer_type_if
- // CHECK-SAME: tensor<10xf32>
+// CHECK-LABEL: testCreateFunctions
+// This function tests invoking the create method with different inference
+// methods. The attributes of the ops inside are used to test creation.
+func @testCreateFunctions(%arg0 : tensor<10xf32>, %arg1 : tensor<20xi32>) {
+// CHECK: "test.no_attributes"
+ %good = "test.no_attributes"(%arg0, %arg0) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
+// CHECK: "test.op_with_shaped_type_infer_type_if"
+// CHECK-SAME: (tensor<10xf32>, tensor<10xf32>) -> tensor<10x10xi17>
+// CHECK: "test.op_with_shaped_type_infer_type_if"
+// CHECK-SAME: (tensor<10xf32>, tensor<20xi32>) -> tensor<10x20xi17>
+// CHECK: "test.op_with_shaped_type_infer_type_if"
+// CHECK-SAME: (tensor<20xi32>, tensor<10xf32>) -> tensor<20x10xi17>
+// CHECK: "test.op_with_shaped_type_infer_type_if"
+// CHECK-SAME: (tensor<20xi32>, tensor<20xi32>) -> tensor<20x20xi17>
+// CHECK: "test.op_with_infer_type_if"
+// CHECK-SAME: (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
+// CHECK: "test.op_with_infer_type_if"
+// CHECK-SAME: (tensor<20xi32>, tensor<20xi32>) -> tensor<20xi32>
return
}
//===----------------------------------------------------------------------===//
static const char *const tblgenNamePrefix = "tblgen_";
-static const char *const generatedArgName = "tblgen_arg";
-static const char *const builderOpState = "tblgen_state";
+static const char *const generatedArgName = "odsArg";
+static const char *const builderOpState = "odsState";
// The logic to calculate the actual value range for a declared operand/result
// of an op with variadic operands/results. Note that this logic is not for
// TODO(jpienaar): Expand to handle regions.
body << formatv(R"(
SmallVector<Type, 2> inferedReturnTypes;
- if (succeeded({0}::inferReturnTypes({1}.location, {1}.operands,
- {1}.attributes, /*regions=*/{{}, inferedReturnTypes)))
+ if (succeeded({0}::inferReturnTypes(odsBuilder->getContext(),
+ {1}.location, {1}.operands, {1}.attributes,
+ /*regions=*/{{}, inferedReturnTypes)))
{1}.addTypes(inferedReturnTypes);
else
llvm::report_fatal_error("Failed to infer result type(s).");)",
void OpEmitter::genInferedTypeCollectiveParamBuilder() {
// TODO(jpienaar): Expand to support regions.
const char *params =
- "Builder *builder, OperationState &{0}, "
+ "Builder *odsBuilder, OperationState &{0}, "
"ValueRange operands, ArrayRef<NamedAttribute> attributes";
auto &m =
opClass.newMethod("void", "build", formatv(params, builderOpState).str(),
auto &body = m.body();
body << formatv(R"(
SmallVector<Type, 2> inferedReturnTypes;
- if (succeeded({0}::inferReturnTypes({1}.location, operands, attributes,
+ if (succeeded({0}::inferReturnTypes(odsBuilder->getContext(),
+ {1}.location, operands, attributes,
/*regions=*/{{}, inferedReturnTypes)))
- build(builder, tblgen_state, inferedReturnTypes, operands, attributes);
+ build(odsBuilder, odsState, inferedReturnTypes, operands, attributes);
else
llvm::report_fatal_error("Failed to infer result type(s).");)",
opClass.getClassName(), builderOpState);
auto numResults = op.getNumResults();
resultTypeNames.reserve(numResults);
- paramList = "Builder *tblgen_builder, OperationState &";
+ paramList = "Builder *odsBuilder, OperationState &";
paramList.append(builderOpState);
switch (typeParamKind) {
// If this is a raw value, then we need to wrap it in an Attribute
// instance.
FmtContext fctx;
- fctx.withBuilder("(*tblgen_builder)");
+ fctx.withBuilder("(*odsBuilder)");
std::string value =
tgfmt(attr.getConstBuilderTemplate(), &fctx, namedAttr.name);
body << formatv(" {0}.addAttribute(\"{1}\", {2});\n", builderOpState,