[mlir] Change CombiningKind in Vector dialect to EnumAttr.
authorOleg Shyshkov <shyshkov.oleg@gmail.com>
Wed, 7 Sep 2022 11:33:02 +0000 (13:33 +0200)
committerAlexander Belyaev <pifon@google.com>
Wed, 7 Sep 2022 11:40:45 +0000 (13:40 +0200)
CombiningKind was implemented before EnumAttr, so it reimplements the same behaviour with the custom code. Except for a few places, EnumAttr is a drop-in replacement.

Reviewed By: nicolasvasilache, pifon2a

Differential Revision: https://reviews.llvm.org/D133343

mlir/include/mlir/Dialect/Vector/IR/CMakeLists.txt
mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/test/Dialect/Vector/invalid.mlir
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

index e881e03..2e56afe 100644 (file)
@@ -4,5 +4,7 @@ add_mlir_doc(VectorOps VectorOps Dialects/ -gen-op-doc)
 set(LLVM_TARGET_DEFINITIONS VectorOps.td)
 mlir_tablegen(VectorOpsEnums.h.inc -gen-enum-decls)
 mlir_tablegen(VectorOpsEnums.cpp.inc -gen-enum-defs)
+mlir_tablegen(VectorOpsAttrDefs.h.inc -gen-attrdef-decls)
+mlir_tablegen(VectorOpsAttrDefs.cpp.inc -gen-attrdef-defs)
 add_public_tablegen_target(MLIRVectorOpsEnumsIncGen)
 add_dependencies(mlir-headers MLIRVectorOpsEnumsIncGen)
index d51c559..cd92dfd 100644 (file)
@@ -29,6 +29,9 @@
 // Pull in all enum type definitions and utility function declarations.
 #include "mlir/Dialect/Vector/IR/VectorOpsEnums.h.inc"
 
+#define GET_ATTRDEF_CLASSES
+#include "mlir/Dialect/Vector/IR/VectorOpsAttrDefs.h.inc"
+
 namespace mlir {
 class MLIRContext;
 class RewritePatternSet;
@@ -113,22 +116,6 @@ void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns,
 /// chain.
 void populatePropagateVectorDistributionPatterns(RewritePatternSet &patterns);
 
-/// An attribute that specifies the combining function for `vector.contract`,
-/// and `vector.reduction`.
-class CombiningKindAttr
-    : public Attribute::AttrBase<CombiningKindAttr, Attribute,
-                                 detail::BitmaskEnumStorage> {
-public:
-  using Base::Base;
-
-  static CombiningKindAttr get(CombiningKind kind, MLIRContext *context);
-
-  CombiningKind getKind() const;
-
-  void print(AsmPrinter &p) const;
-  static Attribute parse(AsmParser &parser, Type type);
-};
-
 /// Collects patterns to progressively lower vector.broadcast ops on high-D
 /// vectors to low-D vector ops.
 void populateVectorBroadcastLoweringPatterns(RewritePatternSet &patterns);
index 9a41bda..ec29d53 100644 (file)
@@ -57,15 +57,10 @@ def CombiningKind : I32BitEnumAttr<
   let genSpecializedAttr = 0;
 }
 
-def Vector_CombiningKindAttr : DialectAttr<
-    Vector_Dialect,
-    CPred<"$_self.isa<::mlir::vector::CombiningKindAttr>()">,
-    "Kind of combining function for contractions and reductions"> {
-  let storageType = "::mlir::vector::CombiningKindAttr";
-  let returnType = "::mlir::vector::CombiningKind";
-  let convertFromStorage = "$_self.getKind()";
-  let constBuilderCall =
-          "::mlir::vector::CombiningKindAttr::get($0, $_builder.getContext())";
+/// An attribute that specifies the combining function for `vector.contract`,
+/// and `vector.reduction`.
+def Vector_CombiningKindAttr : EnumAttr<Vector_Dialect, CombiningKind, "kind"> {
+  let assemblyFormat = "`<` $value `>`";
 }
 
 // TODO: Add an attribute to specify a different algebra with operators other
index 3d3b872..79c3719 100644 (file)
@@ -30,8 +30,8 @@
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Support/LLVM.h"
-#include "mlir/Support/MathExtras.h"
 #include "llvm/ADT/StringSet.h"
+#include "llvm/ADT/TypeSwitch.h"
 #include "llvm/ADT/bit.h"
 #include <numeric>
 
@@ -227,91 +227,15 @@ struct BitmaskEnumStorage : public AttributeStorage {
 } // namespace vector
 } // namespace mlir
 
