From e6f2f17f05a1248b069ba830c4afffd61ee2f297 Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Fri, 11 Sep 2020 06:19:07 -0400 Subject: [PATCH] [mlir][Linalg] Refactor StructuredOpInterface - NFC This revision refactors and cleans up a bunch of things to simplify StructuredOpInterface before work can proceed on Linalg on tensors: - break out pieces of the StructuredOps trait that are part of the StructuredOpInterface, - drop referenceIterators and referenceIndexingMaps that end up being more confusing than useful, - drop NamedStructuredOpTrait --- .../mlir/Dialect/Linalg/IR/LinalgStructuredOps.td | 61 ++- .../Linalg/IR/LinalgStructuredOpsInterface.td | 500 +++++++++++++++++---- mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h | 316 +------------ mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 25 +- mlir/test/Dialect/Linalg/invalid.mlir | 19 +- .../mlir-linalg-ods-gen/test-linalg-ods-gen.tc | 21 +- .../mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp | 43 +- 7 files changed, 489 insertions(+), 496 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td index e003fd1..ac6e931 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -130,21 +130,22 @@ def CopyOp : LinalgStructured_Op<"copy", [ let extraClassDeclaration = libraryCallName # [{ // Rank-polymorphic. // filling_value -> O(ivs) with parallel iterators. - llvm::Optional> referenceIterators() { - unsigned nPar = input().getType().cast().getRank(); - return SmallVector(nPar, getParallelIteratorTypeName()); + ArrayAttr iterator_types() { + unsigned nPar = getInputShapedType(0).getRank(); + return Builder(getContext()).getStrArrayAttr( + SmallVector(nPar, getParallelIteratorTypeName())); } // I(input_perm(ivs)) -> O(output_perm(ivs)) - llvm::Optional> referenceIndexingMaps() { + ArrayAttr indexing_maps() { MLIRContext *context = getContext(); auto maybeInputMap = inputPermutation(); auto maybeOutputMap = outputPermutation(); unsigned inputRank = getInputShapedType(0).getRank(); unsigned outputRank = getOutputShapedType(0).getRank(); - return SmallVector{ + return Builder(getContext()).getAffineMapArrayAttr({ extractOrIdentityMap(maybeInputMap, inputRank, context), - extractOrIdentityMap(maybeOutputMap, outputRank, context)}; + extractOrIdentityMap(maybeOutputMap, outputRank, context)}); } Value getSource() { return input();} @@ -163,16 +164,17 @@ def FillOp : LinalgStructured_Op<"fill", [NInputs<0>, NOutputs<1>]> { let extraClassDeclaration = libraryCallName # [{ // Rank-polymorphic. // filling_value -> O(ivs) with parallel iterators. - llvm::Optional> referenceIterators() { - unsigned nPar = output().getType().cast().getRank(); - return SmallVector(nPar, getParallelIteratorTypeName()); + ArrayAttr iterator_types() { + unsigned nPar = getOutputShapedType(0).getRank(); + return Builder(getContext()).getStrArrayAttr( + SmallVector(nPar, getParallelIteratorTypeName())); } - llvm::Optional> referenceIndexingMaps() { + ArrayAttr indexing_maps() { MLIRContext *context = getContext(); // filling_value -> O(ivs) - return SmallVector{ - extractOrIdentityMap(llvm::None, getNumParallelLoops(), context)}; + return Builder(getContext()).getAffineMapArrayAttr({ + extractOrIdentityMap(llvm::None, getNumParallelLoops(), context)}); } }]; @@ -295,7 +297,7 @@ def ConvOp : PoolingBase_Op<"conv", [NInputs<2>, NOutputs<1>]> { getNumOutputFeatureDimensions(); } - llvm::Optional> referenceIterators() { + ArrayAttr iterator_types() { // Outer parallel loops are always the number of output dimensions; i.e. // [b, xs, q] in the TF notation above. unsigned nPar = getOutputShapedType(0).getRank(); @@ -310,7 +312,7 @@ def ConvOp : PoolingBase_Op<"conv", [NInputs<2>, NOutputs<1>]> { iters.reserve(nPar + nRed + nWin); iters.append(nRed, getReductionIteratorTypeName()); iters.append(nWin, getWindowIteratorTypeName()); - return iters; + return Builder(getContext()).getStrArrayAttr(iters); } // F(z0, ..., zN-1, q, k) * @@ -318,7 +320,7 @@ def ConvOp : PoolingBase_Op<"conv", [NInputs<2>, NOutputs<1>]> { // -> O(b, x0, ..., xN-1, k) // for N equal to `nWindow`. If there is no padding attribute, it will be // ignored. - llvm::Optional> referenceIndexingMaps() { + ArrayAttr indexing_maps() { MLIRContext *context = getContext(); auto nWin = getNumWindowLoops(); assert(nWin > 0 && "expected at least one window dimension"); @@ -343,7 +345,7 @@ def ConvOp : PoolingBase_Op<"conv", [NInputs<2>, NOutputs<1>]> { auto zs = makeAffineDimExprs(nWin, idx, context); // Construct the weighedSum expression. auto ws = weightedPoolingInputIndex(*this, xs, zs); - return SmallVector{ + return Builder(getContext()).getAffineMapArrayAttr({ // filter[z[0], ..., z[N-1], q, k] AffineMap::get(idx, 0, concat(concat(zs, qs), ks), context), // input[b, @@ -353,7 +355,7 @@ def ConvOp : PoolingBase_Op<"conv", [NInputs<2>, NOutputs<1>]> { // q] AffineMap::get(idx, 0, concat(concat(bs, ws), qs), context), // output[b, x[0], ..., x[N-1], k] - AffineMap::get(idx, 0, concat(concat(bs, xs), ks), context)}; + AffineMap::get(idx, 0, concat(concat(bs, xs), ks), context)}); } }]; @@ -384,7 +386,7 @@ class SingleInputPoolingBase_Op OptionalAttr:$padding); let extraClassDeclaration = commonUtils# [{ - llvm::Optional> referenceIterators() { + ArrayAttr iterator_types() { // Outer parallel loops are always the number of output dimensions. unsigned nPar = getOutputShapedType(0).getRank(); // The window loops has the same number loops with output dimensions. @@ -392,10 +394,10 @@ class SingleInputPoolingBase_Op SmallVector iters(nPar, getParallelIteratorTypeName()); iters.reserve(nPar + nWin); iters.append(nWin, getWindowIteratorTypeName()); - return iters; + return Builder(getContext()).getStrArrayAttr(iters); } - llvm::Optional> referenceIndexingMaps() { + ArrayAttr indexing_maps() { MLIRContext *context = getContext(); auto nPar = getNumParallelLoops(); auto nWin = getNumWindowLoops(); @@ -406,14 +408,13 @@ class SingleInputPoolingBase_Op // Construct the weighedSum expression. auto inputDims = weightedPoolingInputIndex(*this, outputDims, windowDims); - return SmallVector{ + return Builder(getContext()).getAffineMapArrayAttr({ // input AffineMap::get(idx, 0, inputDims, context), // windowDims AffineMap::get(idx, 0, windowDims, context), // output - AffineMap::get(idx, 0, outputDims, context) - }; + AffineMap::get(idx, 0, outputDims, context)}); } }]; @@ -466,7 +467,7 @@ class GenericOpBase : LinalgStructuredBase_Op:$library_call, Confined, [IntMinValue<0>]>:$symbol_source); - let results = (outs Variadic:$output_tensors); + let results = (outs Variadic:$output_lis); let regions = (region AnyRegion:$region); let extraClassDeclaration = [{ SmallVector linalgTraitAttrNames() { @@ -485,16 +486,6 @@ class GenericOpBase : LinalgStructuredBase_Op> referenceIterators() { - llvm_unreachable( - "No such thing as reference iterator types for a generic op."); - } - - llvm::Optional> referenceIndexingMaps() { - llvm_unreachable( - "No such thing as reference indexing maps for a generic op."); - } - llvm::Optional getSymbolSource() { auto ss = symbol_source(); return ss.hasValue() ? @@ -807,8 +798,6 @@ def IndexedGenericOp : GenericOpBase<"indexed_generic"> { // Named Linalg ops, implemented as a declarative configurations of generic ops. //===----------------------------------------------------------------------===// -def NamedStructuredOpTraits : NativeOpTrait<"linalg::NamedStructuredOpTraits">; - class LinalgNamedStructured_Op props> : LinalgStructuredBase_Op { string spec = ?; diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td index 82882b0..f32b70e 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td @@ -23,168 +23,486 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> { // Loop types handling. //===------------------------------------------------------------------===// InterfaceMethod< - "Return the number of parallel loops within the current operation.", - "unsigned", "getNumParallelLoops" + /*desc=*/[{ + Return the number of parallel loops within the current operation. + }], + /*retTy=*/"unsigned", + /*methodName=*/"getNumParallelLoops", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return getNumIterators(getParallelIteratorTypeName(), + $_op.iterator_types()); + }] >, InterfaceMethod< - "Return the number of reduction loops within the current operation.", - "unsigned", "getNumReductionLoops" + /*desc=*/[{ + Return the number of reduction loops within the current operation. + }], + /*retTy=*/"unsigned", + /*methodName=*/"getNumReductionLoops", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return getNumIterators(getReductionIteratorTypeName(), + $_op.iterator_types()); + }] >, InterfaceMethod< - "Return the number of window loops within the current operation.", - "unsigned", "getNumWindowLoops" + /*desc=*/[{ + Return the number of window loops within the current operation. + }], + /*retTy=*/"unsigned", + /*methodName=*/"getNumWindowLoops", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return getNumIterators(getWindowIteratorTypeName(), + $_op.iterator_types()); + }] >, InterfaceMethod< - "Return the number of loops within the current operation.", - "unsigned", "getNumLoops">, - + /*desc=*/[{ + Return the total number of loops within the current operation. + }], + /*retTy=*/"unsigned", + /*methodName=*/"getNumLoops", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return getNumIterators($_op.iterator_types()); + }] + >, InterfaceMethod< - [{Returns true if the current operation has only one loop and it's a - reduction loop}], - "bool", "hasSingleReductionLoop">, - + /*desc=*/[{ + Returns true if the current operation has only one loop and it's a + reduction loop. + }], + /*retTy=*/"bool", + /*methodName=*/"hasSingleReductionLoop", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + auto iters = $_op.iterator_types(); + return iters.size() == 1 && + getNumIterators(getReductionIteratorTypeName(), iters) == 1; + }]>, //===------------------------------------------------------------------===// - // Input arguments handling. + // Num input/output arguments handling. //===------------------------------------------------------------------===// + // These special methods must be defined by each op that wants to implement + // the LinalgStructuredInterface. For now, this is either: + // - inherited statically by using the NInputs or + // NOutputs traits. + // - derived from args_in/args_out attributes (for linalg.generic and + // linalg.indexed_generic ops). + InterfaceMethod< + /*desc=*/[{ + Return the number of inputs from the current operation. + }], + /*retTy=*/"unsigned", + /*methodName=*/"getNumInputs" + >, InterfaceMethod< - "Return the number of inputs from the current operation.", - "unsigned", "getNumInputs" + /*desc=*/[{ + Return the number of outputs from the current operation. + }], + /*retTy=*/"unsigned", + /*methodName=*/"getNumOutputs" >, - InterfaceMethod<"Return the input view at the given index.", - "Value", "getInput", (ins "unsigned":$i) + //===------------------------------------------------------------------===// + // Input arguments handling. + //===------------------------------------------------------------------===// + InterfaceMethod< + /*desc=*/[{ + Return the `i`-th input value. + The `i^th` input argument is always the `i^th` operand regardless of + whether we have tensors or buffers. + }], + /*retTy=*/"Value", + /*methodName=*/"getInput", + /*args=*/(ins "unsigned":$i), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + assert(i < $_op.getNumInputs()); + return this->getOperation()->getOperand(i); + }] >, - InterfaceMethod<[{ + InterfaceMethod< + /*desc=*/[{ Return the index of the given input value `v`, or `None` if the value is not an input. }], - "llvm::Optional", "getIndexOfInput", (ins "Value":$v) + /*retTy=*/"llvm::Optional", + /*methodName=*/"getIndexOfInput", + /*args=*/(ins "Value":$value), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + auto it = llvm::find(getInputs(), value); + if (it != getInputs().end()) + return it - getInputs().begin(); + return llvm::None; + }] >, InterfaceMethod< - "Return the input operands from the current operation.", - "Operation::operand_range", "getInputs" - >, - InterfaceMethod<[{ + /*desc=*/[{ Return the `i`-th input shaped type, irrespective of buffer or tensor type. - }], "ShapedType", "getInputShapedType", (ins "unsigned":$i)>, - InterfaceMethod<[{ + }], + /*retTy=*/"ShapedType", + /*methodName=*/"getInputShapedType", + /*args=*/(ins "unsigned":$i), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return getInput(i).getType().template cast(); + }] + >, + InterfaceMethod< + /*desc=*/[{ + Return the input operands from the current operation. + }], + /*retTy=*/"Operation::operand_range", + /*methodName=*/"getInputs", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + auto range = this->getOperation()->getOperands(); + return {range.begin(), range.begin() + $_op.getNumInputs()}; + }] + >, + InterfaceMethod< + /*desc=*/[{ Return the subset of input operands that are of ranked tensor type. - }], "SmallVector", "getInputTensorTypes">, + }], + /*retTy=*/"SmallVector", + /*methodName=*/"getInputTensorTypes" , + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + SmallVector res; + for (Type type : getInputs().getTypes()) + if (auto t = type.template dyn_cast()) + res.push_back(t); + return res; + }] + >, //===------------------------------------------------------------------===// // Output arguments handling. //===------------------------------------------------------------------===// InterfaceMethod< - "Return the number of outputs from the current operation.", - "unsigned", "getNumOutputs" - >, - InterfaceMethod<"Return the output buffer at the given index.", - "Value", "getOutputBuffer", (ins "unsigned":$i) + /*desc=*/[{ + Return the output buffer at the given index, asserts that this is a + buffer operand and not a tensor result. + The `i^th` output argument is an operand (resp. a return value) iff it + is a value of buffer type (resp. a return value of tensor type). + }], + /*retTy=*/"Value", + /*methodName=*/"getOutputBuffer", + /*args=*/(ins "unsigned":$i), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + // Output buffers are passed as output buffer operands (side-effecting). + // Output tensors are results. + // The union of the 2 are all the outputs and we want to ensure i does + // not overflow the buffer operands. + assert(i + this->getOperation()->getNumResults() < $_op.getNumOutputs() + && "overflowing output buffer index"); + return this->getOperation()->getOperand($_op.getNumInputs() + i); + }] >, - InterfaceMethod<[{ + InterfaceMethod< + /*desc=*/[{ Return the index of the given buffer value, or `None` if the value is not part of the output buffers. }], - "llvm::Optional", "getIndexOfOutputBuffer", (ins "Value":$view) + /*retTy=*/"llvm::Optional", + /*methodName=*/"getIndexOfOutputBuffer", + /*args=*/(ins "Value":$value), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + auto it = llvm::find(getOutputBuffers(), value); + if (it != getOutputBuffers().end()) + return it - getOutputBuffers().begin(); + return llvm::None; + }] >, - InterfaceMethod<[{ + InterfaceMethod< + /*desc=*/[{ Return the type of the output buffer at the given index. - }], "MemRefType", "getOutputBufferType", (ins "unsigned":$i)>, - InterfaceMethod<[{ + }], + /*retTy=*/"MemRefType", + /*methodName=*/"getOutputBufferType", + /*args=*/(ins "unsigned":$i), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return getOutputBuffer(i).getType().template cast(); + }]>, + InterfaceMethod< + /*desc=*/[{ Return the `i`-th output shaped type, irrespective of buffer or tensor type. - }], "ShapedType", "getOutputShapedType", (ins "unsigned":$i)>, - InterfaceMethod<[{ + }], + /*retTy=*/"ShapedType", + /*methodName=*/"getOutputShapedType", + /*args=*/(ins "unsigned":$i), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return getShapedType(i + $_op.getNumInputs()); + }]>, + InterfaceMethod< + /*desc=*/[{ Return the results that are of ranked tensor type. - }], "SmallVector", "getOutputTensorTypes">, + }], + /*retTy=*/"SmallVector", + /*methodName=*/"getOutputTensorTypes", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + SmallVector res; + for (Type type : this->getOperation()->getResults().getTypes()) + res.push_back(type.template cast()); + return res; + }]>, InterfaceMethod< - "Return the output buffers (operands) from the current operation.", - "Operation::operand_range", "getOutputBuffers" + /*desc=*/[{ + Return the output buffers (operands) from the current operation. + }], + /*retTy=*/"Operation::operand_range", + /*methodName=*/"getOutputBuffers", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + auto range = this->getOperation()->getOperands(); + return {range.begin() + $_op.getNumInputs(), + range.begin() + getNumInputsAndOutputBuffers()}; + }] >, //===------------------------------------------------------------------===// // Input and Output arguments handling. //===------------------------------------------------------------------===// InterfaceMethod< - "Return one single buffer at position `$i`.", - "Value", "getBuffer", (ins "unsigned":$i) + /*desc=*/[{ + Return one single buffer at position `$i`. + }], + /*retTy=*/"Value", + /*methodName=*/"getBuffer", + /*args=*/(ins "unsigned":$i), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + assert(i < getNumInputsAndOutputBuffers() && "overflowing buffers index"); + return this->getOperation()->getOperand(i); + }] >, InterfaceMethod< - "Return the number of inputs and outputs, irrespective of their buffer " - "or tensor type.", - "unsigned", "getNumInputsAndOutputs" + /*desc=*/[{ + Return the number of inputs and outputs, irrespective of their buffer or + tensor type. + }], + /*retTy=*/"unsigned", + /*methodName=*/"getNumInputsAndOutputs", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return $_op.getNumInputs() + $_op.getNumOutputs(); + }] >, InterfaceMethod< - "Return the number of inputs, irrespective of their buffer or tensor " - "type, and output buffers", - "unsigned", "getNumInputsAndOutputBuffers" + /*desc=*/[{ + Return the number of inputs, irrespective of their buffer or tensor type + and output buffers + }], + /*retTy=*/"unsigned", + /*methodName=*/"getNumInputsAndOutputBuffers", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return $_op.getNumInputs() + $_op.getNumOutputs() - + this->getOperation()->getNumResults(); + }] >, InterfaceMethod< - "Return the range over inputs (irrespective of type) and output buffers.", - "Operation::operand_range", "getInputsAndOutputBuffers" + /*desc=*/[{ + Return the range over inputs (irrespective of type) and output buffers. + }], + /*retTy=*/"Operation::operand_range", + /*methodName=*/"getInputsAndOutputBuffers", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + auto range = this->getOperation()->getOperands(); + return {range.begin(), range.begin() + getNumInputsAndOutputBuffers()}; + }] >, InterfaceMethod< - "Return the shaped types for all the inputs and outputs", - "SmallVector", "getInputOutputShapedTypes" + /*desc=*/[{ + Return the `i`-th shaped type, there are 3 cases: + 1. if `i < $_op.getNumInputs()` then return `getInputShapedType(i)`; + otherwise + 2. if `i < getNumInputsAndOutputBuffers()` then return the + `getOutputBufferType(i - $_op.getNumInputs())`; otherwise + 3. return the `i - getNumInputsAndOutputBuffers()` result type. + }], + /*retTy=*/"ShapedType", + /*methodName=*/"getShapedType", + /*args=*/(ins "unsigned":$i), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + if (i < $_op.getNumInputs()) + return getInputShapedType(i); + if (i < getNumInputsAndOutputBuffers()) + return getOutputBufferType(i - $_op.getNumInputs()); + return getOutputTensorTypes()[i - getNumInputsAndOutputBuffers()]; + }]>, + InterfaceMethod< + /*desc=*/[{ + Return the shaped types for all the inputs and outputs + }], + /*retTy=*/"SmallVector", + /*methodName=*/"getInputOutputShapedTypes", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + SmallVector inputOutputTypes( + this->getOperation()->operand_type_begin(), + this->getOperation()->operand_type_end()); + inputOutputTypes.append(this->getOperation()->result_type_begin(), + this->getOperation()->result_type_end()); + return llvm::to_vector<4>( + llvm::map_range(inputOutputTypes, [](Type type) -> ShapedType { + return type.cast(); + })); + }] >, //===------------------------------------------------------------------===// // Other interface methods. //===------------------------------------------------------------------===// InterfaceMethod< - "Return the reference iterators for this named op (if any are " - "specified). These reference iterators are used to specify the default " - "behavior of the op. Typically this would be a static method but in " - "order to allow rank-polymorphic ops, this needs to be per object " - "instance. Named ops must define referenceIterators, even if empty for " - "the 0-D case. Generic ops on the other hand have a None " - "`referenceIterators`", - "llvm::Optional>", "referenceIterators" + /*desc=*/[{ + Return the iterator types attribute within the current operation. + }], + /*retTy=*/"ArrayAttr", + /*methodName=*/"iterator_types", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return $_op.iterator_types(); + }] >, InterfaceMethod< - "Return the reference indexing maps for this named op (if any are " - "specified). Typically this would be a static method but in order to " - "allow rank-polymorphic ops, this needs to be per object instance. Named " - "ops must define referenceIterators, even if empty for the 0-D case. " - "Generic ops on the other hand have a None `referenceIndexingMaps`", - "llvm::Optional>", "referenceIndexingMaps" + /*desc=*/[{ + Return the indexing maps attribute within the current operation. + }], + /*retTy=*/"ArrayAttr", + /*methodName=*/"indexing_maps" >, InterfaceMethod< - "Return the iterator types attribute within the current operation.", - "ArrayAttr", "iterator_types" + /*desc=*/[{ + Return the indexing maps within the current operation. + }], + /*retTy=*/"SmallVector", + /*methodName=*/"getIndexingMaps", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return llvm::to_vector<4>( + llvm::map_range($_op.indexing_maps(), + [](Attribute attr) -> AffineMap { + return attr.cast().getValue(); + })); + }] >, InterfaceMethod< - "Return the indexing maps attribute within the current operation.", - "ArrayAttr", "indexing_maps" + /*desc=*/[{ + Return the input or output indexing map at index `i`. + }], + /*retTy=*/"AffineMap", + /*methodName=*/"getIndexingMap", + /*args=*/(ins "unsigned":$i), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + assert(i < getNumInputsAndOutputs()); + return $_op.indexing_maps() + .getValue()[i] + .template cast() + .getValue(); + }] >, InterfaceMethod< - "Return the indexing maps within the current operation.", - "SmallVector", "getIndexingMaps" - >, - InterfaceMethod<"Return the input or output indexing map at index `i`.", - "AffineMap", "getIndexingMap", (ins "unsigned":$i) - >, - InterfaceMethod<"Return the input indexing map at index `i`.", - "AffineMap", "getInputIndexingMap", (ins "unsigned":$i) + /*desc=*/[{ + Return the input indexing map at index `i`. + }], + /*retTy=*/"AffineMap", + /*methodName=*/"getInputIndexingMap", + /*args=*/(ins "unsigned":$i), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + assert(i < $_op.getNumInputs()); + return $_op.indexing_maps() + .getValue()[i] + .template cast() + .getValue(); + }] >, - InterfaceMethod<"Return the output indexing map at index `i`.", - "AffineMap", "getOutputIndexingMap", (ins "unsigned":$i) + InterfaceMethod< + /*desc=*/[{ + Return the output indexing map at index `i`. + }], + /*retTy=*/"AffineMap", + /*methodName=*/"getOutputIndexingMap", + /*args=*/(ins "unsigned":$i), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + assert(i < $_op.getNumOutputs()); + return $_op.indexing_maps() + .getValue()[i + $_op.getNumInputs()] + .template cast() + .getValue(); + }] >, - InterfaceMethod<[{ + InterfaceMethod< + /*desc=*/[{ Return whether the op has only MemRef input and outputs. - }], "bool", "hasBufferSemantics">, - InterfaceMethod<[{ + }], + /*retTy=*/"bool", + /*methodName=*/"hasBufferSemantics", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return this->getOperation()->getNumResults() == 0 && + llvm::all_of(getInputs(), + [](Value v) { return v.getType().isa(); }); + }] + >, + InterfaceMethod< + /*desc=*/[{ Return whether the op has only RankedTensor input and outputs. - }], "bool", "hasTensorSemantics">, + }], + /*retTy=*/"bool", + /*methodName=*/"hasTensorSemantics", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + auto isTensorType = [](Value v) { + return v.getType().isa(); + }; + return llvm::all_of(getInputs(), isTensorType) && + llvm::all_of(this->getOperation()->getResults(), isTensorType); + }] + >, //===------------------------------------------------------------------===// // Other static interface methods. //===------------------------------------------------------------------===// - StaticInterfaceMethod<[{ + StaticInterfaceMethod< + /*desc=*/[{ Create an operation of the current type with the given location, operands, and attributes. }], - "Operation *", "create", + /*retTy=*/"Operation *", + /*methodName=*/"create", (ins "OpBuilder &":$builder, "Location":$loc, "ValueRange":$operands, "ArrayRef":$attributes), [{ @@ -192,11 +510,13 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> { attributes); }] >, - InterfaceMethod<[{ + InterfaceMethod< + /*desc=*/[{ Clone the current operation with the given location and operands. This is used to abstract away the optional underlying region creation. }], - "Operation *", "clone", + /*retTy=*/"Operation *", + /*methodName=*/"clone", (ins "OpBuilder &":$b, "Location":$loc, "ValueRange":$operands), [{ BlockAndValueMapping map; unsigned numRegions = $_op.getOperation()->getNumRegions(); diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h index 8dda7d0..c4790ca 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h @@ -49,8 +49,8 @@ public: }; }; -/// This class provides the API for structured ops that are known to operate on -/// buffers or tensors. This trait must be used in conjunction with an op +/// This class provides a verifier for structured ops that are known to operate +/// on buffers or tensors. This trait must be used in conjunction with an op /// definition or a trait that provides the methods `getNumInputs` and /// `getNumOutputs`. Use as a trait as follows: /// @@ -59,324 +59,18 @@ public: template class StructuredOpTraits : public OpTrait::TraitBase { -private: - /// Return the number of inputs, irrespective of their buffer or tensor type. - /// For internal use only. - unsigned nInputs() { - return cast(this->getOperation()).getNumInputs(); - } - /// Return the number of outputs, irrespective of their buffer or tensor type. - /// For internal use only. - unsigned nOutputs() { - return cast(this->getOperation()).getNumOutputs(); - } - public: - //==========================================================================// - // Loop types handling. - //==========================================================================// - unsigned getNumParallelLoops() { - return getNumIterators( - getParallelIteratorTypeName(), - cast(this->getOperation()).iterator_types()); - } - unsigned getNumReductionLoops() { - return getNumIterators( - getReductionIteratorTypeName(), - cast(this->getOperation()).iterator_types()); - } - unsigned getNumWindowLoops() { - return getNumIterators( - getWindowIteratorTypeName(), - cast(this->getOperation()).iterator_types()); - } - unsigned getNumLoops() { - return getNumIterators( - cast(this->getOperation()).iterator_types()); - } - - bool hasSingleReductionLoop() { - auto iterators = cast(this->getOperation()).iterator_types(); - return iterators.size() == 1 && - getNumIterators(getReductionIteratorTypeName(), iterators); - } - - //==========================================================================// - // Input arguments handling. - //==========================================================================// - // The `i^th` input argument is always the `i^th` operand regardless of - // whether we have tensors or buffers. - // - /// Return the `i`-th input value. - Value getInput(unsigned i) { - assert(i < nInputs()); - return this->getOperation()->getOperand(i); - } - /// Return the index of `value` in the list of inputs if found, llvm::None - /// otherwise. - Optional getIndexOfInput(Value value) { - auto it = llvm::find(getInputs(), value); - if (it != getInputs().end()) - return it - getInputs().begin(); - return llvm::None; - } - /// Return the `i`-th input shaped type, irrespective of buffer or tensor - /// type. - ShapedType getInputShapedType(unsigned i) { - return getInput(i).getType().template cast(); - } - /// Return the range over inputs. - Operation::operand_range getInputs() { - auto range = this->getOperation()->getOperands(); - return {range.begin(), range.begin() + nInputs()}; - } - /// Query the subset of input operands that are of ranked tensor type. - SmallVector getInputTensorTypes() { - SmallVector res; - for (Type type : getInputs().getTypes()) - if (auto t = type.template dyn_cast()) - res.push_back(t); - return res; - } - - //==========================================================================// - // Output arguments handling. - //==========================================================================// - // The `i^th` output argument is an operand (resp. a return value) iff it is - // a value of buffer type (resp. a return value of tensor type). - - /// Return the `i`-th output, asserts that this is a buffer operand and not - /// a tensor result. - Value getOutputBuffer(unsigned i) { - assert(i + this->getOperation()->getNumResults() < nOutputs() && - "overflowing output buffer index"); - return this->getOperation()->getOperand(nInputs() + i); - } - /// Return the index of `value` in the list of output buffers if found, - /// llvm::None otherwise. - Optional getIndexOfOutputBuffer(Value value) { - auto it = llvm::find(getOutputBuffers(), value); - if (it != getOutputBuffers().end()) - return it - getOutputBuffers().begin(); - return llvm::None; - } - /// Return the `i`-th output buffer type. - MemRefType getOutputBufferType(unsigned i) { - return getOutputBuffer(i).getType().template cast(); - } - /// Return the `i`-th output shaped type, irrespective of buffer of tensor - /// type. - ShapedType getOutputShapedType(unsigned i) { - return getShapedType(i + nInputs()); - } - /// Query the subset of results that are of ranked tensor type. - SmallVector getOutputTensorTypes() { - SmallVector res; - for (Type type : this->getOperation()->getResults().getTypes()) - res.push_back(type.template cast()); - return res; - } - /// Return the range over outputs. - Operation::operand_range getOutputBuffers() { - auto range = this->getOperation()->getOperands(); - return {range.begin() + nInputs(), - range.begin() + getNumInputsAndOutputBuffers()}; - } - - //==========================================================================// - // Input and Output arguments handling. - //==========================================================================// - Value getBuffer(unsigned i) { - assert(i < getNumInputsAndOutputBuffers() && "overflowing buffers index"); - return this->getOperation()->getOperand(i); - } - /// Return the number of inputs and outputs, irrespective of their buffer or - /// tensor type. - unsigned getNumInputsAndOutputs() { return nInputs() + nOutputs(); } - /// Return the number of inputs, irrespective of their buffer or tensor type, - /// and output buffers. - unsigned getNumInputsAndOutputBuffers() { - assert(this->getOperation()->getNumResults() <= nOutputs()); - return nInputs() + nOutputs() - this->getOperation()->getNumResults(); - } - /// Return the range over inputs (irrespective of type) and output buffers. - Operation::operand_range getInputsAndOutputBuffers() { - auto range = this->getOperation()->getOperands(); - return {range.begin(), range.begin() + getNumInputsAndOutputBuffers()}; - } - /// Return the `i`-th shaped type, there are 3 cases: - /// 1. if `i < nInputs()` then return `getInputShapedType(i)`; otherwise - /// 2. if `i < getNumInputsAndOutputBuffers()` then return the - /// `getOutputBufferType(i - nInputs())`; otherwise - /// 3. return the `i - getNumInputsAndOutputBuffers()` result type. - ShapedType getShapedType(unsigned i) { - if (i < nInputs()) - return getInputShapedType(i); - if (i < getNumInputsAndOutputBuffers()) - return getOutputBufferType(i - nInputs()).template cast(); - return getOutputTensorTypes()[i - getNumInputsAndOutputBuffers()] - .template cast(); - } - /// Return the shaped types for all the inputs and outputs - SmallVector getInputOutputShapedTypes() { - SmallVector inputOutputTypes( - this->getOperation()->operand_type_begin(), - this->getOperation()->operand_type_end()); - inputOutputTypes.append(this->getOperation()->result_type_begin(), - this->getOperation()->result_type_end()); - return llvm::to_vector<4>( - llvm::map_range(inputOutputTypes, [](Type type) -> ShapedType { - return type.cast(); - })); - } - - //==========================================================================// - // Other interface methods. - //==========================================================================// - - // Get or build the indexing_maps ArrayAttr. - ArrayAttr iterator_types() { - // Return the attribute if it is present. - if (auto attr = this->getOperation()->getAttr("iterator_types")) - return attr.template cast(); - - // If not, form the attribute using the reference iterator types for the - // ConcreteType. - auto maybeReferenceIteratorTypes = - cast(this->getOperation()).referenceIterators(); - - // If there is no reference, this must be a generic op. - // TODO: Traits are used to define ops. Split into cpp to avoid cyclic - // dependency. - auto name = this->getOperation()->getName().getStringRef(); - if (!maybeReferenceIteratorTypes && name != "generic" && - name != "indexed_generic") { - this->getOperation()->dump(); - llvm_unreachable("Op missing referenceIterators"); - } - - // If we have a reference, build the reference attribute and set it in the - // op before returning. - auto *ctx = this->getOperation()->getContext(); - auto attrRange = llvm::map_range(*maybeReferenceIteratorTypes, - [ctx](StringRef str) -> Attribute { - return StringAttr::get(str, ctx); - }); - auto attr = ArrayAttr::get(llvm::to_vector<4>(attrRange), ctx); - // TODO: Need to memoize this. Can't just store as an attribute atm as it - // will impact parser, printer and tests. - // this->getOperation()->setAttr("iterator_types", attr); - return attr; - } - - // Get or build the indexing_maps ArrayAttr. - ArrayAttr indexing_maps() { - // Return the attribute if it is present. - if (auto attr = this->getOperation()->getAttr("indexing_maps")) - return attr.template cast(); - - // If not, form the attribute using the reference indexing map for the - // ConcreteType. - auto maybeReferenceIndexingMaps = - cast(this->getOperation()).referenceIndexingMaps(); - - // If there is no reference, this must be a generic op. - auto name = this->getOperation()->getName().getStringRef(); - if (!maybeReferenceIndexingMaps && name != "generic" && - name != "indexed_generic") { - this->getOperation()->dump(); - llvm_unreachable("Op missing referenceIndexingMaps"); - } - - // If we have a reference, build the reference attribute and set it in the - // op before returning. - auto *ctx = this->getOperation()->getContext(); - auto attrRange = - llvm::map_range(*maybeReferenceIndexingMaps, [ctx](AffineMap map) { - // 0-D corner case because there is no such thing as a concrete empty - // map type. - if (!map) - map = AffineMap::get(0, 0, getAffineConstantExpr(0, ctx)); - return AffineMapAttr::get(map); - }); - SmallVector attrs{attrRange.begin(), attrRange.end()}; - auto attr = ArrayAttr::get(attrs, ctx); - // TODO: Need to memoize this. Can't just store as an attribute atm as it - // will impact parser, printer and tests. - // this->getOperation()->setAttr("indexing_maps", attr); - return attr; - } - - SmallVector getIndexingMaps() { - return llvm::to_vector<4>( - llvm::map_range(indexing_maps(), [](Attribute attr) -> AffineMap { - return attr.cast().getValue(); - })); - } - - AffineMap getIndexingMap(unsigned i) { - assert(i < getNumInputsAndOutputs()); - return indexing_maps() - .getValue()[i] - .template cast() - .getValue(); - } - - AffineMap getInputIndexingMap(unsigned i) { - assert(i < nInputs()); - return indexing_maps() - .getValue()[i] - .template cast() - .getValue(); - } - - AffineMap getOutputIndexingMap(unsigned i) { - assert(i < nOutputs()); - return indexing_maps() - .getValue()[i + nInputs()] - .template cast() - .getValue(); - } - - /// Query whether the op has only buffer inputs and no returns. - bool hasBufferSemantics() { - return this->getOperation()->getNumResults() == 0 && - llvm::all_of(getInputs(), - [](Value v) { return v.getType().isa(); }); - } - - /// Query whether the op has only tensor inputs and outputs. - bool hasTensorSemantics() { - auto isTensorType = [](Value v) { - return v.getType().isa(); - }; - return llvm::all_of(getInputs(), isTensorType) && - llvm::all_of(this->getOperation()->getResults(), isTensorType); - } - - //==========================================================================// - // Other static interface methods. - //==========================================================================// static LogicalResult verifyTrait(Operation *op) { + ConcreteType concreteOp = cast(op); auto nOperands = cast(op).getNumInputsAndOutputBuffers(); if (failed(OpTrait::impl::verifyAtLeastNOperands(op, nOperands))) return failure(); + if (op->getNumResults() > concreteOp.getNumOutputs()) + return op->emitError("unexpected #results > #outputs"); return success(); } }; -/// This class provides the API for named Linalg StructuredOps. -template -class NamedStructuredOpTraits - : public OpTrait::TraitBase { -public: - static SmallVector referenceIterators(TypeRange inputTypes, - TypeRange outputTypes); - - static SmallVector referenceIndexingMaps(TypeRange inputTypes, - TypeRange outputTypes); -}; - } // namespace linalg } // namespace OpTrait } // namespace mlir diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 77eb644..7071cd3 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -260,13 +260,14 @@ static LogicalResult verifyGenericOp(GenericOpType op) { if (failed(BlockArgsVerifier::verify(op, region.front()))) return failure(); - auto attr = op.template getAttrOfType("symbol_source"); - int64_t targetRank = 0; - if (attr) { - unsigned index = attr.getInt(); + auto symbolSourceAttr = + op.template getAttrOfType("symbol_source"); + int64_t expectedNumSymbols = 0; + if (symbolSourceAttr) { + unsigned index = symbolSourceAttr.getInt(); if (index >= op.getNumOperands()) return op.emitOpError("symbol_source index out of range"); - targetRank = op.getShapedType(index).getRank(); + expectedNumSymbols = op.getShapedType(index).getRank(); } SmallVector indexingMaps; @@ -278,9 +279,9 @@ static LogicalResult verifyGenericOp(GenericOpType op) { auto view = (idx < nInputViews) ? op.getInputShapedType(idx) : op.getOutputShapedType(idx - nInputViews); - if (m.getNumSymbols() != targetRank) + if (m.getNumSymbols() != expectedNumSymbols) return op.emitOpError("expected the number of symbols in indexing_map #") - << idx << " to match target rank"; + << idx << " to match rank of operand `symbol_source`"; if (m.getNumDims() != nLoops) return op.emitOpError("expected indexing_map #") @@ -1246,15 +1247,9 @@ void buildNamedStructuredOpRegionAndAttributes(Builder &builder, mlir::edsc::ScopedContext scope(opBuilder, builder.getUnknownLoc()); NamedStructuredOpType::regionBuilder(*body); - auto indexingMaps = builder.getAffineMapArrayAttr( - NamedStructuredOpType::referenceIndexingMaps(operandTypes, - tensorResultTypes)); - result.addAttribute(getIndexingMapsAttrName(), indexingMaps); + // indexing_maps is an auto-generated method. - auto iterators = - builder.getStrArrayAttr(NamedStructuredOpType::referenceIterators( - operandTypes, tensorResultTypes)); - result.addAttribute(getIteratorTypesAttrName(), iterators); + // iterator_types is an auto-generated method. } template diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir index c631c47..3774aed 100644 --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -113,7 +113,7 @@ func @generic_mismatched_num_returns(%arg0: memref) { // ----- func @generic_symbol_in_map(%arg0: memref) { - // expected-error @+1 {{expected the number of symbols in indexing_map #0 to match target rank}} + // expected-error @+1 {{expected the number of symbols in indexing_map #0 to match rank of operand `symbol_source`}} linalg.generic { args_in = 0, args_out = 1, @@ -514,3 +514,20 @@ func @named_ops(%a3: memref, %b3: memref, %c3: memref, memref, memref) -> () return } + +// ----- + +func @generic(%arg0: tensor) { + // expected-error @+1 {{unexpected #results > #outputs}} + linalg.generic { + args_in = 1, + args_out = 1, + indexing_maps = [ affine_map<(i) -> (i)> ], + iterator_types = ["parallel"] + } %arg0 { + ^bb(%0: i4) : + %1 = std.addi %0, %0: i4 + linalg.yield %1, %1: i4, i4 + } : tensor -> (tensor, tensor) + return +} diff --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc index d796d19..aad983e 100644 --- a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc +++ b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc @@ -4,16 +4,15 @@ // ODS-LABEL: def Test1Op : LinalgNamedStructured_Op<"test1", [ // ODS-NEXT: NInputs<2> // ODS-NEXT: NOutputs<1> -// ODS-NEXT: NamedStructuredOpTraits // ODS-NEXT: SingleBlockImplicitTerminator<"YieldOp"> // -// IMPL-LABEL: SmallVector Test1Op::referenceIterators +// IMPL-LABEL: ArrayAttr Test1Op::iterator_types() { // IMPL: { {{.*}}Parallel{{.*}}, {{.*}}Reduction{{.*}} } // -// IMPL: SmallVector Test1Op::referenceIndexingMaps +// IMPL: ArrayAttr Test1Op::indexing_maps() { // IMPL: AffineMap::get(2, 0, {d0, d1}, context), // IMPL-NEXT: AffineMap::get(2, 0, {d1}, context), -// IMPL-NEXT: AffineMap::get(2, 0, {d0}, context) }; +// IMPL-NEXT: AffineMap::get(2, 0, {d0}, context) }); // // IMPL: void Test1Op::regionBuilder(Block &block) { // IMPL: Value [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]); @@ -29,16 +28,15 @@ def test1(A: f32(M, K), B: f32(K)) -> (C: f32(M)) { // ODS-LABEL: def Test2Op : LinalgNamedStructured_Op<"test2", [ // ODS-NEXT: NInputs<2> // ODS-NEXT: NOutputs<1> -// ODS-NEXT: NamedStructuredOpTraits // ODS-NEXT: SingleBlockImplicitTerminator<"YieldOp"> // -// IMPL-LABEL: SmallVector Test2Op::referenceIterators +// IMPL-LABEL: ArrayAttr Test2Op::iterator_types() { // IMPL: { {{.*}}Parallel{{.*}}, {{.*}}Parallel{{.*}}, {{.*}}Reduction{{.*}} } // -// IMPL: SmallVector Test2Op::referenceIndexingMaps +// IMPL: ArrayAttr Test2Op::indexing_maps() { // IMPL: AffineMap::get(3, 0, {d0, d2}, context), // IMPL-NEXT: AffineMap::get(3, 0, {d2, d1}, context), -// IMPL-NEXT: AffineMap::get(3, 0, {d0, d1}, context) }; +// IMPL-NEXT: AffineMap::get(3, 0, {d0, d1}, context) }); // // IMPL: Test2Op::regionBuilder(Block &block) { // IMPL: Value [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]); @@ -54,16 +52,15 @@ def test2(A: f32(M, K), B: f32(K, N)) -> (C: f32(M, N)) { // ODS-LABEL: def Test3Op : LinalgNamedStructured_Op<"test3", [ // ODS-NEXT: NInputs<2> // ODS-NEXT: NOutputs<1> -// ODS-NEXT: NamedStructuredOpTraits // ODS-NEXT: SingleBlockImplicitTerminator<"YieldOp"> // -// IMPL-LABEL: SmallVector Test3Op::referenceIterators +// IMPL-LABEL: ArrayAttr Test3Op::iterator_types() { // IMPL: { {{.*}}Parallel{{.*}}, {{.*}}Parallel{{.*}}, {{.*}}Reduction{{.*}} } // -// IMPL: SmallVector Test3Op::referenceIndexingMaps +// IMPL: ArrayAttr Test3Op::indexing_maps() { // IMPL: AffineMap::get(4, 0, {d0, d1, d3}, context), // IMPL-NEXT: AffineMap::get(4, 0, {d3, d2}, context), -// IMPL-NEXT: AffineMap::get(4, 0, {d0, d1, d2}, context) }; +// IMPL-NEXT: AffineMap::get(4, 0, {d0, d1, d2}, context) }); // // IMPL: Test3Op::regionBuilder(Block &block) { // IMPL: Value [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]); diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp index 92efef6..59d6556 100644 --- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp +++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp @@ -974,19 +974,19 @@ public: /// Parse and print the information for a TC def. /// When `gen-ods-decl` is used, this prints the ODS declaration for the TC. /// When `gen-impl` is used, this prints the C++ implementation for the extra - /// methods defined in ODS (referenceIterators, referenceIndexingMaps and - /// regionBuilder). + /// methods defined in ODS (`iterator_types`, `indexing_maps` and + /// `regionBuilder`). LogicalResult parseAndEmitODSDef(llvm::raw_ostream &os); /// Print the ODS class that defines a new `cppOpName` for a `linalgOpName`. void printODS(llvm::raw_ostream &os, StringRef cppOpName, StringRef linalgOpName); - /// Print the C++ StructuredOpsInterface impl of `referenceIterators`. + /// Print the C++ StructuredOpsInterface impl of `iterator_types`. void printReferenceIterators(llvm::raw_ostream &os, StringRef cppOpName, ComprehensionParsingState &state); - /// Print the C++ StructuredOpsInterface impl of `referenceIndexingMaps`. + /// Print the C++ StructuredOpsInterface impl of `indexing_maps`. void printReferenceIndexingMaps(llvm::raw_ostream &os, StringRef cppOpName, ComprehensionParsingState &state); @@ -1446,7 +1446,6 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName, const char *header = R"FMT( def {0} : LinalgNamedStructured_Op<"{1}", [ NInputs<{2}>, NOutputs<{3}>, - NamedStructuredOpTraits, SingleBlockImplicitTerminator<"YieldOp">]> { let arguments = (ins Variadic:$views); let results = (outs Variadic:$output_tensors); @@ -1465,16 +1464,9 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName, return ::parseNamedStructuredOp<{0}>(parser, result); }]; let extraClassDeclaration = [{{ - llvm::Optional> referenceIterators(); - static SmallVector referenceIterators( - TypeRange inputTypes, TypeRange outputTypes); - - llvm::Optional> referenceIndexingMaps(); - static SmallVector referenceIndexingMaps( - TypeRange inputTypes, TypeRange outputTypes); - + ArrayAttr iterator_types(); + ArrayAttr indexing_maps(); static void regionBuilder(Block &block); - std::string getLibraryCallName() {{ return generateLibraryCallName(getOperation()); } @@ -1492,20 +1484,14 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName, os << llvm::formatv(header, cppOpName, linalgOpName, nInputs, nOutputs); } -/// Print the C++ StructuredOpsInterface impl of `referenceIterators`. +/// Print the C++ StructuredOpsInterface impl of `iterator_types`. void TCParser::printReferenceIterators(llvm::raw_ostream &os, StringRef cppOpName, ComprehensionParsingState &state) { const char *referenceReferenceIteratorsFmt = R"FMT( - // This is temporary until we transition out of manually specified ops - // that should be auto-generated with linalg-ods-gen. - llvm::Optional> {0}::referenceIterators() {{ - llvm_unreachable("Unexpected missing `iterator_types` attribute."); - } - SmallVector {0}::referenceIterators( - TypeRange inputTypes, TypeRange outputTypes) { - return SmallVector{{ {1} }; + ArrayAttr {0}::iterator_types() { + return Builder(getContext()).getStrArrayAttr(SmallVector{{ {1} }); })FMT"; std::string iteratorsStr; @@ -1542,16 +1528,11 @@ void TCParser::printReferenceIndexingMaps(llvm::raw_ostream &os, R"FMT( // This is temporary until we transition out of manually specified ops that // should be auto-generated with linalg-ods-gen. - llvm::Optional> {0}::referenceIndexingMaps() {{ - llvm_unreachable("Unexpected missing `indexing_maps` attribute."); - } - SmallVector {0}::referenceIndexingMaps( - TypeRange inputTypes, TypeRange outputTypes) { - assert(!inputTypes.empty() && "At least one input expected"); - MLIRContext *context = (*inputTypes.begin()).getContext(); + ArrayAttr {0}::indexing_maps() { + MLIRContext *context = getContext(); AffineExpr {1}; bindDims(context, {1}); - return SmallVector{{ {2} }; + return Builder(context).getAffineMapArrayAttr({ {2} }); })FMT"; // 2. Print a comma-separated list of identifiers for the AffineExpr in -- 2.7.4