// Pull in all enum type definitions and utility function declarations.
#include "mlir/Dialect/Vector/IR/VectorOpsEnums.h.inc"
+#define GET_ATTRDEF_CLASSES
+#include "mlir/Dialect/Vector/IR/VectorOpsAttrDefs.h.inc"
+
namespace mlir {
class MLIRContext;
class RewritePatternSet;
/// chain.
void populatePropagateVectorDistributionPatterns(RewritePatternSet &patterns);
-/// An attribute that specifies the combining function for `vector.contract`,
-/// and `vector.reduction`.
-class CombiningKindAttr
- : public Attribute::AttrBase<CombiningKindAttr, Attribute,
- detail::BitmaskEnumStorage> {
-public:
- using Base::Base;
-
- static CombiningKindAttr get(CombiningKind kind, MLIRContext *context);
-
- CombiningKind getKind() const;
-
- void print(AsmPrinter &p) const;
- static Attribute parse(AsmParser &parser, Type type);
-};
-
/// Collects patterns to progressively lower vector.broadcast ops on high-D
/// vectors to low-D vector ops.
void populateVectorBroadcastLoweringPatterns(RewritePatternSet &patterns);
let genSpecializedAttr = 0;
}
-def Vector_CombiningKindAttr : DialectAttr<
- Vector_Dialect,
- CPred<"$_self.isa<::mlir::vector::CombiningKindAttr>()">,
- "Kind of combining function for contractions and reductions"> {
- let storageType = "::mlir::vector::CombiningKindAttr";
- let returnType = "::mlir::vector::CombiningKind";
- let convertFromStorage = "$_self.getKind()";
- let constBuilderCall =
- "::mlir::vector::CombiningKindAttr::get($0, $_builder.getContext())";
+/// An attribute that specifies the combining function for `vector.contract`,
+/// and `vector.reduction`.
+def Vector_CombiningKindAttr : EnumAttr<Vector_Dialect, CombiningKind, "kind"> {
+ let assemblyFormat = "`<` $value `>`";
}
// TODO: Add an attribute to specify a different algebra with operators other
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Support/LLVM.h"
-#include "mlir/Support/MathExtras.h"
#include "llvm/ADT/StringSet.h"
+#include "llvm/ADT/TypeSwitch.h"
#include "llvm/ADT/bit.h"
#include <numeric>
} // namespace vector
} // namespace mlir
-CombiningKindAttr CombiningKindAttr::get(CombiningKind kind,
- MLIRContext *context) {
- return Base::get(context, static_cast<uint64_t>(kind));
-}
-
-CombiningKind CombiningKindAttr::getKind() const {
- return static_cast<CombiningKind>(getImpl()->value);
-}
-
-static constexpr const CombiningKind combiningKindsList[] = {
- // clang-format off
- CombiningKind::ADD,
- CombiningKind::MUL,
- CombiningKind::MINUI,
- CombiningKind::MINSI,
- CombiningKind::MINF,
- CombiningKind::MAXUI,
- CombiningKind::MAXSI,
- CombiningKind::MAXF,
- CombiningKind::AND,
- CombiningKind::OR,
- CombiningKind::XOR,
- // clang-format on
-};
-
-void CombiningKindAttr::print(AsmPrinter &printer) const {
- printer << "<";
- auto kinds = llvm::make_filter_range(combiningKindsList, [&](auto kind) {
- return bitEnumContains(this->getKind(), kind);
- });
- llvm::interleaveComma(kinds, printer,
- [&](auto kind) { printer << stringifyEnum(kind); });
- printer << ">";
-}
-
-Attribute CombiningKindAttr::parse(AsmParser &parser, Type type) {
- if (failed(parser.parseLess()))
- return {};
-
- StringRef elemName;
- if (failed(parser.parseKeyword(&elemName)))
- return {};
-
- auto kind = symbolizeCombiningKind(elemName);
- if (!kind) {
- parser.emitError(parser.getNameLoc(), "Unknown combining kind: ")
- << elemName;
- return {};
- }
-
- if (failed(parser.parseGreater()))
- return {};
-
- return CombiningKindAttr::get(*kind, parser.getContext());
-}
-
-Attribute VectorDialect::parseAttribute(DialectAsmParser &parser,
- Type type) const {
- StringRef attrKind;
- if (parser.parseKeyword(&attrKind))
- return {};
-
- if (attrKind == "kind")
- return CombiningKindAttr::parse(parser, {});
-
- parser.emitError(parser.getNameLoc(), "Unknown attribute type: ") << attrKind;
- return {};
-}
-
-void VectorDialect::printAttribute(Attribute attr,
- DialectAsmPrinter &os) const {
- if (auto ck = attr.dyn_cast<CombiningKindAttr>()) {
- os << "kind";
- ck.print(os);
- return;
- }
- llvm_unreachable("Unknown attribute type");
-}
-
//===----------------------------------------------------------------------===//
// VectorDialect
//===----------------------------------------------------------------------===//
void VectorDialect::initialize() {
- addAttributes<CombiningKindAttr>();
+ addAttributes<
+#define GET_ATTRDEF_LIST
+#include "mlir/Dialect/Vector/IR/VectorOpsAttrDefs.cpp.inc"
+ >();
addOperations<
#define GET_OP_LIST
result.addAttribute(::mlir::getIndexingMapsAttrName(), indexingMaps);
result.addAttribute(::mlir::getIteratorTypesAttrName(), iteratorTypes);
result.addAttribute(ContractionOp::getKindAttrStrName(),
- CombiningKindAttr::get(kind, builder.getContext()));
+ CombiningKindAttr::get(builder.getContext(), kind));
}
ParseResult ContractionOp::parse(OpAsmParser &parser, OperationState &result) {
result.attributes.assign(dictAttr.getValue().begin(),
dictAttr.getValue().end());
if (!result.attributes.get(ContractionOp::getKindAttrStrName())) {
- result.addAttribute(ContractionOp::getKindAttrStrName(),
- CombiningKindAttr::get(ContractionOp::getDefaultKind(),
- result.getContext()));
+ result.addAttribute(
+ ContractionOp::getKindAttrStrName(),
+ CombiningKindAttr::get(result.getContext(),
+ ContractionOp::getDefaultKind()));
}
if (masksInfo.empty())
return success();
if (!result.attributes.get(OuterProductOp::getKindAttrStrName())) {
result.attributes.append(
OuterProductOp::getKindAttrStrName(),
- CombiningKindAttr::get(OuterProductOp::getDefaultKind(),
- result.getContext()));
+ CombiningKindAttr::get(result.getContext(),
+ OuterProductOp::getDefaultKind()));
}
return failure(
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
+#define GET_ATTRDEF_CLASSES
+#include "mlir/Dialect/Vector/IR/VectorOpsAttrDefs.cpp.inc"
+
#define GET_OP_CLASSES
#include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"