-CombiningKindAttr CombiningKindAttr::get(CombiningKind kind,
-                                         MLIRContext *context) {
-  return Base::get(context, static_cast<uint64_t>(kind));
-}
-
-CombiningKind CombiningKindAttr::getKind() const {
-  return static_cast<CombiningKind>(getImpl()->value);
-}
-
-static constexpr const CombiningKind combiningKindsList[] = {
-    // clang-format off
-    CombiningKind::ADD,
-    CombiningKind::MUL,
-    CombiningKind::MINUI,
-    CombiningKind::MINSI,
-    CombiningKind::MINF,
-    CombiningKind::MAXUI,
-    CombiningKind::MAXSI,
-    CombiningKind::MAXF,
-    CombiningKind::AND,
-    CombiningKind::OR,
-    CombiningKind::XOR,
-    // clang-format on
-};
-
-void CombiningKindAttr::print(AsmPrinter &printer) const {
-  printer << "<";
-  auto kinds = llvm::make_filter_range(combiningKindsList, [&](auto kind) {
-    return bitEnumContains(this->getKind(), kind);
-  });
-  llvm::interleaveComma(kinds, printer,
-                        [&](auto kind) { printer << stringifyEnum(kind); });
-  printer << ">";
-}
-
-Attribute CombiningKindAttr::parse(AsmParser &parser, Type type) {
-  if (failed(parser.parseLess()))
-    return {};
-
-  StringRef elemName;
-  if (failed(parser.parseKeyword(&elemName)))
-    return {};
-
-  auto kind = symbolizeCombiningKind(elemName);
-  if (!kind) {
-    parser.emitError(parser.getNameLoc(), "Unknown combining kind: ")
-        << elemName;
-    return {};
-  }
-
-  if (failed(parser.parseGreater()))
-    return {};
-
-  return CombiningKindAttr::get(*kind, parser.getContext());
-}
-
-Attribute VectorDialect::parseAttribute(DialectAsmParser &parser,
-                                        Type type) const {
-  StringRef attrKind;
-  if (parser.parseKeyword(&attrKind))
-    return {};
-
-  if (attrKind == "kind")
-    return CombiningKindAttr::parse(parser, {});
-
-  parser.emitError(parser.getNameLoc(), "Unknown attribute type: ") << attrKind;
-  return {};
-}
-
-void VectorDialect::printAttribute(Attribute attr,
-                                   DialectAsmPrinter &os) const {
-  if (auto ck = attr.dyn_cast<CombiningKindAttr>()) {
-    os << "kind";
-    ck.print(os);
-    return;
-  }
-  llvm_unreachable("Unknown attribute type");
-}
-
 //===----------------------------------------------------------------------===//
 // VectorDialect
 //===----------------------------------------------------------------------===//
 
 void VectorDialect::initialize() {
-  addAttributes<CombiningKindAttr>();
+  addAttributes<
+#define GET_ATTRDEF_LIST
+#include "mlir/Dialect/Vector/IR/VectorOpsAttrDefs.cpp.inc"
+      >();
 
   addOperations<
 #define GET_OP_LIST
@@ -558,7 +482,7 @@ void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
   result.addAttribute(::mlir::getIndexingMapsAttrName(), indexingMaps);
   result.addAttribute(::mlir::getIteratorTypesAttrName(), iteratorTypes);
   result.addAttribute(ContractionOp::getKindAttrStrName(),
-                      CombiningKindAttr::get(kind, builder.getContext()));
+                      CombiningKindAttr::get(builder.getContext(), kind));
 }
 
 ParseResult ContractionOp::parse(OpAsmParser &parser, OperationState &result) {
@@ -587,9 +511,10 @@ ParseResult ContractionOp::parse(OpAsmParser &parser, OperationState &result) {
   result.attributes.assign(dictAttr.getValue().begin(),
                            dictAttr.getValue().end());
   if (!result.attributes.get(ContractionOp::getKindAttrStrName())) {
-    result.addAttribute(ContractionOp::getKindAttrStrName(),
-                        CombiningKindAttr::get(ContractionOp::getDefaultKind(),
-                                               result.getContext()));
+    result.addAttribute(
+        ContractionOp::getKindAttrStrName(),
+        CombiningKindAttr::get(result.getContext(),
+                               ContractionOp::getDefaultKind()));
   }
   if (masksInfo.empty())
     return success();
@@ -2385,8 +2310,8 @@ ParseResult OuterProductOp::parse(OpAsmParser &parser, OperationState &result) {
   if (!result.attributes.get(OuterProductOp::getKindAttrStrName())) {
     result.attributes.append(
         OuterProductOp::getKindAttrStrName(),
-        CombiningKindAttr::get(OuterProductOp::getDefaultKind(),
-                               result.getContext()));
+        CombiningKindAttr::get(result.getContext(),
+                               OuterProductOp::getDefaultKind()));
   }
 
   return failure(
@@ -5179,5 +5104,8 @@ Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
 // TableGen'd op method definitions
 //===----------------------------------------------------------------------===//
 
+#define GET_ATTRDEF_CLASSES
+#include "mlir/Dialect/Vector/IR/VectorOpsAttrDefs.cpp.inc"
+
 #define GET_OP_CLASSES
 #include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
index 3ebf6df..59474bc 100644 (file)
@@ -1111,7 +1111,8 @@ func.func @bitcast_sizemismatch(%arg0 : vector<5x1x3x2xf32>) {
 // -----
 
 func.func @reduce_unknown_kind(%arg0: vector<16xf32>) -> f32 {
-  // expected-error@+1 {{custom op 'vector.reduction' Unknown combining kind: joho}}
+  // expected-error@+2 {{custom op 'vector.reduction' failed to parse Vector_CombiningKindAttr parameter 'value' which is to be a `::mlir::vector::CombiningKind`}}
+  // expected-error@+1 {{custom op 'vector.reduction' expected ::mlir::vector::CombiningKind to be one of: }}
   %0 = vector.reduction <joho>, %arg0 : vector<16xf32> into f32
 }
 
index bb2465e..68d8250 100644 (file)
@@ -7759,6 +7759,14 @@ gentbl_cc_library(
             "include/mlir/Dialect/Vector/IR/VectorOpsEnums.cpp.inc",
         ),
         (
+            ["-gen-attrdef-decls"],
+            "include/mlir/Dialect/Vector/IR/VectorOpsAttrDefs.h.inc",
+        ),
+        (
+            ["-gen-attrdef-defs"],
+            "include/mlir/Dialect/Vector/IR/VectorOpsAttrDefs.cpp.inc",
+        ),
+        (
             ["-gen-op-doc"],
             "g3doc/Dialects/Vector/VectorOps.md",
         ),