/// the reduction.
bool isRowMajorBatchMatmul(ArrayAttr indexingMaps);
-/// Attribute name for the AffineArrayAttr which encodes the relationship
-/// between a structured op iterators' and its operands.
-constexpr StringRef getIndexingMapsAttrName() { return "indexing_maps"; }
-
-/// Attribute name for the StrArrayAttr which encodes the type of a structured
-/// op's iterators.
-constexpr StringRef getIteratorTypesAttrName() { return "iterator_types"; }
-
-/// Attribute name for the StrArrayAttr which encodes the distribution type for
-/// `linalg.tiled_loop`.
-constexpr StringRef getDistributionTypesAttrName() {
- return "distribution_types";
-}
-
-/// Attribute name for the StringAttr which encodes an optional documentation
-/// string of the structured op.
-constexpr StringRef getDocAttrName() { return "doc"; }
-
-/// Attribute name for the StrArrayAttr which encodes the external library
-/// function that implements the structured op.
-constexpr StringRef getLibraryCallAttrName() { return "library_call"; }
-
-/// Attribute name for the StrArrayAttr which encodes the value of strides.
-constexpr StringRef getStridesAttrName() { return "strides"; }
-
-/// Attribute name for the StrArrayAttr which encodes the value of dilations.
-constexpr StringRef getDilationsAttrName() { return "dilations"; }
-
-/// Attribute name for the StrArrayAttr which encodes the value of paddings.
-constexpr StringRef getPaddingAttrName() { return "padding"; }
-
/// Use to encode that a particular iterator type has parallel semantics.
constexpr StringRef getParallelIteratorTypeName() { return "parallel"; }
m = m.compose(permutationMap);
newIndexingMaps.push_back(m);
}
- genericOp->setAttr(getIndexingMapsAttrName(),
- rewriter.getAffineMapArrayAttr(newIndexingMaps));
+ genericOp.setIndexingMapsAttr(
+ rewriter.getAffineMapArrayAttr(newIndexingMaps));
// 3. Compute the interchanged iterator types.
ArrayRef<Attribute> itTypes = genericOp.getIteratorTypes().getValue();
SmallVector<int64_t> permutation(interchangeVector.begin(),
interchangeVector.end());
applyPermutationToVector(itTypesVector, permutation);
- genericOp->setAttr(getIteratorTypesAttrName(),
- ArrayAttr::get(context, itTypesVector));
+ genericOp.setIteratorTypesAttr(rewriter.getArrayAttr(itTypesVector));
// 4. Transform the index operations by applying the permutation map.
if (genericOp.hasIndexSemantics()) {
ArrayRef<IteratorType> iteratorTypes) {
result.addOperands({lhs, rhs, acc});
result.addTypes(acc.getType());
- result.addAttribute(::mlir::getIndexingMapsAttrName(),
+ result.addAttribute(getIndexingMapsAttrName(result.name),
builder.getAffineMapArrayAttr(
AffineMap::inferFromExprList(indexingExprs)));
result.addAttribute(
- ::mlir::getIteratorTypesAttrName(),
+ getIteratorTypesAttrName(result.name),
builder.getArrayAttr(llvm::to_vector(llvm::map_range(
iteratorTypes, [&](IteratorType t) -> mlir::Attribute {
return IteratorTypeAttr::get(builder.getContext(), t);
ArrayAttr iteratorTypes, CombiningKind kind) {
result.addOperands({lhs, rhs, acc});
result.addTypes(acc.getType());
- result.addAttribute(::mlir::getIndexingMapsAttrName(), indexingMaps);
- result.addAttribute(::mlir::getIteratorTypesAttrName(), iteratorTypes);
- result.addAttribute(ContractionOp::getKindAttrStrName(),
+ result.addAttribute(getIndexingMapsAttrName(result.name), indexingMaps);
+ result.addAttribute(getIteratorTypesAttrName(result.name), iteratorTypes);
+ result.addAttribute(getKindAttrName(result.name),
CombiningKindAttr::get(builder.getContext(), kind));
}
// represented as an array of strings.
// TODO: Remove this conversion once tests are fixed.
ArrayAttr iteratorTypes =
- result.attributes.get("iterator_types").cast<ArrayAttr>();
+ result.attributes.get(getIteratorTypesAttrName(result.name))
+ .cast<ArrayAttr>();
SmallVector<Attribute> iteratorTypeAttrs;
if (!maybeIteratorType.has_value())
return parser.emitError(loc) << "unexpected iterator_type (" << s << ")";
- iteratorTypeAttrs.push_back(IteratorTypeAttr::get(
- parser.getContext(), maybeIteratorType.value()));
+ iteratorTypeAttrs.push_back(
+ IteratorTypeAttr::get(parser.getContext(), maybeIteratorType.value()));
}
- result.attributes.set("iterator_types",
+ result.attributes.set(getIteratorTypesAttrName(result.name),
parser.getBuilder().getArrayAttr(iteratorTypeAttrs));
- if (!result.attributes.get(ContractionOp::getKindAttrStrName())) {
+ if (!result.attributes.get(getKindAttrName(result.name))) {
result.addAttribute(
- ContractionOp::getKindAttrStrName(),
+ getKindAttrName(result.name),
CombiningKindAttr::get(result.getContext(),
ContractionOp::getDefaultKind()));
}
return success();
}
-ArrayRef<StringRef> ContractionOp::getTraitAttrNames() {
- static constexpr StringRef names[3] = {::mlir::getIndexingMapsAttrName(),
- ::mlir::getIteratorTypesAttrName(),
- ContractionOp::getKindAttrStrName()};
- return llvm::makeArrayRef(names);
+SmallVector<StringRef> ContractionOp::getTraitAttrNames() {
+ return SmallVector<StringRef>{getIndexingMapsAttrName(),
+ getIteratorTypesAttrName(), getKindAttrName()};
}
static int64_t getResultIndex(AffineMap map, AffineExpr targetExpr) {