#ifndef LINALG_BASE
#define LINALG_BASE
+include "mlir/Dialect/Utils/StructuredOpsUtils.td"
include "mlir/Dialect/Linalg/IR/LinalgEnums.td"
include "mlir/IR/EnumAttr.td"
include "mlir/IR/OpBase.td"
let assemblyFormat = "`<` $value `>`";
}
+def IteratorTypeEnum : EnumAttr<Linalg_Dialect, IteratorType, "iterator_type"> {
+ let assemblyFormat = "`<` $value `>`";
+}
+def IteratorTypeArrayAttr : TypedArrayAttrBase<IteratorTypeEnum,
+ "Iterator type should be an enum.">;
+
#endif // LINALG_BASE
namespace mlir {
namespace linalg {
+class IteratorTypeAttr;
class LinalgOp;
namespace detail {
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
- return getNumIterators(getParallelIteratorTypeName(),
- $_op.getIteratorTypesArray());
+ return llvm::count($_op.getIteratorTypesArray(),
+ utils::IteratorType::parallel);
}]
>,
InterfaceMethod<
/*methodBody=*/"",
/*defaultImplementation=*/[{
return findPositionsOfType($_op.getIteratorTypesArray(),
- getParallelIteratorTypeName(), res);
+ utils::IteratorType::parallel, res);
}]
>,
InterfaceMethod<
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
- return getNumIterators(getReductionIteratorTypeName(),
- $_op.getIteratorTypesArray());
+ return llvm::count($_op.getIteratorTypesArray(),
+ utils::IteratorType::reduction);
}]
>,
InterfaceMethod<
/*methodBody=*/"",
/*defaultImplementation=*/[{
return findPositionsOfType($_op.getIteratorTypesArray(),
- getReductionIteratorTypeName(), res);
- }]
- >,
- InterfaceMethod<
- /*desc=*/[{
- Return the number of window loops.
- }],
- /*retTy=*/"unsigned",
- /*methodName=*/"getNumWindowLoops",
- /*args=*/(ins),
- /*methodBody=*/"",
- /*defaultImplementation=*/[{
- return getNumIterators(getWindowIteratorTypeName(),
- $_op.getIteratorTypesArray());
- }]
- >,
- InterfaceMethod<
- /*desc=*/[{
- Return the dims that are window loops.
- }],
- /*retTy=*/"void",
- /*methodName=*/"getWindowDims",
- /*args=*/(ins "SmallVectorImpl<unsigned> &":$res),
- /*methodBody=*/"",
- /*defaultImplementation=*/[{
- return findPositionsOfType($_op.getIteratorTypesArray(),
- getWindowIteratorTypeName(), res);
+ utils::IteratorType::reduction, res);
}]
>,
InterfaceMethod<
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
- return getNumIterators($_op.getIteratorTypesArray());
+ return $_op.getIteratorTypesArray().size();
}]
>,
InterfaceMethod<
/*defaultImplementation=*/[{
auto iters = $_op.getIteratorTypesArray();
return iters.size() == 1 &&
- getNumIterators(getReductionIteratorTypeName(), iters) == 1;
+ llvm::count(iters, utils::IteratorType::reduction) == 1;
}]>,
//===------------------------------------------------------------------===//
// Input and Init arguments handling.
can be infered from other parameters and in such cases default
getIteratorTypesArray should be overriden.
}],
- /*retTy=*/"SmallVector<StringRef>",
+ /*retTy=*/"SmallVector<utils::IteratorType>",
/*methodName=*/"getIteratorTypesArray",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
- auto range = $_op.getIteratorTypes().template getAsValueRange<StringAttr>();
+ auto range = $_op.getIteratorTypes()
+ .template getAsValueRange<IteratorTypeAttr,
+ utils::IteratorType>();
return {range.begin(), range.end()};
}]
>,
LogicalResult reifyResultShapes(OpBuilder &b,
ReifiedRankedShapedTypeDims &reifiedReturnShapes);
- SmallVector<StringRef> getIteratorTypeNames() {
- return getIteratorTypesArray();
- }
-
//========================================================================//
// Forwarding functions to access interface methods from the
// DestinationStyleOpInterface.
let arguments = (ins Variadic<AnyType>:$inputs,
Variadic<AnyShaped>:$outputs,
AffineMapArrayAttr:$indexing_maps,
- ArrayAttr:$iterator_types,
+ IteratorTypeArrayAttr:$iterator_types,
OptionalAttr<StrAttr>:$doc,
OptionalAttr<StrAttr>:$library_call);
let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>,
OpBuilder<(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
"ValueRange":$outputs, "ArrayRef<AffineMap>":$indexingMaps,
- "ArrayRef<StringRef>":$iteratorTypes, "StringRef":$doc,
+ "ArrayRef<utils::IteratorType>":$iteratorTypes, "StringRef":$doc,
"StringRef":$libraryCall,
CArg<"function_ref<void(OpBuilder &, Location, ValueRange)>", "nullptr">,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>,
OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$outputBuffers,
- "ArrayRef<AffineMap>":$indexingMaps, "ArrayRef<StringRef>":$iteratorTypes,
+ "ArrayRef<AffineMap>":$indexingMaps, "ArrayRef<utils::IteratorType>":$iteratorTypes,
"StringRef":$doc, "StringRef":$libraryCall,
CArg<"function_ref<void(OpBuilder &, Location, ValueRange)>", "nullptr">,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>,
OpBuilder<(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
"ValueRange":$outputs, "ArrayRef<AffineMap>":$indexingMaps,
- "ArrayRef<StringRef>":$iteratorTypes,
+ "ArrayRef<utils::IteratorType>":$iteratorTypes,
CArg<"function_ref<void(OpBuilder &, Location, ValueRange)>", "nullptr">,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>,
OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$outputBuffers,
- "ArrayRef<AffineMap>":$indexingMaps, "ArrayRef<StringRef>":$iteratorTypes,
+ "ArrayRef<AffineMap>":$indexingMaps, "ArrayRef<utils::IteratorType>":$iteratorTypes,
CArg<"function_ref<void(OpBuilder &, Location, ValueRange)>", "nullptr">,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>
];
let extraClassDeclaration = structuredOpsBaseDecls # [{
// Implement functions necessary for LinalgStructuredInterface.
- SmallVector<StringRef> getIteratorTypesArray();
+ SmallVector<utils::IteratorType> getIteratorTypesArray();
ArrayAttr getIndexingMaps();
std::string getLibraryCallName() {
return "op_has_no_registered_library_name";
let extraClassDeclaration = structuredOpsBaseDecls # [{
// Declare functions necessary for LinalgStructuredInterface.
- SmallVector<StringRef> getIteratorTypesArray();
+ SmallVector<utils::IteratorType> getIteratorTypesArray();
ArrayAttr getIndexingMaps();
std::string getLibraryCallName() {
return "op_has_no_registered_library_name";
let extraClassDeclaration = structuredOpsBaseDecls # [{
// Declare functions necessary for LinalgStructuredInterface.
- SmallVector<StringRef> getIteratorTypesArray();
+ SmallVector<utils::IteratorType> getIteratorTypesArray();
ArrayAttr getIndexingMaps();
std::string getLibraryCallName() {
return "op_has_no_registered_library_name";
let extraClassDeclaration = structuredOpsBaseDecls # [{
// Declare functions necessary for LinalgStructuredInterface.
- SmallVector<StringRef> getIteratorTypesArray();
+ SmallVector<utils::IteratorType> getIteratorTypesArray();
ArrayAttr getIndexingMaps();
std::string getLibraryCallName() {
return "op_has_no_registered_library_name";
bool isElementwise(LinalgOp op);
/// Check if iterator type has "parallel" semantics.
-bool isParallelIterator(StringRef iteratorType);
+bool isParallelIterator(utils::IteratorType iteratorType);
/// Check if iterator type has "reduction" semantics.
-bool isReductionIterator(StringRef iteratorType);
+bool isReductionIterator(utils::IteratorType iteratorType);
/// Helper function that creates a memref::DimOp or tensor::DimOp depending on
/// the type of `source`.
template <typename LoopTy>
struct GenerateLoopNest {
static void doit(OpBuilder &b, Location loc, ArrayRef<Range> loopRanges,
- LinalgOp linalgOp, ArrayRef<StringRef> iteratorTypes,
+ LinalgOp linalgOp,
+ ArrayRef<utils::IteratorType> iteratorTypes,
function_ref<scf::ValueVector(OpBuilder &, Location,
ValueRange, ValueRange)>
bodyBuilderFn,
namespace tosa {
// Creates a SmallVector of Stringrefs for N parallel loops
-SmallVector<StringRef> getNParallelLoopsAttrs(unsigned nParallelLoops);
+SmallVector<utils::IteratorType>
+getNParallelLoopsAttrs(unsigned nParallelLoops);
// Takes a vector of values and condenses them to a vector with no gaps.
SmallVector<Value> condenseValues(const SmallVector<Value> &values);
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Location.h"
#include "mlir/Support/LLVM.h"
-#include "llvm/ADT/StringRef.h"
// Pull in all enum type definitions and utility function declarations.
#include "mlir/Dialect/Utils/DialectUtilsEnums.h.inc"
/// the reduction.
bool isRowMajorBatchMatmul(ArrayAttr indexingMaps);
-/// Use to encode that a particular iterator type has parallel semantics.
-constexpr StringRef getParallelIteratorTypeName() { return "parallel"; }
-
-/// Use to encode that a particular iterator type has reduction semantics.
-constexpr StringRef getReductionIteratorTypeName() { return "reduction"; }
-
-/// Use to encode that a particular iterator type has window semantics.
-constexpr StringRef getWindowIteratorTypeName() { return "window"; }
-
-/// Use to encode that a particular iterator type has window semantics.
-inline ArrayRef<StringRef> getAllIteratorTypeNames() {
- static constexpr StringRef names[3] = {getParallelIteratorTypeName(),
- getReductionIteratorTypeName(),
- getWindowIteratorTypeName()};
- return llvm::makeArrayRef(names);
-}
-
-/// Returns the iterator of a certain type.
-inline unsigned getNumIterators(StringRef name,
- ArrayRef<StringRef> iteratorTypes) {
- auto names = getAllIteratorTypeNames();
- (void)names;
- assert(llvm::is_contained(names, name));
- return llvm::count(iteratorTypes, name);
-}
-
-inline unsigned getNumIterators(ArrayRef<StringRef> iteratorTypes) {
- unsigned res = 0;
- for (auto n : getAllIteratorTypeNames())
- res += getNumIterators(n, iteratorTypes);
- return res;
-}
-
/// Return positions in `iteratorTypes` that match `iteratorTypeName`.
-inline void findPositionsOfType(ArrayRef<StringRef> iteratorTypes,
- StringRef iteratorTypeName,
+inline void findPositionsOfType(ArrayRef<utils::IteratorType> iteratorTypes,
+ utils::IteratorType iteratorTypeName,
SmallVectorImpl<unsigned> &res) {
for (const auto &en : llvm::enumerate(iteratorTypes)) {
if (en.value() == iteratorTypeName)
/// Helper StructuredGenerator class to manipulate and rewrite ops with
/// `StructuredOpInterface`. This is templated for now because VectorOps do not
/// yet implement the StructuredOpInterface itself.
-template <typename StructuredOpInterface>
+template <typename StructuredOpInterface, typename IteratorTypeT>
class StructuredGenerator {
public:
using MapList = ArrayRef<ArrayRef<AffineExpr>>;
struct IteratorType {
- IteratorType(StringRef strRef) : strRef(strRef) {}
- bool isOfType(StringRef typeName) const { return typeName == strRef; }
- StringRef strRef;
+ IteratorType(IteratorTypeT iter) : iter(iter) {}
+ bool isOfType(IteratorTypeT expectedIter) const {
+ return expectedIter == iter;
+ }
+ IteratorTypeT iter;
};
struct Par : public IteratorType {
- Par() : IteratorType(getParallelIteratorTypeName()) {}
+ Par() : IteratorType(IteratorTypeT::parallel) {}
};
struct Red : public IteratorType {
- Red() : IteratorType(getReductionIteratorTypeName()) {}
- };
- struct Win : public IteratorType {
- Win() : IteratorType(getWindowIteratorTypeName()) {}
+ Red() : IteratorType(IteratorTypeT::reduction) {}
};
StructuredGenerator(OpBuilder &builder, StructuredOpInterface op)
: builder(builder), ctx(op.getContext()), loc(op.getLoc()),
- iterators(op.getIteratorTypeNames()), maps(op.getIndexingMapsArray()),
+ iterators(op.getIteratorTypesArray()), maps(op.getIndexingMapsArray()),
op(op) {}
bool iters(ArrayRef<IteratorType> its) {
OpBuilder &builder;
MLIRContext *ctx;
Location loc;
- SmallVector<StringRef> iterators;
+ SmallVector<IteratorTypeT> iterators;
SmallVector<AffineMap, 4> maps;
Operation *op;
};
return CombiningKind::ADD;
}
- // Returns iterator types in string format.
- SmallVector<StringRef> getIteratorTypeNames() {
- return llvm::to_vector(
- llvm::map_range(getIteratorTypes(), [](Attribute a) {
- return stringifyIteratorType(a.cast<IteratorTypeAttr>().getValue());
- }));
+ SmallVector<IteratorType> getIteratorTypesArray() {
+ auto range =
+ getIteratorTypes()
+ .template getAsValueRange<IteratorTypeAttr, IteratorType>();
+ return {range.begin(), range.end()};
}
}];
SmallVector<AffineExpr, 2> srcExprs;
SmallVector<AffineExpr, 2> dstExprs;
- SmallVector<StringRef, 4> iteratorTypes;
+ SmallVector<utils::IteratorType, 4> iteratorTypes;
for (unsigned int i = 0, rank = inputTy.getRank(); i != rank; ++i) {
srcExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
- iteratorTypes.push_back(axis == i ? getReductionIteratorTypeName()
- : getParallelIteratorTypeName());
+ iteratorTypes.push_back(axis == i ? utils::IteratorType::reduction
+ : utils::IteratorType::parallel);
if (axis != i)
dstExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
}
auto inputMap = AffineMap::get(/*dimCount=*/4, /*symbolCount=*/0,
inputExprs, builder.getContext());
auto resultMap = rewriter.getMultiDimIdentityMap(resultTy.getRank());
- SmallVector<StringRef> iterators(4, getParallelIteratorTypeName());
+ SmallVector<utils::IteratorType> iterators(4,
+ utils::IteratorType::parallel);
Value empty = builder.create<tensor::EmptyOp>(
resultTy.getShape(), resultTy.getElementType(), outputDynSize);
// We need to reduce along the arg-max axis, with parallel operations along
// the rest.
- SmallVector<StringRef, 4> iteratorTypes;
- iteratorTypes.resize(inputTy.getRank(), getParallelIteratorTypeName());
- iteratorTypes[axis] = getReductionIteratorTypeName();
+ SmallVector<utils::IteratorType, 4> iteratorTypes;
+ iteratorTypes.resize(inputTy.getRank(), utils::IteratorType::parallel);
+ iteratorTypes[axis] = utils::IteratorType::reduction;
SmallVector<AffineExpr, 2> srcExprs;
SmallVector<AffineExpr, 2> dstExprs;
if (inputExprWalker.unConvolvedDims.count(outputDim) &&
!filterDims.count(outputDim)) {
// Batch dimension.
- if (iteratorTypes[outputDim] != getParallelIteratorTypeName())
+ if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
return MatchConvolutionResult::OutputDimsNotParallel;
allLoopDims.insert(outputDim);
continue;
if (inputExprWalker.convolvedDims.count(outputDim) &&
!filterDims.count(outputDim)) {
// Output image Loop dimension.
- if (iteratorTypes[outputDim] != getParallelIteratorTypeName())
+ if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
return MatchConvolutionResult::OutputDimsNotParallel;
allLoopDims.insert(outputDim);
continue;
!inputExprWalker.unConvolvedDims.count(outputDim) &&
filterDims.count(outputDim)) {
// Output channel dimension.
- if (iteratorTypes[outputDim] != getParallelIteratorTypeName())
+ if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
return MatchConvolutionResult::OutputDimsNotParallel;
allLoopDims.insert(outputDim);
continue;
if (inputExprWalker.unConvolvedDims.count(outputDim) &&
filterDims.count(outputDim)) {
// Depth multiplier.
- if (iteratorTypes[outputDim] != getParallelIteratorTypeName())
+ if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
return MatchConvolutionResult::OutputDimsNotParallel;
allLoopDims.insert(outputDim);
continue;
if (inputExprWalker.convolvedDims.count(filterDim) &&
!outputDims.count(filterDim)) {
// Filter loop dimension.
- if (iteratorTypes[filterDim] != getReductionIteratorTypeName())
+ if (iteratorTypes[filterDim] != utils::IteratorType::reduction)
return MatchConvolutionResult::NonOutputDimNotReduction;
if (allLoopDims.count(filterDim))
return MatchConvolutionResult::NonConvolutionLoop;
if (inputExprWalker.unConvolvedDims.count(filterDim) &&
!outputDims.count(filterDim)) {
// Input channel dimension.
- if (iteratorTypes[filterDim] != getReductionIteratorTypeName())
+ if (iteratorTypes[filterDim] != utils::IteratorType::reduction)
return MatchConvolutionResult::NonOutputDimNotReduction;
if (allLoopDims.count(filterDim))
return MatchConvolutionResult::NonConvolutionLoop;
LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
LinalgOp linalgOp = cast<LinalgOp>(op);
- // Check all iterator types are known.
- auto iteratorTypesRange = linalgOp.getIteratorTypesArray();
- for (StringRef iteratorType : iteratorTypesRange) {
- if (!llvm::is_contained(getAllIteratorTypeNames(), iteratorType) ||
- !utils::symbolizeIteratorType(iteratorType).has_value())
- return op->emitOpError("unexpected iterator_type (")
- << iteratorType << ")";
- }
-
// Before checking indexing maps, we need to make sure the attributes
// referenced by it are valid.
if (linalgOp.hasDynamicIndexingMaps())
void GenericOp::build(
OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
- ArrayRef<StringRef> iteratorTypes, StringRef doc, StringRef libraryCall,
+ ArrayRef<utils::IteratorType> iteratorTypes, StringRef doc,
+ StringRef libraryCall,
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
ArrayRef<NamedAttribute> attributes) {
build(builder, result, resultTensorTypes, inputs, outputs,
builder.getAffineMapArrayAttr(indexingMaps),
- builder.getStrArrayAttr(iteratorTypes),
+ builder.getArrayAttr(llvm::to_vector(llvm::map_range(
+ iteratorTypes,
+ [&](utils::IteratorType iter) -> mlir::Attribute {
+ return IteratorTypeAttr::get(builder.getContext(), iter);
+ }))),
doc.empty() ? StringAttr() : builder.getStringAttr(doc),
libraryCall.empty() ? StringAttr() : builder.getStringAttr(libraryCall),
bodyBuild, attributes);
void GenericOp::build(
OpBuilder &builder, OperationState &result, ValueRange inputs,
ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
- ArrayRef<StringRef> iteratorTypes, StringRef doc, StringRef libraryCall,
+ ArrayRef<utils::IteratorType> iteratorTypes, StringRef doc,
+ StringRef libraryCall,
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
ArrayRef<NamedAttribute> attributes) {
build(builder, result, TypeRange{}, inputs, outputs, indexingMaps,
void GenericOp::build(
OpBuilder &builder, OperationState &result, ValueRange inputs,
ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
- ArrayRef<StringRef> iteratorTypes,
+ ArrayRef<utils::IteratorType> iteratorTypes,
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
ArrayRef<NamedAttribute> attributes) {
build(builder, result, inputs, outputs, indexingMaps, iteratorTypes,
void GenericOp::build(
OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
- ArrayRef<StringRef> iteratorTypes,
+ ArrayRef<utils::IteratorType> iteratorTypes,
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
ArrayRef<NamedAttribute> attributes) {
build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
llvm::StringSet<> genericAttrNamesSet;
genericAttrNamesSet.insert(genericAttrNames.begin(), genericAttrNames.end());
SmallVector<NamedAttribute, 8> genericAttrs;
- for (auto attr : (*this)->getAttrs())
- if (genericAttrNamesSet.count(attr.getName().strref()) > 0)
+ for (auto attr : (*this)->getAttrs()) {
+ if (attr.getName() == getIteratorTypesAttrName()) {
+ auto iteratorTypes =
+ attr.getValue()
+ .cast<ArrayAttr>()
+ .getAsValueRange<IteratorTypeAttr, utils::IteratorType>();
+ // Convert IteratorType enums into the string representation. This is
+ // needed, because tests still use the old format when 'iterator_types'
+ // attribute is represented as an array of strings.
+ // TODO: Remove this conversion once tests are fixed.
+ SmallVector<Attribute> iteratorTypeNames =
+ llvm::to_vector(llvm::map_range(
+ iteratorTypes, [&](utils::IteratorType t) -> Attribute {
+ return StringAttr::get(getContext(), stringifyIteratorType(t));
+ }));
+
+ genericAttrs.emplace_back(
+ getIteratorTypesAttrName(),
+ ArrayAttr::get(getContext(), iteratorTypeNames));
+ } else if (genericAttrNamesSet.count(attr.getName().strref()) > 0) {
genericAttrs.push_back(attr);
+ }
+ }
if (!genericAttrs.empty()) {
auto genericDictAttr = DictionaryAttr::get(getContext(), genericAttrs);
p << genericDictAttr;
result.attributes.assign(dictAttr.getValue().begin(),
dictAttr.getValue().end());
+ // Convert array of string into an array of IteratyType enums. This is needed,
+ // because tests still use the old format when 'iterator_types' attribute is
+ // represented as an array of strings.
+ // TODO: Remove this conversion once tests are fixed.
+ ArrayAttr iteratorTypes =
+ result.attributes.get(getIteratorTypesAttrName(result.name))
+ .cast<ArrayAttr>();
+
+ SmallVector<Attribute> iteratorTypeAttrs;
+
+ for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
+ auto maybeIteratorType = utils::symbolizeIteratorType(s);
+ if (!maybeIteratorType.has_value())
+ return parser.emitError(parser.getCurrentLocation())
+ << "unexpected iterator_type (" << s << ")";
+
+ iteratorTypeAttrs.push_back(
+ IteratorTypeAttr::get(parser.getContext(), maybeIteratorType.value()));
+ }
+ result.attributes.set(getIteratorTypesAttrName(result.name),
+ parser.getBuilder().getArrayAttr(iteratorTypeAttrs));
+
// Parsing is shared with named ops, except for the region.
SmallVector<Type, 1> inputTypes, outputTypes;
if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
return success();
}
-SmallVector<StringRef> MapOp::getIteratorTypesArray() {
+SmallVector<utils::IteratorType> MapOp::getIteratorTypesArray() {
int64_t rank = getInit().getType().getRank();
- return SmallVector<StringRef>(rank, getParallelIteratorTypeName());
+ return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
}
ArrayAttr MapOp::getIndexingMaps() {
inputs, inits, bodyBuild);
}
-SmallVector<StringRef> ReduceOp::getIteratorTypesArray() {
+SmallVector<utils::IteratorType> ReduceOp::getIteratorTypesArray() {
int64_t inputRank = getInputs()[0].getType().cast<ShapedType>().getRank();
- SmallVector<StringRef> iteratorTypes(inputRank,
- getParallelIteratorTypeName());
+ SmallVector<utils::IteratorType> iteratorTypes(inputRank,
+ utils::IteratorType::parallel);
for (int64_t reductionDim : getDimensions())
- iteratorTypes[reductionDim] = getReductionIteratorTypeName();
+ iteratorTypes[reductionDim] = utils::IteratorType::reduction;
return iteratorTypes;
}
return success();
}
-SmallVector<StringRef> TransposeOp::getIteratorTypesArray() {
+SmallVector<utils::IteratorType> TransposeOp::getIteratorTypesArray() {
int64_t rank = getInit().getType().getRank();
- return SmallVector<StringRef>(rank, getParallelIteratorTypeName());
+ return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
}
ArrayAttr TransposeOp::getIndexingMaps() {
return success();
}
-SmallVector<StringRef> BroadcastOp::getIteratorTypesArray() {
+SmallVector<utils::IteratorType> BroadcastOp::getIteratorTypesArray() {
int64_t rank = getInit().getType().getRank();
- return SmallVector<StringRef>(rank, getParallelIteratorTypeName());
+ return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
}
ArrayAttr BroadcastOp::getIndexingMaps() {
.getValue()
.isProjectedPermutation();
}) &&
- genericOp.getMatchingIndexingMap(fusableOpOperand).getNumResults() > 0 &&
- llvm::all_of(genericOp.getIteratorTypesArray(), [](StringRef it) {
- return it == getParallelIteratorTypeName();
- });
+ genericOp.getMatchingIndexingMap(fusableOpOperand).getNumResults() >
+ 0 &&
+ llvm::all_of(genericOp.getIteratorTypesArray(), isParallelIterator);
}
namespace {
}
// The iterator types of the expanded op are all parallel.
- SmallVector<StringRef> iteratorTypes(expansionInfo.getExpandedOpNumDims(),
- getParallelIteratorTypeName());
+ SmallVector<utils::IteratorType> iteratorTypes(
+ expansionInfo.getExpandedOpNumDims(), utils::IteratorType::parallel);
TypeRange resultTypes = ValueRange(outputs).getTypes();
auto fusedOp =
continue;
// Check that all folded iterator types are all parallel or all reductions.
- StringRef startIteratorType = iteratorTypes[foldedIterationSpaceDims[0]];
+ utils::IteratorType startIteratorType =
+ iteratorTypes[foldedIterationSpaceDims[0]];
if (!isParallelIterator(startIteratorType) &&
!isReductionIterator(startIteratorType))
continue;
/// Get the iterator types for the collapsed operation given the original
/// iterator types and collapsed dimensions.
-static SmallVector<StringRef>
-getCollapsedOpIteratorTypes(ArrayRef<Attribute> iteratorTypes,
+static SmallVector<utils::IteratorType>
+getCollapsedOpIteratorTypes(ArrayRef<utils::IteratorType> iteratorTypes,
const CollapsingInfo &collapsingInfo) {
- SmallVector<StringRef> collapsedIteratorTypes;
+ SmallVector<utils::IteratorType> collapsedIteratorTypes;
for (ReassociationIndicesRef foldedIterDims :
collapsingInfo.getCollapsedOpToOrigOpMapping()) {
assert(!foldedIterDims.empty() &&
// Just pick the iterator type of the first folded dim. Pre-condition checks
// expected to have checked that iterator types of all folded dimensions are
// the same.
- collapsedIteratorTypes.push_back(
- iteratorTypes[foldedIterDims[0]].cast<StringAttr>().getValue());
+ collapsedIteratorTypes.push_back(iteratorTypes[foldedIterDims[0]]);
}
return collapsedIteratorTypes;
}
}
// Get the iterator types for the operand.
- SmallVector<StringRef> iteratorTypes = getCollapsedOpIteratorTypes(
- genericOp.getIteratorTypes().getValue(), collapsingInfo);
+ SmallVector<utils::IteratorType> iteratorTypes = getCollapsedOpIteratorTypes(
+ genericOp.getIteratorTypesArray(), collapsingInfo);
// Get the indexing maps.
auto indexingMaps = llvm::to_vector(
SmallVector<AffineMap, 3> indexingMaps(
op->getNumResults() + op->getNumOperands(),
rewriter.getMultiDimIdentityMap(rank));
- SmallVector<StringRef, 6> iteratorTypes(rank,
- getParallelIteratorTypeName());
+ SmallVector<utils::IteratorType, 6> iteratorTypes(
+ rank, utils::IteratorType::parallel);
auto outputs = getOrCreateOperandsMatchingResultTypes(rewriter, op);
rewriter.replaceOpWithNewOp<linalg::GenericOp>(
op, /*resultTensorTypes=*/op->getResultTypes(),
SmallVector<Value> inputs = linalgOp.getDpsInputOperands();
SmallVector<Value> outputs = linalgOp.getDpsInitOperands();
SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
- SmallVector<StringRef> iterators = linalgOp.getIteratorTypesArray();
+ SmallVector<utils::IteratorType> iterators = linalgOp.getIteratorTypesArray();
SmallVector<Type> resultTypes = linalgOp.hasTensorSemantics()
? TypeRange(ValueRange(outputs))
: TypeRange{};
newMaps.push_back(AffineMap::get(oldOutputMap.getNumDims() + 1, 0, outputExpr,
op.getContext()));
- SmallVector<StringRef> newIteratorTypes;
+ SmallVector<utils::IteratorType> newIteratorTypes;
for (auto &it : llvm::enumerate(op.getIteratorTypesArray())) {
if (insertSplitDimension == it.index() && !control.innerParallel)
- newIteratorTypes.push_back(getParallelIteratorTypeName());
+ newIteratorTypes.push_back(utils::IteratorType::parallel);
newIteratorTypes.push_back(it.value());
if (insertSplitDimension == it.index() && control.innerParallel)
- newIteratorTypes.push_back(getParallelIteratorTypeName());
+ newIteratorTypes.push_back(utils::IteratorType::parallel);
}
// Create the new op matching the original op with an extra parallel
// dimension.
// from the previous op.
unsigned intermRank = newOutputShape.size();
AffineMap inputMap = b.getMultiDimIdentityMap(intermRank);
- SmallVector<StringRef> reductionIteratorTypes;
+ SmallVector<utils::IteratorType> reductionIteratorTypes;
SmallVector<AffineExpr> exprs;
for (unsigned i : llvm::seq<unsigned>(0, intermRank)) {
if (insertSplitDimension == i) {
- reductionIteratorTypes.push_back(getReductionIteratorTypeName());
+ reductionIteratorTypes.push_back(utils::IteratorType::reduction);
} else {
exprs.push_back(b.getAffineDimExpr(i));
- reductionIteratorTypes.push_back(getParallelIteratorTypeName());
+ reductionIteratorTypes.push_back(utils::IteratorType::parallel);
}
}
AffineMap outputMap = AffineMap::get(intermRank, 0, exprs, op.getContext());
// dimension.
auto iteratorTypes = op.getIteratorTypesArray();
iteratorTypes.insert(iteratorTypes.begin() + reductionDimPos,
- getParallelIteratorTypeName());
+ utils::IteratorType::parallel);
GenericOp genericOp =
b.create<GenericOp>(loc, ValueRange(newOutputs).getTypes(), newInputs,
newOutputs, newMaps, iteratorTypes);
AffineMap map = b.getMultiDimIdentityMap(originalOutputType.getRank() + 1);
SmallVector<AffineMap> indexingMaps = {
map, map.dropResult(insertSplitDimension)};
- SmallVector<StringRef> reductionIteratorTypes(
- originalOutputType.getRank() + 1, getParallelIteratorTypeName());
+ SmallVector<utils::IteratorType> reductionIteratorTypes(
+ originalOutputType.getRank() + 1, utils::IteratorType::parallel);
reductionIteratorTypes[insertSplitDimension] =
- getReductionIteratorTypeName();
+ utils::IteratorType::reduction;
// clang-format off
auto reductionOp = b.create<GenericOp>(
auto [loopRanges, loopIndexToRangeIndex] = makeTiledLoopRanges(
b, op.getLoc(), shapeSizesToLoopsMap, allShapeSizes, tileSizes);
- SmallVector<StringRef, 4> iteratorTypes;
+ SmallVector<utils::IteratorType, 4> iteratorTypes;
for (const auto &attr : enumerate(op.getIteratorTypesArray())) {
if (loopIndexToRangeIndex.count(attr.index()))
iteratorTypes.push_back(attr.value());
/// Return the loop iterator type.
SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
LinalgOpTy concreteOp = cast<LinalgOpTy>(op);
- return llvm::to_vector(llvm::map_range(
- concreteOp.getIteratorTypesArray(), [](StringRef iteratorType) {
- return utils::symbolizeIteratorType(iteratorType).value();
- }));
+ return concreteOp.getIteratorTypesArray();
}
/// Return the iteration domain range.
// Step3. create a generic op where the reduction dimension is replaced by a
// parallel dimension of the size of reduction.
- SmallVector<StringRef> newIteratorTypes = linalgOp.getIteratorTypesArray();
- newIteratorTypes[reductionDims[0]] = getParallelIteratorTypeName();
+ SmallVector<utils::IteratorType> newIteratorTypes =
+ linalgOp.getIteratorTypesArray();
+ newIteratorTypes[reductionDims[0]] = utils::IteratorType::parallel;
SmallVector<AffineMap> newMaps = linalgOp.getIndexingMapsArray();
newMaps.back() = AffineMap::get(newMaps.back().getNumDims(), 0, outputExpr,
linalgOp.getContext());
int64_t intermRank =
partialReduce[0].getType().cast<ShapedType>().getRank();
AffineMap inputMap = b.getMultiDimIdentityMap(intermRank);
- SmallVector<StringRef> reductionIteratorTypes;
+ SmallVector<utils::IteratorType> reductionIteratorTypes;
SmallVector<AffineExpr> exprs;
for (int64_t i : llvm::seq<int64_t>(0, intermRank)) {
if (dimToMerge == i) {
- reductionIteratorTypes.push_back(getReductionIteratorTypeName());
+ reductionIteratorTypes.push_back(utils::IteratorType::reduction);
} else {
exprs.push_back(b.getAffineDimExpr(i));
- reductionIteratorTypes.push_back(getParallelIteratorTypeName());
+ reductionIteratorTypes.push_back(utils::IteratorType::parallel);
}
}
AffineMap outputMap =
return vectorizeCopy(rewriter, copyOp);
}
-static SmallVector<StringRef> getNParallelLoopsAttrs(unsigned nParallelLoops) {
- return SmallVector<StringRef>(nParallelLoops, getParallelIteratorTypeName());
+static SmallVector<utils::IteratorType>
+getNParallelLoopsAttrs(unsigned nParallelLoops) {
+ return SmallVector<utils::IteratorType>(nParallelLoops,
+ utils::IteratorType::parallel);
}
/// Rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp (to
/// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
/// ```
/// kw is unrolled, w is unrolled iff dilationW > 1.
-struct Conv1DGenerator : public StructuredGenerator<LinalgOp> {
+struct Conv1DGenerator
+ : public StructuredGenerator<LinalgOp, utils::IteratorType> {
Conv1DGenerator(OpBuilder &builder, LinalgOp linalgOp, int strideW,
int dilationW)
- : StructuredGenerator<LinalgOp>(builder, linalgOp), strideW(strideW),
- dilationW(dilationW) {
+ : StructuredGenerator<LinalgOp, utils::IteratorType>(builder, linalgOp),
+ strideW(strideW), dilationW(dilationW) {
// Determine whether `linalgOp` can be generated with this generator
if (linalgOp.getNumDpsInputs() != 2 || linalgOp.getNumDpsInits() != 1)
return;
return hasOnlyScalarElementwiseOp(op->getRegion(0));
}
-bool isParallelIterator(StringRef iteratorType) {
- return iteratorType == getParallelIteratorTypeName();
+bool isParallelIterator(utils::IteratorType iteratorType) {
+ return iteratorType == utils::IteratorType::parallel;
}
-bool isReductionIterator(StringRef iteratorType) {
- return iteratorType == getReductionIteratorTypeName();
+bool isReductionIterator(utils::IteratorType iteratorType) {
+ return iteratorType == utils::IteratorType::reduction;
}
/// Helper function that creates a memref::DimOp or tensor::DimOp depending on
b.getContext())),
AffineMap::getMultiDimIdentityMap(transposeVector.size(),
b.getContext())};
- SmallVector<llvm::StringRef> iteratorTypes(transposeVector.size(),
- getParallelIteratorTypeName());
+ SmallVector<utils::IteratorType> iteratorTypes(transposeVector.size(),
+ utils::IteratorType::parallel);
// Create a GenericOp to transpose `inputTensor` into `outputTensor`.
- auto transposeOp = b.create<GenericOp>(
- loc, resultTensorType, inputTensor, outputTensor,
- b.getAffineMapArrayAttr(indexingMaps), b.getStrArrayAttr(iteratorTypes),
- /*doc=*/nullptr,
- /*library_call=*/nullptr);
+ auto transposeOp =
+ b.create<GenericOp>(loc, resultTensorType, inputTensor, outputTensor,
+ indexingMaps, iteratorTypes);
Region &body = transposeOp.getRegion();
body.push_back(new Block());
body.front().addArguments({elementType, elementType}, {loc, loc});
AffineMap id =
AffineMap::getMultiDimIdentityMap(memrefTypeTo.getRank(), b.getContext());
- SmallVector<StringRef> iteratorTypes(memrefTypeTo.getRank(),
- getParallelIteratorTypeName());
+ SmallVector<utils::IteratorType> iteratorTypes(memrefTypeTo.getRank(),
+ utils::IteratorType::parallel);
return b.create<linalg::GenericOp>(
loc,
/*inputs=*/from,
template <>
void GenerateLoopNest<scf::ForOp>::doit(
OpBuilder &b, Location loc, ArrayRef<Range> loopRanges, LinalgOp linalgOp,
- ArrayRef<StringRef> iteratorTypes,
+ ArrayRef<utils::IteratorType> iteratorTypes,
function_ref<scf::ValueVector(OpBuilder &, Location, ValueRange,
ValueRange)>
bodyBuilderFn,
template <>
void GenerateLoopNest<AffineForOp>::doit(
OpBuilder &b, Location loc, ArrayRef<Range> loopRanges, LinalgOp linalgOp,
- ArrayRef<StringRef> iteratorTypes,
+ ArrayRef<utils::IteratorType> iteratorTypes,
function_ref<scf::ValueVector(OpBuilder &, Location, ValueRange,
ValueRange)>
bodyBuilderFn,
// exceeds 10.
static void generateParallelLoopNest(
OpBuilder &b, Location loc, ValueRange lbs, ValueRange ubs,
- ValueRange steps, ArrayRef<StringRef> iteratorTypes,
+ ValueRange steps, ArrayRef<utils::IteratorType> iteratorTypes,
ArrayRef<linalg::ProcInfo> procInfo,
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn,
SmallVectorImpl<Value> &ivStorage) {
template <>
void GenerateLoopNest<scf::ParallelOp>::doit(
OpBuilder &b, Location loc, ArrayRef<Range> loopRanges, LinalgOp linalgOp,
- ArrayRef<StringRef> iteratorTypes,
+ ArrayRef<utils::IteratorType> iteratorTypes,
function_ref<scf::ValueVector(OpBuilder &, Location, ValueRange,
ValueRange)>
bodyBuilderFn,
/// as we use adj matrix for the graph.
/// The sorted result will put the first Reduction iterator to the
/// latest possible index.
-static bool topSortOptimal(unsigned n, ArrayRef<StringRef> iteratorTypes,
+static bool topSortOptimal(unsigned n,
+ ArrayRef<utils::IteratorType> iteratorTypes,
std::vector<unsigned> &topSort,
std::vector<unsigned> &inDegree,
std::vector<std::vector<bool>> &adjM) {
using namespace mlir;
using namespace mlir::tosa;
-SmallVector<StringRef>
+SmallVector<utils::IteratorType>
mlir::tosa::getNParallelLoopsAttrs(unsigned nParallelLoops) {
- return SmallVector<StringRef>(nParallelLoops, getParallelIteratorTypeName());
+ return SmallVector<utils::IteratorType>(nParallelLoops,
+ utils::IteratorType::parallel);
}
SmallVector<Value>
}
namespace {
-struct IteratorType {
- IteratorType(StringRef strRef) : strRef(strRef) {}
- bool isOfType(Attribute attr) const {
- auto sAttr = attr.dyn_cast<StringAttr>();
- return sAttr && sAttr.getValue() == strRef;
- }
- StringRef strRef;
-};
-struct Par : public IteratorType {
- Par() : IteratorType(getParallelIteratorTypeName()) {}
-};
-struct Red : public IteratorType {
- Red() : IteratorType(getReductionIteratorTypeName()) {}
-};
/// Generate a vector implementation for matmat, matvec and tmatvec.
/// This unrolls outer-products along the reduction dimension.
struct UnrolledOuterProductGenerator
- : public StructuredGenerator<vector::ContractionOp> {
+ : public StructuredGenerator<vector::ContractionOp, vector::IteratorType> {
UnrolledOuterProductGenerator(OpBuilder &builder, vector::ContractionOp op)
- : StructuredGenerator<vector::ContractionOp>(builder, op),
+ : StructuredGenerator<vector::ContractionOp, vector::IteratorType>(
+ builder, op),
kind(op.getKind()), lhs(op.getLhs()), rhs(op.getRhs()),
res(op.getAcc()), lhsType(op.getLhsType()) {}
} else {
MemRefLayoutAttrInterface updatedLayout;
if (auto strided = layout.dyn_cast<StridedLayoutAttr>()) {
- auto strides = llvm::to_vector(strided.getStrides().drop_back(dimsToDrop));
- updatedLayout = StridedLayoutAttr::get(strided.getContext(), strided.getOffset(), strides);
+ auto strides =
+ llvm::to_vector(strided.getStrides().drop_back(dimsToDrop));
+ updatedLayout = StridedLayoutAttr::get(strided.getContext(),
+ strided.getOffset(), strides);
} else {
AffineMap map = srcType.getLayout().getAffineMap();
int numSymbols = map.getNumSymbols();
# TODO: Support emission of pure memref form.
indexing_maps_attr = ArrayAttr.get(
[AffineMapAttr.get(am) for am in indexing_maps])
- iterator_types_attr = ArrayAttr.get(
- [StringAttr.get(s) for s in op_config.iterator_types])
+ iterator_types_attr = ArrayAttr.get([
+ Attribute.parse(f"#linalg.iterator_type<{s}>")
+ for s in op_config.iterator_types
+ ])
# Compute the index attributes used when emitting a named structured op.
index_attrs = {} # type: Dict[str, DenseElementAttr]
# An operation is rank polymorphic if the iteration domain has rank zero.
if not iterator_types_attr:
rank = ShapedType(outs[0].type).rank
- iterator_types_attr = ArrayAttr.get([StringAttr.get("parallel")] * rank)
+ iterator_types_attr = ArrayAttr.get(
+ [Attribute.parse("#linalg.iterator_type<parallel>")] * rank)
scalar_map = AffineMap.get(rank, 0, [])
tensor_map = AffineMap.get_identity(rank)
indexing_maps = []
// expected-error @+1 {{expected op with 2 inputs and 1 output}}
%0 = test.linalg_conv_op {
indexing_maps = [#map, #map],
- iterator_types = ["parallel"]}
+ iterator_types = [#test.iterator_type<parallel>]}
ins(%arg0 : tensor<?xf32>) outs(%arg1 : tensor<?xf32>) {
^bb0(%arg2 : f32, %arg3 : f32):
linalg.yield %arg3 : f32
indexing_maps = [affine_map<(d0, d1) -> (d0 * 2)>,
affine_map<(d0, d1) -> (d1)>,
affine_map<(d0, d1) -> (d0)>],
- iterator_types = ["parallel", "reduction"]}
+ iterator_types = [#test.iterator_type<parallel>,
+ #test.iterator_type<reduction>]}
ins(%arg0, %arg1 : tensor<?xf32>, tensor<?xf32>)
outs(%arg2 : tensor<?xf32>) {
^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32):
indexing_maps = [affine_map<(d0, d1) -> (d0 + d1, d0)>,
affine_map<(d0, d1) -> (d1)>,
affine_map<(d0, d1) -> (d0)>],
- iterator_types = ["parallel", "reduction"]}
+ iterator_types = [#test.iterator_type<parallel>,
+ #test.iterator_type<reduction>]}
ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?xf32>)
outs(%arg2 : tensor<?xf32>) {
^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32):
indexing_maps = [affine_map<(d0, d1) -> (d1)>,
affine_map<(d0, d1) -> (d1 + d0)>,
affine_map<(d0, d1) -> (d0)>],
- iterator_types = ["parallel", "reduction"]}
+ iterator_types = [#test.iterator_type<parallel>,
+ #test.iterator_type<reduction>]}
ins(%arg0, %arg1 : tensor<?xf32>, tensor<?xf32>)
outs(%arg2 : tensor<?xf32>) {
^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32):
indexing_maps = [affine_map<(d0, d1) -> (d0)>,
affine_map<(d0, d1) -> (d1)>,
affine_map<(d0, d1) -> (d0 + d1)>],
- iterator_types = ["parallel", "parallel"]}
+ iterator_types = [#test.iterator_type<parallel>,
+ #test.iterator_type<parallel>]}
ins(%arg0, %arg1 : tensor<?xf32>, tensor<?xf32>)
outs(%arg2 : tensor<?xf32>) {
^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32):
indexing_maps = [affine_map<(d0, d1) -> (d0 + d1)>,
affine_map<(d0, d1) -> (d1)>,
affine_map<(d0, d1) -> (d0, d1)>],
- iterator_types = ["parallel", "parallel"]}
+ iterator_types = [#test.iterator_type<parallel>,
+ #test.iterator_type<parallel>]}
ins(%arg0, %arg1 : tensor<?xf32>, tensor<?xf32>)
outs(%arg2 : tensor<?x?xf32>) {
^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32):
indexing_maps = [affine_map<(d0, d1, d2) -> (d0 + d1)>,
affine_map<(d0, d1, d2) -> (d1)>,
affine_map<(d0, d1, d2) -> (d0, d2)>],
- iterator_types = ["parallel", "reduction", "parallel"]}
+ iterator_types = [#test.iterator_type<parallel>,
+ #test.iterator_type<reduction>,
+ #test.iterator_type<parallel>]}
ins(%arg0, %arg1 : tensor<?xf32>, tensor<?xf32>)
outs(%arg2 : tensor<?x?xf32>) {
^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32):
indexing_maps = [affine_map<(d0, d1, d2) -> (d0 + d1)>,
affine_map<(d0, d1, d2) -> (d1, d2)>,
affine_map<(d0, d1, d2) -> (d0)>],
- iterator_types = ["parallel", "reduction", "reduction"]}
+ iterator_types = [#test.iterator_type<parallel>,
+ #test.iterator_type<reduction>,
+ #test.iterator_type<reduction>]}
ins(%arg0, %arg1 : tensor<?xf32>, tensor<?x?xf32>)
outs(%arg2 : tensor<?xf32>) {
^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32):
indexing_maps = [affine_map<(d0, d1, d2) -> (d0 + d1, d2)>,
affine_map<(d0, d1, d2) -> (d1)>,
affine_map<(d0, d1, d2) -> (d0)>],
- iterator_types = ["parallel", "reduction", "reduction"]}
+ iterator_types = [#test.iterator_type<parallel>,
+ #test.iterator_type<reduction>,
+ #test.iterator_type<reduction>]}
ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?xf32>)
outs(%arg2 : tensor<?xf32>) {
^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32):
indexing_maps = [affine_map<(d0, d1) -> (d0 + d1)>,
affine_map<(d0, d1) -> (d1)>,
affine_map<(d0, d1) -> (d0)>],
- iterator_types = ["parallel", "parallel"]}
+ iterator_types = [#test.iterator_type<parallel>,
+ #test.iterator_type<parallel>]}
ins(%arg0, %arg1 : tensor<?xf32>, tensor<?xf32>)
outs(%arg2 : tensor<?xf32>) {
^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32):
// -----
func.func @generic_wrong_iterator(%arg0: memref<1xi32>) {
- // expected-error @+1 {{op unexpected iterator_type (random)}}
+ // expected-error @+4 {{unexpected iterator_type (random)}}
linalg.generic {
indexing_maps = [ affine_map<(i) -> (i)> ],
iterator_types = ["random"]}
^bb1(%arg1: !pdl.operation):
%match_attr = transform.structured.match
ops{["linalg.generic"]}
- attributes{iterator_types = ["parallel", "parallel", "parallel"]}
+ attributes{iterator_types = [
+ #linalg.iterator_type<parallel>,
+ #linalg.iterator_type<parallel>,
+ #linalg.iterator_type<parallel>]}
in %arg1
transform.test_print_remark_at_operand %match_attr, "matched complex attr" : !pdl.operation
transform.test_consume_operand %match_attr
%no_match = transform.structured.match
- attributes{iterator_types = ["parallel", "parallel", "reduction"]}
+ attributes{iterator_types = [
+ #linalg.iterator_type<parallel>,
+ #linalg.iterator_type<parallel>,
+ #linalg.iterator_type<reduction>]}
in %arg1
// expected-remark @below {{0}}
transform.test_print_number_of_associated_payload_ir_ops %no_match
mlir_tablegen(TestTypeDefs.cpp.inc -gen-typedef-defs -typedefs-dialect=test)
add_public_tablegen_target(MLIRTestTypeDefIncGen)
+set(LLVM_TARGET_DEFINITIONS TestEnumDefs.td)
+mlir_tablegen(TestOpEnums.h.inc -gen-enum-decls)
+mlir_tablegen(TestOpEnums.cpp.inc -gen-enum-defs)
+add_public_tablegen_target(MLIRTestEnumDefIncGen)
+
set(LLVM_TARGET_DEFINITIONS TestOps.td)
mlir_tablegen(TestOps.h.inc -gen-op-decls)
mlir_tablegen(TestOps.cpp.inc -gen-op-defs)
mlir_tablegen(TestOpsDialect.h.inc -gen-dialect-decls -dialect=test)
mlir_tablegen(TestOpsDialect.cpp.inc -gen-dialect-defs -dialect=test)
-mlir_tablegen(TestOpEnums.h.inc -gen-enum-decls)
-mlir_tablegen(TestOpEnums.cpp.inc -gen-enum-defs)
mlir_tablegen(TestPatterns.inc -gen-rewriters)
add_public_tablegen_target(MLIRTestOpsIncGen)
DEPENDS
MLIRTestAttrDefIncGen
+ MLIRTestEnumDefIncGen
MLIRTestInterfaceIncGen
MLIRTestTypeDefIncGen
MLIRTestOpsIncGen
// To get the test dialect definition.
include "TestDialect.td"
+include "TestEnumDefs.td"
+include "mlir/Dialect/Utils/StructuredOpsUtils.td"
include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/BuiltinAttributeInterfaces.td"
include "mlir/IR/EnumAttr.td"
"array_of_ints", "int32_t">;
// An array of enum attributes.
-def TestSimpleEnum : I32EnumAttr<"SimpleEnum", "", [
- I32EnumAttrCase<"a", 0>,
- I32EnumAttrCase<"b", 1>
- ]> {
- let genSpecializedAttr = 0;
- let cppNamespace = "::test";
-}
def TestSimpleEnumAttr : EnumAttr<Test_Dialect, TestSimpleEnum, "simple_enum"> {
let assemblyFormat = "`` $value";
}
let assemblyFormat = "`<` $a (`>`) : (`,` ` ` custom<TrueFalse>($b)^ `>`)?";
}
+def Test_IteratorTypeEnum
+ : EnumAttr<Test_Dialect, IteratorType, "iterator_type"> {
+ let assemblyFormat = "`<` $value `>`";
+}
+
+def Test_IteratorTypeArrayAttr
+ : TypedArrayAttrBase<Test_IteratorTypeEnum,
+ "Iterator type should be an enum.">;
+
+
#endif // TEST_ATTRDEFS
#include <tuple>
#include "TestTraits.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Dialect.h"
--- /dev/null
+//===-- TestEnumDefs.td - Test dialect enum definitions ----*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// TableGen enum definitions for Test dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef TEST_ENUMDEFS
+#define TEST_ENUMDEFS
+
+include "mlir/IR/EnumAttr.td"
+
+def I32Case5: I32EnumAttrCase<"case5", 5>;
+def I32Case10: I32EnumAttrCase<"case10", 10>;
+
+def SomeI32Enum: I32EnumAttr<
+ "SomeI32Enum", "", [I32Case5, I32Case10]>;
+
+def I64Case5: I64EnumAttrCase<"case5", 5>;
+def I64Case10: I64EnumAttrCase<"case10", 10>;
+
+def SomeI64Enum: I64EnumAttr<
+ "SomeI64Enum", "", [I64Case5, I64Case10]>;
+
+//===----------------------------------------------------------------------===//
+// Test Enum
+//===----------------------------------------------------------------------===//
+
+// Define the C++ enum.
+def TestEnum
+ : I32EnumAttr<"TestEnum", "a test enum", [
+ I32EnumAttrCase<"First", 0, "first">,
+ I32EnumAttrCase<"Second", 1, "second">,
+ I32EnumAttrCase<"Third", 2, "third">,
+ ]> {
+ let genSpecializedAttr = 0;
+ let cppNamespace = "test";
+}
+
+def TestSimpleEnum : I32EnumAttr<"SimpleEnum", "", [
+ I32EnumAttrCase<"a", 0>,
+ I32EnumAttrCase<"b", 1>
+ ]> {
+ let genSpecializedAttr = 0;
+ let cppNamespace = "::test";
+}
+
+//===----------------------------------------------------------------------===//
+// Test Bit Enum
+//===----------------------------------------------------------------------===//
+
+// Define the C++ enum.
+def TestBitEnum
+ : I32BitEnumAttr<"TestBitEnum", "a test bit enum", [
+ I32BitEnumAttrCaseBit<"Read", 0, "read">,
+ I32BitEnumAttrCaseBit<"Write", 1, "write">,
+ I32BitEnumAttrCaseBit<"Execute", 2, "execute">,
+ ]> {
+ let genSpecializedAttr = 0;
+ let cppNamespace = "test";
+ let separator = ", ";
+}
+
+// Define an enum with a different separator
+def TestBitEnumVerticalBar
+ : I32BitEnumAttr<"TestBitEnumVerticalBar", "another test bit enum", [
+ I32BitEnumAttrCaseBit<"User", 0, "user">,
+ I32BitEnumAttrCaseBit<"Group", 1, "group">,
+ I32BitEnumAttrCaseBit<"Other", 2, "other">,
+ ]> {
+ let genSpecializedAttr = 0;
+ let cppNamespace = "test";
+ let separator = " | ";
+}
+
+//===----------------------------------------------------------------------===//
+// Test Patterns (Multi-result Ops)
+
+def MultiResultOpKind1: I64EnumAttrCase<"kind1", 1>;
+def MultiResultOpKind2: I64EnumAttrCase<"kind2", 2>;
+def MultiResultOpKind3: I64EnumAttrCase<"kind3", 3>;
+def MultiResultOpKind4: I64EnumAttrCase<"kind4", 4>;
+def MultiResultOpKind5: I64EnumAttrCase<"kind5", 5>;
+def MultiResultOpKind6: I64EnumAttrCase<"kind6", 6>;
+
+def MultiResultOpEnum: I64EnumAttr<
+ "MultiResultOpEnum", "Multi-result op kinds", [
+ MultiResultOpKind1, MultiResultOpKind2, MultiResultOpKind3,
+ MultiResultOpKind4, MultiResultOpKind5, MultiResultOpKind6
+ ]>;
+
+#endif // TEST_ENUMDEFS
);
}
-def I32Case5: I32EnumAttrCase<"case5", 5>;
-def I32Case10: I32EnumAttrCase<"case10", 10>;
-
-def SomeI32Enum: I32EnumAttr<
- "SomeI32Enum", "", [I32Case5, I32Case10]>;
-
def I32EnumAttrOp : TEST_Op<"i32_enum_attr"> {
let arguments = (ins SomeI32Enum:$attr);
let results = (outs I32:$val);
}
-def I64Case5: I64EnumAttrCase<"case5", 5>;
-def I64Case10: I64EnumAttrCase<"case10", 10>;
-
-def SomeI64Enum: I64EnumAttr<
- "SomeI64Enum", "", [I64Case5, I64Case10]>;
-
def I64EnumAttrOp : TEST_Op<"i64_enum_attr"> {
let arguments = (ins SomeI64Enum:$attr);
let results = (outs I32:$val);
// Test Enum Attributes
//===----------------------------------------------------------------------===//
-// Define the C++ enum.
-def TestEnum
- : I32EnumAttr<"TestEnum", "a test enum", [
- I32EnumAttrCase<"First", 0, "first">,
- I32EnumAttrCase<"Second", 1, "second">,
- I32EnumAttrCase<"Third", 2, "third">,
- ]> {
- let genSpecializedAttr = 0;
- let cppNamespace = "test";
-}
-
// Define the enum attribute.
def TestEnumAttr : EnumAttr<Test_Dialect, TestEnum, "enum">;
// Test Bit Enum Attributes
//===----------------------------------------------------------------------===//
-// Define the C++ enum.
-def TestBitEnum
- : I32BitEnumAttr<"TestBitEnum", "a test bit enum", [
- I32BitEnumAttrCaseBit<"Read", 0, "read">,
- I32BitEnumAttrCaseBit<"Write", 1, "write">,
- I32BitEnumAttrCaseBit<"Execute", 2, "execute">,
- ]> {
- let genSpecializedAttr = 0;
- let cppNamespace = "test";
- let separator = ", ";
-}
-
// Define the enum attribute.
def TestBitEnumAttr : EnumAttr<Test_Dialect, TestBitEnum, "bit_enum"> {
let assemblyFormat = "`<` $value `>`";
let assemblyFormat = "$value (`tag` $tag^)? attr-dict";
}
-// Define an enum with a different separator
-def TestBitEnumVerticalBar
- : I32BitEnumAttr<"TestBitEnumVerticalBar", "another test bit enum", [
- I32BitEnumAttrCaseBit<"User", 0, "user">,
- I32BitEnumAttrCaseBit<"Group", 1, "group">,
- I32BitEnumAttrCaseBit<"Other", 2, "other">,
- ]> {
- let genSpecializedAttr = 0;
- let cppNamespace = "test";
- let separator = " | ";
-}
-
def TestBitEnumVerticalBarAttr
: EnumAttr<Test_Dialect, TestBitEnumVerticalBar, "bit_enum_vbar"> {
let assemblyFormat = "`<` $value `>`";
def : Pat<(I32EnumAttrOp I32Case5), (I32EnumAttrOp I32Case10)>;
def : Pat<(I64EnumAttrOp I64Case5), (I64EnumAttrOp I64Case10)>;
-//===----------------------------------------------------------------------===//
-// Test Patterns (Multi-result Ops)
-
-def MultiResultOpKind1: I64EnumAttrCase<"kind1", 1>;
-def MultiResultOpKind2: I64EnumAttrCase<"kind2", 2>;
-def MultiResultOpKind3: I64EnumAttrCase<"kind3", 3>;
-def MultiResultOpKind4: I64EnumAttrCase<"kind4", 4>;
-def MultiResultOpKind5: I64EnumAttrCase<"kind5", 5>;
-def MultiResultOpKind6: I64EnumAttrCase<"kind6", 6>;
-
-def MultiResultOpEnum: I64EnumAttr<
- "MultiResultOpEnum", "Multi-result op kinds", [
- MultiResultOpKind1, MultiResultOpKind2, MultiResultOpKind3,
- MultiResultOpKind4, MultiResultOpKind5, MultiResultOpKind6
- ]>;
-
def ThreeResultOp : TEST_Op<"three_result"> {
let arguments = (ins MultiResultOpEnum:$kind);
let results = (outs I32:$result1, F32:$result2, F32:$result3);
return ®ionBuilder;
}
- mlir::ArrayAttr getIteratorTypes() {
- return getOperation()->getAttrOfType<mlir::ArrayAttr>("iterator_types");
+ llvm::SmallVector<mlir::utils::IteratorType> getIteratorTypesArray() {
+ auto attrs = getOperation()->getAttrOfType<mlir::ArrayAttr>("iterator_types");
+ auto range = attrs.getAsValueRange<IteratorTypeAttr, mlir::utils::IteratorType>();
+ return {range.begin(), range.end()};
}
mlir::ArrayAttr getIndexingMaps() {
return ®ionBuilder;
}
- mlir::ArrayAttr getIteratorTypes() {
- return getOperation()->getAttrOfType<mlir::ArrayAttr>("iterator_types");
+ llvm::SmallVector<mlir::utils::IteratorType> getIteratorTypesArray() {
+ auto attrs = getOperation()->getAttrOfType<mlir::ArrayAttr>("iterator_types");
+ auto range = attrs.getAsValueRange<IteratorTypeAttr, mlir::utils::IteratorType>();
+ return {range.begin(), range.end()};
}
mlir::ArrayAttr getIndexingMaps() {
let extraClassDeclaration = structuredOpsBaseDecls # [{{
// Auto-generated.
- SmallVector<StringRef> getIteratorTypesArray();
+ SmallVector<utils::IteratorType> getIteratorTypesArray();
ArrayAttr getIndexingMaps();
static void regionBuilder(ImplicitLocOpBuilder &b,
Block &block, ArrayRef<NamedAttribute> attrs);
// {1}: Comma interleaved iterator type names.
static const char structuredOpIteratorTypesFormat[] =
R"FMT(
-SmallVector<StringRef> {0}::getIteratorTypesArray() {{
- return SmallVector<StringRef>{{ {1} };
+SmallVector<utils::IteratorType> {0}::getIteratorTypesArray() {{
+ return SmallVector<utils::IteratorType>{{ {1} };
}
)FMT";
// {0}: Class name
static const char rankPolyStructuredOpIteratorTypesFormat[] =
R"FMT(
-SmallVector<StringRef> {0}::getIteratorTypesArray() {{
+SmallVector<utils::IteratorType> {0}::getIteratorTypesArray() {{
int64_t rank = getRank(getDpsInitOperand(0));
- return SmallVector<StringRef>(rank, getParallelIteratorTypeName());
+ return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
}
)FMT";
[&](LinalgIteratorTypeDef it) {
switch (it) {
case LinalgIteratorTypeDef::parallel:
- ss << "getParallelIteratorTypeName()";
+ ss << "utils::IteratorType::parallel";
break;
case LinalgIteratorTypeDef::reduction:
- ss << "getReductionIteratorTypeName()";
+ ss << "utils::IteratorType::reduction";
break;
}
});
"lib/Dialect/Test/TestOpsDialect.cpp.inc",
),
(
- ["-gen-enum-decls"],
- "lib/Dialect/Test/TestOpEnums.h.inc",
- ),
- (
- ["-gen-enum-defs"],
- "lib/Dialect/Test/TestOpEnums.cpp.inc",
- ),
- (
["-gen-rewriters"],
"lib/Dialect/Test/TestPatterns.inc",
),
)
gentbl_cc_library(
+ name = "TestEnumDefsIncGen",
+ strip_include_prefix = "lib/Dialect/Test",
+ tbl_outs = [
+ (
+ ["-gen-enum-decls"],
+ "lib/Dialect/Test/TestOpEnums.h.inc",
+ ),
+ (
+ ["-gen-enum-defs"],
+ "lib/Dialect/Test/TestOpEnums.cpp.inc",
+ ),
+ ],
+ tblgen = "//mlir:mlir-tblgen",
+ td_file = "lib/Dialect/Test/TestEnumDefs.td",
+ test = True,
+ deps = [
+ ":TestOpTdFiles",
+ ],
+)
+
+gentbl_cc_library(
name = "TestTypeDefsIncGen",
strip_include_prefix = "lib/Dialect/Test",
tbl_outs = [
],
deps = [
":TestAttrDefsIncGen",
+ ":TestEnumDefsIncGen",
":TestInterfacesIncGen",
":TestOpsIncGen",
":TestTypeDefsIncGen",
"//mlir:DerivedAttributeOpInterface",
"//mlir:DestinationStyleOpInterface",
"//mlir:Dialect",
+ "//mlir:DialectUtils",
"//mlir:FuncDialect",
"//mlir:FuncTransforms",
"//mlir:IR",