[mlir][ods] Add nested OpTrait
authorJacques Pienaar <jpienaar@google.com>
Tue, 20 Jul 2021 17:44:48 +0000 (10:44 -0700)
committerJacques Pienaar <jpienaar@google.com>
Tue, 20 Jul 2021 17:44:48 +0000 (10:44 -0700)
Allows for grouping OpTraits with list of OpTrait to make it easier to group OpTraits together without needing to use list concats (e.g., enable using `[Traits, ..., UsefulGroupOfTraits, Others, ...]` instead of `[Traits, ...] # UsefulGroupOfTraits # [Others, ...]`). Flatten in construction of Operation. This recurses here as the expectation is that these aren't expected to be deeply nested (most likely only 1 level of nesting).

Differential Revision: https://reviews.llvm.org/D106223

mlir/include/mlir/IR/OpBase.td
mlir/lib/TableGen/Operator.cpp
mlir/test/mlir-tblgen/op-decl-and-defs.td

index 5989019..1168dd1 100644 (file)
@@ -1796,6 +1796,14 @@ class PredTrait<string descr, Pred pred> : Trait {
 // TODO: Remove this class in favor of using Trait.
 class OpTrait;
 
+// Define a OpTrait corresponding to a list of OpTraits, this allows for
+// specifying a list of traits as trait. Avoids needing to do
+// `[Traits, ...] # ListOfTraits # [Others, ...]` while still allowing providing
+// convenient groupings.
+class OpTraitList<list<OpTrait> props> : OpTrait {
+  list<OpTrait> traits = props;
+}
+
 // These classes are used to define operation specific traits.
 class NativeOpTrait<string name> : NativeTrait<name, "Op">, OpTrait;
 class ParamNativeOpTrait<string prop, string params>
index 11f95c0..ea9513d 100644 (file)
@@ -489,11 +489,21 @@ void Operator::populateOpStructure() {
     // This is uniquing based on pointers of the trait.
     SmallPtrSet<const llvm::Init *, 32> traitSet;
     traits.reserve(traitSet.size());
-    for (auto *traitInit : *traitList) {
-      // Keep traits in the same order while skipping over duplicates.
-      if (traitSet.insert(traitInit).second)
-        traits.push_back(Trait::create(traitInit));
-    }
+
+    std::function<void(llvm::ListInit *)> insert;
+    insert = [&](llvm::ListInit *traitList) {
+      for (auto *traitInit : *traitList) {
+        auto *def = cast<DefInit>(traitInit)->getDef();
+        if (def->isSubClassOf("OpTraitList")) {
+          insert(def->getValueAsListInit("traits"));
+          continue;
+        }
+        // Keep traits in the same order while skipping over duplicates.
+        if (traitSet.insert(traitInit).second)
+          traits.push_back(Trait::create(traitInit));
+      }
+    };
+    insert(traitList);
   }
 
   populateTypeInferenceInfo(argumentsAndResultsIndex);
index 4fb9ecb..471bac6 100644 (file)
@@ -261,6 +261,15 @@ def NS_JOp : NS_Op<"op_with_InferTypeOpInterface_interface", [DeclareOpInterface
 // CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
 // CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
 
+// Test usage of OpTraitList getting flattened during emission.
+def NS_KOp : NS_Op<"k_op", [IsolatedFromAbove,
+    OpTraitList<[DeclareOpInterfaceMethods<InferTypeOpInterface>]>]> {
+  let arguments = (ins AnyType:$a, AnyType:$b);
+}
+
+// CHECK: class KOp : public ::mlir::Op<KOp,
+// CHECK-SAME: ::mlir::OpTrait::IsIsolatedFromAbove, ::mlir::InferTypeOpInterface::Trait
+
 // Check native OpTrait usage
 // ---