This matches the current support provided to operations, and allows attaching traits, interfaces, and using the DeclareInterfaceMethods utility. This was missed when attribute/type generation was first added.
Differential Revision: https://reviews.llvm.org/D100233
//===----------------------------------------------------------------------===//
class PDL_Type<string name, string typeMnemonic>
- : TypeDef<PDL_Dialect, name, "::mlir::pdl::PDLType"> {
+ : TypeDef<PDL_Dialect, name, [], "::mlir::pdl::PDLType"> {
let mnemonic = typeMnemonic;
}
// Base class for Builtin dialect attributes.
class Builtin_Attr<string name, string baseCppClass = "::mlir::Attribute">
- : AttrDef<Builtin_Dialect, name, baseCppClass> {
+ : AttrDef<Builtin_Dialect, name, [], baseCppClass> {
let mnemonic = ?;
}
// Base class for Builtin dialect location attributes.
class Builtin_LocationAttr<string name>
- : AttrDef<Builtin_Dialect, name, "::mlir::LocationAttr"> {
+ : AttrDef<Builtin_Dialect, name, [], "::mlir::LocationAttr"> {
let cppClassName = name;
let mnemonic = ?;
}
// Base class for Builtin dialect types.
class Builtin_Type<string name, string baseCppClass = "::mlir::Type">
- : TypeDef<Builtin_Dialect, name, baseCppClass> {
+ : TypeDef<Builtin_Dialect, name, [], baseCppClass> {
let mnemonic = ?;
}
//===----------------------------------------------------------------------===//
// Base class for Builtin dialect float types.
-class Builtin_FloatType<string name> : TypeDef<Builtin_Dialect, name,
- "::mlir::FloatType"> {
+class Builtin_FloatType<string name> : Builtin_Type<name, "::mlir::FloatType"> {
let extraClassDeclaration = [{
static }] # name # [{Type get(MLIRContext *context);
}];
class VariadicSuccessor<Successor successor>
: Successor<successor.predicate, successor.summary>;
+
//===----------------------------------------------------------------------===//
-// OpTrait definitions
+// Trait definitions
//===----------------------------------------------------------------------===//
-// OpTrait represents a trait regarding an op.
-class OpTrait;
+// Trait represents a trait regarding an attribute, operation, or type.
+class Trait;
-// NativeOpTrait corresponds to the MLIR C++ OpTrait mechanism. The
-// purpose to wrap around C++ symbol string with this class is to make
-// traits specified for ops in TableGen less alien and more integrated.
-class NativeOpTrait<string name> : OpTrait {
+// NativeTrait corresponds to the MLIR C++ trait mechanism. The purpose to wrap
+// around C++ symbol string with this class is to make traits specified for
+// entities in TableGen less alien and more integrated.
+class NativeTrait<string name, string entityType> : Trait {
string trait = name;
- string cppNamespace = "::mlir::OpTrait";
+ string cppNamespace = "::mlir::" # entityType # "Trait";
}
-// ParamNativeOpTrait corresponds to the template-parameterized traits in the
-// C++ implementation. MLIR uses nested class templates to implement such
-// traits leading to constructs of the form "TraitName<Parameters>::Impl". Use
-// the value in `prop` as the trait name and the value in `params` as
-// parameters to construct the native trait class name.
-class ParamNativeOpTrait<string prop, string params>
- : NativeOpTrait<prop # "<" # params # ">::Impl">;
+// ParamNativeTrait corresponds to the template-parameterized traits in the C++
+// implementation. MLIR uses nested class templates to implement such traits
+// leading to constructs of the form "TraitName<Parameters>::Impl". Use the
+// value in `prop` as the trait name and the value in `params` as parameters to
+// construct the native trait class name.
+class ParamNativeTrait<string prop, string params, string entityType>
+ : NativeTrait<prop # "<" # params # ">::Impl", entityType>;
-// GenInternalOpTrait is an op trait that does not have direct C++ mapping but
-// affects op definition generator internals, like how op builders and
+// GenInternalTrait is a trait that does not have direct C++ mapping but affects
+// an entities definition generator internals, like how operation builders and
// operand/attribute/result getters are generated.
-class GenInternalOpTrait<string prop> : OpTrait {
- string trait = "::mlir::OpTrait::" # prop;
+class GenInternalTrait<string prop, string entityType> : Trait {
+ string trait = "::mlir::" # entityType # "Trait::" # prop;
}
-// PredOpTrait is an op trait implemented by way of a predicate on the op.
-class PredOpTrait<string descr, Pred pred> : OpTrait {
+// PredTrait is a trait implemented by way of a predicate on an entity.
+class PredTrait<string descr, Pred pred> : Trait {
string summary = descr;
Pred predicate = pred;
}
+//===----------------------------------------------------------------------===//
+// OpTrait definitions
+//===----------------------------------------------------------------------===//
+
+// OpTrait represents a trait regarding an operation.
+// TODO: Remove this class in favor of using Trait.
+class OpTrait;
+
+// These classes are used to define operation specific traits.
+class NativeOpTrait<string name> : NativeTrait<name, "Op">, OpTrait;
+class ParamNativeOpTrait<string prop, string params>
+ : ParamNativeTrait<prop, params, "Op">, OpTrait;
+class GenInternalOpTrait<string prop> : GenInternalTrait<prop, "Op">, OpTrait;
+class PredOpTrait<string descr, Pred pred> : PredTrait<descr, pred>, OpTrait;
+
// Op defines an affine scope.
def AffineScope : NativeOpTrait<"AffineScope">;
// Op defines an automatic allocation scope.
string defaultValue = value;
}
-// OpInterfaceTrait corresponds to a specific 'OpInterface' class defined in
-// C++. The purpose to wrap around C++ symbol string with this class is to make
+// InterfaceTrait corresponds to a specific 'Interface' class defined in C++.
+// The purpose to wrap around C++ symbol string with this class is to make
// interfaces specified for ops in TableGen less alien and more integrated.
-class OpInterfaceTrait<string name, code verifyBody = [{}]>
- : NativeOpTrait<""> {
+class InterfaceTrait<string name> : NativeTrait<"", ""> {
let trait = name # "::Trait";
let cppNamespace = "";
- // Specify the body of the verification function. `$_op` will be replaced with
- // the operation being verified.
- code verify = verifyBody;
-
// An optional code block containing extra declarations to place in the
// interface trait declaration.
code extraTraitClassDeclaration = "";
}
+// OpInterfaceTrait corresponds to a specific 'OpInterface' class defined in
+// C++. The purpose to wrap around C++ symbol string with this class is to make
+// interfaces specified for ops in TableGen less alien and more integrated.
+class OpInterfaceTrait<string name, code verifyBody = [{}]>
+ : InterfaceTrait<name>, OpTrait {
+ // Specify the body of the verification function. `$_op` will be replaced with
+ // the operation being verified.
+ code verify = verifyBody;
+}
+
// This class represents a single, optionally static, interface method.
// Note: non-static interface methods have an implicit parameter, either
// $_op/$_attr/$_type corresponding to an instance of the derived value.
}
// AttrInterface represents an interface registered to an attribute.
-class AttrInterface<string name> : Interface<name> {
- // An optional code block containing extra declarations to place in the
- // interface trait declaration.
- code extraTraitClassDeclaration = "";
-}
+class AttrInterface<string name> : Interface<name>, InterfaceTrait<name>;
// OpInterface represents an interface registered to an operation.
class OpInterface<string name> : Interface<name>, OpInterfaceTrait<name>;
// TypeInterface represents an interface registered to a type.
-class TypeInterface<string name> : Interface<name> {
- // An optional code block containing extra declarations to place in the
- // interface trait declaration.
- code extraTraitClassDeclaration = "";
-}
+class TypeInterface<string name> : Interface<name>, InterfaceTrait<name>;
-// Whether to declare the op interface methods in the op's header. This class
-// simply wraps an OpInterface but is used to indicate that the method
+// Whether to declare the interface methods in the user entity's header. This
+// class simply wraps an Interface but is used to indicate that the method
// declarations should be generated. This class takes an optional set of methods
// that should have declarations generated even if the method has a default
// implementation.
+class DeclareInterfaceMethods<Interface interface,
+ list<string> overridenMethods = []> {
+ // This field contains a set of method names that should always have their
+ // declarations generated. This allows for generating declarations for
+ // methods with default implementations that need to be overridden.
+ list<string> alwaysOverriddenMethods = overridenMethods;
+}
+class DeclareAttrInterfaceMethods<AttrInterface interface,
+ list<string> overridenMethods = []>
+ : DeclareInterfaceMethods<interface, overridenMethods>,
+ AttrInterface<interface.cppClassName> {
+ let description = interface.description;
+ let cppClassName = interface.cppClassName;
+ let cppNamespace = interface.cppNamespace;
+ let methods = interface.methods;
+}
class DeclareOpInterfaceMethods<OpInterface interface,
list<string> overridenMethods = []>
- : OpInterface<interface.cppClassName> {
+ : DeclareInterfaceMethods<interface, overridenMethods>,
+ OpInterface<interface.cppClassName> {
+ let description = interface.description;
+ let cppClassName = interface.cppClassName;
+ let cppNamespace = interface.cppNamespace;
+ let methods = interface.methods;
+}
+class DeclareTypeInterfaceMethods<TypeInterface interface,
+ list<string> overridenMethods = []>
+ : DeclareInterfaceMethods<interface, overridenMethods>,
+ TypeInterface<interface.cppClassName> {
let description = interface.description;
let cppClassName = interface.cppClassName;
let cppNamespace = interface.cppNamespace;
let methods = interface.methods;
-
- // This field contains a set of method names that should always have their
- // declarations generated. This allows for generating declarations for
- // methods with default implementations that need to be overridden.
- list<string> alwaysOverriddenMethods = overridenMethods;
}
//===----------------------------------------------------------------------===//
// Define a new attribute or type, named `name`, that inherits from the given
// C++ base class.
-class AttrOrTypeDef<string valueType, string name, string baseCppClass> {
+class AttrOrTypeDef<string valueType, string name, list<Trait> defTraits,
+ string baseCppClass> {
// The name of the C++ base class to use for this def.
string cppBaseClassName = baseCppClass;
// Note that builders should only be provided when a def has parameters.
list<AttrOrTypeBuilder> builders = ?;
+ // The list of traits attached to this def.
+ list<Trait> traits = defTraits;
+
// Use the lowercased name as the keyword for parsing/printing. Specify only
// if you want tblgen to generate declarations and/or definitions of
// the printer/parser.
// Define a new attribute, named `name`, belonging to `dialect` that inherits
// from the given C++ base class.
-class AttrDef<Dialect dialect, string name,
+class AttrDef<Dialect dialect, string name, list<Trait> traits = [],
string baseCppClass = "::mlir::Attribute">
: DialectAttr<dialect, CPred<"">, /*descr*/"">,
- AttrOrTypeDef<"Attr", name, baseCppClass> {
+ AttrOrTypeDef<"Attr", name, traits, baseCppClass> {
// The name of the C++ Attribute class.
string cppClassName = name # "Attr";
// Define a new type, named `name`, belonging to `dialect` that inherits from
// the given C++ base class.
-class TypeDef<Dialect dialect, string name,
+class TypeDef<Dialect dialect, string name, list<Trait> traits = [],
string baseCppClass = "::mlir::Type">
: DialectType<dialect, CPred<"">, /*descr*/"", name # "Type">,
- AttrOrTypeDef<"Type", name, baseCppClass> {
+ AttrOrTypeDef<"Type", name, traits, baseCppClass> {
// A constant builder provided when the type has no parameters.
let builderCall = !if(!empty(parameters),
"$_builder.getType<" # dialect.cppNamespace #
#include "mlir/Support/LLVM.h"
#include "mlir/TableGen/Builder.h"
+#include "mlir/TableGen/Trait.h"
namespace llvm {
class DagInit;
// Returns the builders of this def.
ArrayRef<AttrOrTypeBuilder> getBuilders() const { return builders; }
+ // Returns the traits of this def.
+ ArrayRef<Trait> getTraits() const { return traits; }
+
// Returns whether two AttrOrTypeDefs are equal by checking the equality of
// the underlying record.
bool operator==(const AttrOrTypeDef &other) const;
protected:
const llvm::Record *def;
- // The builders of this type definition.
+ // The builders of this definition.
SmallVector<AttrOrTypeBuilder> builders;
+
+ // The traits of this definition.
+ SmallVector<Trait> traits;
};
//===----------------------------------------------------------------------===//
#include "mlir/TableGen/Attribute.h"
#include "mlir/TableGen/Builder.h"
#include "mlir/TableGen/Dialect.h"
-#include "mlir/TableGen/OpTrait.h"
#include "mlir/TableGen/Region.h"
#include "mlir/TableGen/Successor.h"
+#include "mlir/TableGen/Trait.h"
#include "mlir/TableGen/Type.h"
#include "llvm/ADT/PointerUnion.h"
#include "llvm/ADT/SmallVector.h"
var_decorator_range getArgDecorators(int index) const;
// Returns the trait wrapper for the given MLIR C++ `trait`.
- // TODO: We should add a C++ wrapper class for TableGen OpTrait instead of
- // requiring the raw MLIR trait here.
- const OpTrait *getTrait(llvm::StringRef trait) const;
+ const Trait *getTrait(llvm::StringRef trait) const;
// Regions.
using const_region_iterator = const NamedRegion *;
unsigned getNumVariadicSuccessors() const;
// Trait.
- using const_trait_iterator = const OpTrait *;
+ using const_trait_iterator = const Trait *;
const_trait_iterator trait_begin() const;
const_trait_iterator trait_end() const;
llvm::iterator_range<const_trait_iterator> getTraits() const;
SmallVector<NamedSuccessor, 0> successors;
// The traits of the op.
- SmallVector<OpTrait, 4> traits;
+ SmallVector<Trait, 4> traits;
// The regions of this op.
SmallVector<NamedRegion, 1> regions;
// This class represents an instance of a side effect interface applied to an
// operation. This is a wrapper around an OpInterfaceTrait that also includes
// the effects that are applied.
-class SideEffectTrait : public InterfaceOpTrait {
+class SideEffectTrait : public InterfaceTrait {
public:
// Return the effects that are attached to the side effect interface.
Operator::var_decorator_range getEffects() const;
// Return the name of the base C++ effect.
StringRef getBaseEffectName() const;
- static bool classof(const OpTrait *t);
+ static bool classof(const Trait *t);
};
} // end namespace tblgen
-//===- OpTrait.h - OpTrait wrapper class ------------------------*- C++ -*-===//
+//===- Trait.h - Trait wrapper class ----------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
//
//===----------------------------------------------------------------------===//
//
-// OpTrait wrapper to simplify using TableGen Record defining an MLIR OpTrait.
+// Trait wrapper to simplify using TableGen Record defining an MLIR Trait.
//
//===----------------------------------------------------------------------===//
-#ifndef MLIR_TABLEGEN_OPTRAIT_H_
-#define MLIR_TABLEGEN_OPTRAIT_H_
+#ifndef MLIR_TABLEGEN_TRAIT_H_
+#define MLIR_TABLEGEN_TRAIT_H_
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/StringRef.h"
namespace mlir {
namespace tblgen {
-struct OpInterface;
+class Interface;
-// Wrapper class with helper methods for accessing OpTrait constraints defined
-// in TableGen.
-class OpTrait {
+// Wrapper class with helper methods for accessing Trait constraints defined in
+// TableGen.
+class Trait {
public:
- // Discriminator for kinds of op traits.
+ // Discriminator for kinds of traits.
enum class Kind {
- // OpTrait corresponding to C++ class.
+ // Trait corresponding to C++ class.
Native,
- // OpTrait corresponding to predicate on operation.
+ // Trait corresponding to a predicate.
Pred,
- // OpTrait controlling op definition generator internals.
+ // Trait controlling definition generator internals.
Internal,
- // OpTrait corresponding to OpInterface.
+ // Trait corresponding to an Interface.
Interface
};
- explicit OpTrait(Kind kind, const llvm::Record *def);
+ explicit Trait(Kind kind, const llvm::Record *def);
- // Returns an OpTrait corresponding to the init provided.
- static OpTrait create(const llvm::Init *init);
+ // Returns an Trait corresponding to the init provided.
+ static Trait create(const llvm::Init *init);
Kind getKind() const { return kind; }
Kind kind;
};
-// OpTrait corresponding to a native C++ OpTrait.
-class NativeOpTrait : public OpTrait {
+// Trait corresponding to a native C++ Trait.
+class NativeTrait : public Trait {
public:
// Returns the trait corresponding to a C++ trait class.
- std::string getTrait() const;
+ std::string getFullyQualifiedTraitName() const;
- static bool classof(const OpTrait *t) { return t->getKind() == Kind::Native; }
+ static bool classof(const Trait *t) { return t->getKind() == Kind::Native; }
};
-// OpTrait corresponding to a predicate on the operation.
-class PredOpTrait : public OpTrait {
+// Trait corresponding to a predicate on the operation.
+class PredTrait : public Trait {
public:
// Returns the template for constructing the predicate.
std::string getPredTemplate() const;
// Returns the description of what the predicate is verifying.
StringRef getSummary() const;
- static bool classof(const OpTrait *t) { return t->getKind() == Kind::Pred; }
+ static bool classof(const Trait *t) { return t->getKind() == Kind::Pred; }
};
-// OpTrait controlling op definition generator internals.
-class InternalOpTrait : public OpTrait {
+// Trait controlling op definition generator internals.
+class InternalTrait : public Trait {
public:
// Returns the trait controlling op definition generator internals.
- StringRef getTrait() const;
+ StringRef getFullyQualifiedTraitName() const;
- static bool classof(const OpTrait *t) {
- return t->getKind() == Kind::Internal;
- }
+ static bool classof(const Trait *t) { return t->getKind() == Kind::Internal; }
};
-// OpTrait corresponding to an OpInterface on the operation.
-class InterfaceOpTrait : public OpTrait {
+// Trait corresponding to an OpInterface on the operation.
+class InterfaceTrait : public Trait {
public:
- // Returns member function definitions corresponding to the trait,
- OpInterface getOpInterface() const;
+ // Returns interface corresponding to the trait.
+ Interface getInterface() const;
// Returns the trait corresponding to a C++ trait class.
- std::string getTrait() const;
+ std::string getFullyQualifiedTraitName() const;
- static bool classof(const OpTrait *t) {
+ static bool classof(const Trait *t) {
return t->getKind() == Kind::Interface;
}
} // end namespace tblgen
} // end namespace mlir
-#endif // MLIR_TABLEGEN_OPTRAIT_H_
+#endif // MLIR_TABLEGEN_TRAIT_H_
#include "mlir/TableGen/AttrOrTypeDef.h"
#include "mlir/TableGen/Dialect.h"
+#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/Record.h"
builders.emplace_back(builder);
}
}
+
+ // Populate the traits.
+ if (auto *traitList = def->getValueAsListInit("traits")) {
+ SmallPtrSet<const llvm::Init *, 32> traitSet;
+ traits.reserve(traitSet.size());
+ for (auto *traitInit : *traitList)
+ if (traitSet.insert(traitInit).second)
+ traits.push_back(Trait::create(traitInit));
+ }
}
Dialect AttrOrTypeDef::getDialect() const {
Interfaces.cpp
Operator.cpp
OpClass.cpp
- OpTrait.cpp
Pass.cpp
Pattern.cpp
Predicate.cpp
Region.cpp
SideEffects.cpp
Successor.cpp
+ Trait.cpp
Type.cpp
DISABLE_LLVM_LINK_LLVM_DYLIB
+++ /dev/null
-//===- OpTrait.cpp - OpTrait class ----------------------------------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// OpTrait wrapper to simplify using TableGen Record defining a MLIR OpTrait.
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/TableGen/OpTrait.h"
-#include "mlir/TableGen/Interfaces.h"
-#include "mlir/TableGen/Predicate.h"
-#include "llvm/ADT/StringExtras.h"
-#include "llvm/Support/FormatVariadic.h"
-#include "llvm/TableGen/Error.h"
-#include "llvm/TableGen/Record.h"
-
-using namespace mlir;
-using namespace mlir::tblgen;
-
-OpTrait OpTrait::create(const llvm::Init *init) {
- auto def = cast<llvm::DefInit>(init)->getDef();
- if (def->isSubClassOf("PredOpTrait"))
- return OpTrait(Kind::Pred, def);
- if (def->isSubClassOf("GenInternalOpTrait"))
- return OpTrait(Kind::Internal, def);
- if (def->isSubClassOf("OpInterfaceTrait"))
- return OpTrait(Kind::Interface, def);
- assert(def->isSubClassOf("NativeOpTrait"));
- return OpTrait(Kind::Native, def);
-}
-
-OpTrait::OpTrait(Kind kind, const llvm::Record *def) : def(def), kind(kind) {}
-
-std::string NativeOpTrait::getTrait() const {
- llvm::StringRef trait = def->getValueAsString("trait");
- llvm::StringRef cppNamespace = def->getValueAsString("cppNamespace");
- return cppNamespace.empty() ? trait.str()
- : (cppNamespace + "::" + trait).str();
-}
-
-llvm::StringRef InternalOpTrait::getTrait() const {
- return def->getValueAsString("trait");
-}
-
-std::string PredOpTrait::getPredTemplate() const {
- auto pred = Pred(def->getValueInit("predicate"));
- return pred.getCondition();
-}
-
-llvm::StringRef PredOpTrait::getSummary() const {
- return def->getValueAsString("summary");
-}
-
-OpInterface InterfaceOpTrait::getOpInterface() const {
- return OpInterface(def);
-}
-
-std::string InterfaceOpTrait::getTrait() const {
- llvm::StringRef trait = def->getValueAsString("trait");
- llvm::StringRef cppNamespace = def->getValueAsString("cppNamespace");
- return cppNamespace.empty() ? trait.str()
- : (cppNamespace + "::" + trait).str();
-}
-
-bool InterfaceOpTrait::shouldDeclareMethods() const {
- return def->isSubClassOf("DeclareOpInterfaceMethods");
-}
-
-std::vector<StringRef> InterfaceOpTrait::getAlwaysDeclaredMethods() const {
- return def->getValueAsListOfStrings("alwaysOverriddenMethods");
-}
//===----------------------------------------------------------------------===//
#include "mlir/TableGen/Operator.h"
-#include "mlir/TableGen/OpTrait.h"
#include "mlir/TableGen/Predicate.h"
+#include "mlir/TableGen/Trait.h"
#include "mlir/TableGen/Type.h"
#include "llvm/ADT/EquivalenceClasses.h"
#include "llvm/ADT/STLExtras.h"
return *arg->getValueAsListInit("decorators");
}
-const OpTrait *Operator::getTrait(StringRef trait) const {
+const Trait *Operator::getTrait(StringRef trait) const {
for (const auto &t : traits) {
- if (const auto *opTrait = dyn_cast<NativeOpTrait>(&t)) {
- if (opTrait->getTrait() == trait)
- return opTrait;
- } else if (const auto *opTrait = dyn_cast<InternalOpTrait>(&t)) {
- if (opTrait->getTrait() == trait)
- return opTrait;
- } else if (const auto *opTrait = dyn_cast<InterfaceOpTrait>(&t)) {
- if (opTrait->getTrait() == trait)
- return opTrait;
+ if (const auto *traitDef = dyn_cast<NativeTrait>(&t)) {
+ if (traitDef->getFullyQualifiedTraitName() == trait)
+ return traitDef;
+ } else if (const auto *traitDef = dyn_cast<InternalTrait>(&t)) {
+ if (traitDef->getFullyQualifiedTraitName() == trait)
+ return traitDef;
+ } else if (const auto *traitDef = dyn_cast<InterfaceTrait>(&t)) {
+ if (traitDef->getFullyQualifiedTraitName() == trait)
+ return traitDef;
}
}
return nullptr;
return found;
};
- for (const OpTrait &trait : traits) {
+ for (const Trait &trait : traits) {
const llvm::Record &def = trait.getDef();
// If the infer type op interface was manually added, then treat it as
// intention that the op needs special handling.
if (def.isSubClassOf(
llvm::formatv("{0}::Trait", inferTypeOpInterface).str()))
return;
- if (const auto *opTrait = dyn_cast<InterfaceOpTrait>(&trait))
- if (&opTrait->getDef() == inferTrait)
+ if (const auto *traitDef = dyn_cast<InterfaceTrait>(&trait))
+ if (&traitDef->getDef() == inferTrait)
return;
if (!def.isSubClassOf("AllTypesMatch"))
// If the types could be computed, then add type inference trait.
if (allResultsHaveKnownTypes)
- traits.push_back(OpTrait::create(inferTrait->getDefInit()));
+ traits.push_back(Trait::create(inferTrait->getDefInit()));
}
void Operator::populateOpStructure() {
for (auto *traitInit : *traitList) {
// Keep traits in the same order while skipping over duplicates.
if (traitSet.insert(traitInit).second)
- traits.push_back(OpTrait::create(traitInit));
+ traits.push_back(Trait::create(traitInit));
}
}
return def->getValueAsString("baseEffectName");
}
-bool SideEffectTrait::classof(const OpTrait *t) {
+bool SideEffectTrait::classof(const Trait *t) {
return t->getDef().isSubClassOf("SideEffectsTraitBase");
}
--- /dev/null
+//===- Trait.cpp ----------------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Trait wrapper to simplify using TableGen Record defining a MLIR Trait.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/TableGen/Trait.h"
+#include "mlir/TableGen/Interfaces.h"
+#include "mlir/TableGen/Predicate.h"
+#include "llvm/ADT/StringExtras.h"
+#include "llvm/Support/FormatVariadic.h"
+#include "llvm/TableGen/Error.h"
+#include "llvm/TableGen/Record.h"
+
+using namespace mlir;
+using namespace mlir::tblgen;
+
+//===----------------------------------------------------------------------===//
+// Trait
+//===----------------------------------------------------------------------===//
+
+Trait Trait::create(const llvm::Init *init) {
+ auto def = cast<llvm::DefInit>(init)->getDef();
+ if (def->isSubClassOf("PredTrait"))
+ return Trait(Kind::Pred, def);
+ if (def->isSubClassOf("GenInternalTrait"))
+ return Trait(Kind::Internal, def);
+ if (def->isSubClassOf("InterfaceTrait"))
+ return Trait(Kind::Interface, def);
+ assert(def->isSubClassOf("NativeTrait"));
+ return Trait(Kind::Native, def);
+}
+
+Trait::Trait(Kind kind, const llvm::Record *def) : def(def), kind(kind) {}
+
+//===----------------------------------------------------------------------===//
+// NativeTrait
+//===----------------------------------------------------------------------===//
+
+std::string NativeTrait::getFullyQualifiedTraitName() const {
+ llvm::StringRef trait = def->getValueAsString("trait");
+ llvm::StringRef cppNamespace = def->getValueAsString("cppNamespace");
+ return cppNamespace.empty() ? trait.str()
+ : (cppNamespace + "::" + trait).str();
+}
+
+//===----------------------------------------------------------------------===//
+// InternalTrait
+//===----------------------------------------------------------------------===//
+
+llvm::StringRef InternalTrait::getFullyQualifiedTraitName() const {
+ return def->getValueAsString("trait");
+}
+
+//===----------------------------------------------------------------------===//
+// PredTrait
+//===----------------------------------------------------------------------===//
+
+std::string PredTrait::getPredTemplate() const {
+ auto pred = Pred(def->getValueInit("predicate"));
+ return pred.getCondition();
+}
+
+llvm::StringRef PredTrait::getSummary() const {
+ return def->getValueAsString("summary");
+}
+
+//===----------------------------------------------------------------------===//
+// InterfaceTrait
+//===----------------------------------------------------------------------===//
+
+Interface InterfaceTrait::getInterface() const { return Interface(def); }
+
+std::string InterfaceTrait::getFullyQualifiedTraitName() const {
+ llvm::StringRef trait = def->getValueAsString("trait");
+ llvm::StringRef cppNamespace = def->getValueAsString("cppNamespace");
+ return cppNamespace.empty() ? trait.str()
+ : (cppNamespace + "::" + trait).str();
+}
+
+bool InterfaceTrait::shouldDeclareMethods() const {
+ return def->isSubClassOf("DeclareInterfaceMethods");
+}
+
+std::vector<StringRef> InterfaceTrait::getAlwaysDeclaredMethods() const {
+ return def->getValueAsListOfStrings("alwaysOverriddenMethods");
+}
// A type interface used to test the ODS generation of type interfaces.
def TestTypeInterface : TypeInterface<"TestTypeInterface"> {
+ let cppNamespace = "::mlir::test";
let methods = [
InterfaceMethod<"Prints the type name.",
"void", "printTypeA", (ins "Location":$loc), [{
// To get the test dialect def.
include "TestOps.td"
+include "mlir/Interfaces/DataLayoutInterfaces.td"
// All of the types will extend this class.
-class Test_Type<string name> : TypeDef<Test_Dialect, name> { }
+class Test_Type<string name, list<Trait> traits = []>
+ : TypeDef<Test_Dialect, name, traits>;
def SimpleTypeA : Test_Type<"SimpleA"> {
let mnemonic = "smpla";
let mnemonic = "struct";
}
+def TestType : Test_Type<"Test", [
+ DeclareTypeInterfaceMethods<TestTypeInterface>
+]> {
+ let mnemonic = "test_type";
+}
+
+def TestTypeWithLayoutType : Test_Type<"TestTypeWithLayout", [
+ DeclareTypeInterfaceMethods<DataLayoutTypeInterface, ["areCompatible"]>
+]> {
+ let mnemonic = "test_type_with_layout";
+ let parameters = (ins "unsigned":$key);
+ let extraClassDeclaration = [{
+ LogicalResult verifyEntries(DataLayoutEntryListRef params,
+ Location loc) const;
+
+ private:
+ unsigned extractKind(DataLayoutEntryListRef params,
+ StringRef expectedKind) const;
+
+ public:
+ }];
+}
+
#endif // TEST_TYPEDEFS
}
}
+// The functions don't need to be in the header file, but need to be in the mlir
+// namespace. Declare them here, then define them immediately below. Separating
+// the declaration and definition adheres to the LLVM coding standards.
+namespace mlir {
+namespace test {
+// FieldInfo is used as part of a parameter, so equality comparison is
+// compulsory.
+static bool operator==(const FieldInfo &a, const FieldInfo &b);
+// FieldInfo is used as part of a parameter, so a hash will be computed.
+static llvm::hash_code hash_value(const FieldInfo &fi); // NOLINT
+} // namespace test
+} // namespace mlir
+
+// FieldInfo is used as part of a parameter, so equality comparison is
+// compulsory.
+static bool mlir::test::operator==(const FieldInfo &a, const FieldInfo &b) {
+ return a.name == b.name && a.type == b.type;
+}
+
+// FieldInfo is used as part of a parameter, so a hash will be computed.
+static llvm::hash_code mlir::test::hash_value(const FieldInfo &fi) { // NOLINT
+ return llvm::hash_combine(fi.name, fi.type);
+}
+
+//===----------------------------------------------------------------------===//
+// CompoundAType
+//===----------------------------------------------------------------------===//
+
Type CompoundAType::parse(MLIRContext *ctxt, DialectAsmParser &parser) {
int widthOfSomething;
Type oneType;
printer << "]>";
}
-// The functions don't need to be in the header file, but need to be in the mlir
-// namespace. Declare them here, then define them immediately below. Separating
-// the declaration and definition adheres to the LLVM coding standards.
-namespace mlir {
-namespace test {
-// FieldInfo is used as part of a parameter, so equality comparison is
-// compulsory.
-static bool operator==(const FieldInfo &a, const FieldInfo &b);
-// FieldInfo is used as part of a parameter, so a hash will be computed.
-static llvm::hash_code hash_value(const FieldInfo &fi); // NOLINT
-} // namespace test
-} // namespace mlir
-
-// FieldInfo is used as part of a parameter, so equality comparison is
-// compulsory.
-static bool mlir::test::operator==(const FieldInfo &a, const FieldInfo &b) {
- return a.name == b.name && a.type == b.type;
-}
-
-// FieldInfo is used as part of a parameter, so a hash will be computed.
-static llvm::hash_code mlir::test::hash_value(const FieldInfo &fi) { // NOLINT
- return llvm::hash_combine(fi.name, fi.type);
-}
+//===----------------------------------------------------------------------===//
+// TestIntegerType
+//===----------------------------------------------------------------------===//
// Example type validity checker.
LogicalResult
}
//===----------------------------------------------------------------------===//
-// Tablegen Generated Definitions
+// TestType
//===----------------------------------------------------------------------===//
-#define GET_TYPEDEF_CLASSES
-#include "TestTypeDefs.cpp.inc"
+void TestType::printTypeC(Location loc) const {
+ emitRemark(loc) << *this << " - TestC";
+}
+
+//===----------------------------------------------------------------------===//
+// TestTypeWithLayout
+//===----------------------------------------------------------------------===//
+
+Type TestTypeWithLayoutType::parse(MLIRContext *ctx, DialectAsmParser &parser) {
+ unsigned val;
+ if (parser.parseLess() || parser.parseInteger(val) || parser.parseGreater())
+ return Type();
+ return TestTypeWithLayoutType::get(ctx, val);
+}
+
+void TestTypeWithLayoutType::print(DialectAsmPrinter &printer) const {
+ printer << "test_type_with_layout<" << getKey() << ">";
+}
+
+unsigned
+TestTypeWithLayoutType::getTypeSizeInBits(const DataLayout &dataLayout,
+ DataLayoutEntryListRef params) const {
+ return extractKind(params, "size");
+}
-LogicalResult TestTypeWithLayout::verifyEntries(DataLayoutEntryListRef params,
- Location loc) const {
+unsigned
+TestTypeWithLayoutType::getABIAlignment(const DataLayout &dataLayout,
+ DataLayoutEntryListRef params) const {
+ return extractKind(params, "alignment");
+}
+
+unsigned TestTypeWithLayoutType::getPreferredAlignment(
+ const DataLayout &dataLayout, DataLayoutEntryListRef params) const {
+ return extractKind(params, "preferred");
+}
+
+bool TestTypeWithLayoutType::areCompatible(
+ DataLayoutEntryListRef oldLayout, DataLayoutEntryListRef newLayout) const {
+ unsigned old = extractKind(oldLayout, "alignment");
+ return old == 1 || extractKind(newLayout, "alignment") <= old;
+}
+
+LogicalResult
+TestTypeWithLayoutType::verifyEntries(DataLayoutEntryListRef params,
+ Location loc) const {
for (DataLayoutEntryInterface entry : params) {
// This is for testing purposes only, so assert well-formedness.
assert(entry.isTypeEntry() && "unexpected identifier entry");
- assert(entry.getKey().get<Type>().isa<TestTypeWithLayout>() &&
+ assert(entry.getKey().get<Type>().isa<TestTypeWithLayoutType>() &&
"wrong type passed in");
auto array = entry.getValue().dyn_cast<ArrayAttr>();
assert(array && array.getValue().size() == 2 &&
return success();
}
-unsigned TestTypeWithLayout::extractKind(DataLayoutEntryListRef params,
- StringRef expectedKind) const {
+unsigned TestTypeWithLayoutType::extractKind(DataLayoutEntryListRef params,
+ StringRef expectedKind) const {
for (DataLayoutEntryInterface entry : params) {
ArrayRef<Attribute> pair = entry.getValue().cast<ArrayAttr>().getValue();
StringRef kind = pair.front().cast<StringAttr>().getValue();
}
//===----------------------------------------------------------------------===//
+// Tablegen Generated Definitions
+//===----------------------------------------------------------------------===//
+
+#define GET_TYPEDEF_CLASSES
+#include "TestTypeDefs.cpp.inc"
+
+//===----------------------------------------------------------------------===//
// TestDialect
//===----------------------------------------------------------------------===//
void TestDialect::registerTypes() {
- addTypes<TestType, TestTypeWithLayout, TestRecursiveType,
+ addTypes<TestRecursiveType,
#define GET_TYPEDEF_LIST
#include "TestTypeDefs.cpp.inc"
>();
if (parseResult.hasValue())
return genType;
}
- if (typeTag == "test_type")
- return TestType::get(parser.getBuilder().getContext());
-
- if (typeTag == "test_type_with_layout") {
- unsigned val;
- if (parser.parseLess() || parser.parseInteger(val) ||
- parser.parseGreater()) {
- return Type();
- }
- return TestTypeWithLayout::get(parser.getBuilder().getContext(), val);
- }
if (typeTag != "test_rec") {
parser.emitError(parser.getNameLoc()) << "unknown type!";
llvm::SetVector<Type> &stack) {
if (succeeded(generatedTypePrinter(type, printer)))
return;
- if (type.isa<TestType>()) {
- printer << "test_type";
- return;
- }
-
- if (auto t = type.dyn_cast<TestTypeWithLayout>()) {
- printer << "test_type_with_layout<" << t.getKey() << ">";
- return;
- }
auto rec = type.cast<TestRecursiveType>();
printer << "test_rec<" << rec.getName();
} // namespace test
} // namespace mlir
+#include "TestTypeInterfaces.h.inc"
+
#define GET_TYPEDEF_CLASSES
#include "TestTypeDefs.h.inc"
namespace mlir {
namespace test {
-#include "TestTypeInterfaces.h.inc"
-
-/// This class is a simple test type that uses a generated interface.
-struct TestType : public Type::TypeBase<TestType, Type, TypeStorage,
- TestTypeInterface::Trait> {
- using Base::Base;
-
- /// Provide a definition for the necessary interface methods.
- void printTypeC(Location loc) const {
- emitRemark(loc) << *this << " - TestC";
- }
-};
-
/// Storage for simple named recursive types, where the type is identified by
/// its name and can "contain" another type, including itself.
struct TestRecursiveTypeStorage : public TypeStorage {
StringRef getName() { return getImpl()->name; }
};
-struct TestTypeWithLayoutStorage : public TypeStorage {
- using KeyTy = unsigned;
-
- explicit TestTypeWithLayoutStorage(unsigned key) : key(key) {}
- bool operator==(const KeyTy &other) const { return other == key; }
-
- static TestTypeWithLayoutStorage *construct(TypeStorageAllocator &allocator,
- const KeyTy &key) {
- return new (allocator.allocate<TestTypeWithLayoutStorage>())
- TestTypeWithLayoutStorage(key);
- }
-
- unsigned key;
-};
-
-class TestTypeWithLayout
- : public Type::TypeBase<TestTypeWithLayout, Type, TestTypeWithLayoutStorage,
- DataLayoutTypeInterface::Trait> {
-public:
- using Base::Base;
-
- static TestTypeWithLayout get(MLIRContext *ctx, unsigned key) {
- return Base::get(ctx, key);
- }
-
- unsigned getKey() { return getImpl()->key; }
-
- unsigned getTypeSizeInBits(const DataLayout &dataLayout,
- DataLayoutEntryListRef params) const {
- return extractKind(params, "size");
- }
-
- unsigned getABIAlignment(const DataLayout &dataLayout,
- DataLayoutEntryListRef params) const {
- return extractKind(params, "alignment");
- }
-
- unsigned getPreferredAlignment(const DataLayout &dataLayout,
- DataLayoutEntryListRef params) const {
- return extractKind(params, "preferred");
- }
-
- bool areCompatible(DataLayoutEntryListRef oldLayout,
- DataLayoutEntryListRef newLayout) const {
- unsigned old = extractKind(oldLayout, "alignment");
- return old == 1 || extractKind(newLayout, "alignment") <= old;
- }
-
- LogicalResult verifyEntries(DataLayoutEntryListRef params,
- Location loc) const;
-
-private:
- unsigned extractKind(DataLayoutEntryListRef params,
- StringRef expectedKind) const;
-};
-
} // namespace test
} // namespace mlir
// DEF: CompoundAAttrStorage (
// DEF-NEXT: : ::mlir::AttributeStorage(inner),
-// DEF: bool operator==(const KeyTy &key) const {
-// DEF-NEXT: if (!(widthOfSomething == std::get<0>(key)))
+// DEF: bool operator==(const KeyTy &tblgenKey) const {
+// DEF-NEXT: if (!(widthOfSomething == std::get<0>(tblgenKey)))
// DEF-NEXT: return false;
-// DEF-NEXT: if (!(exampleTdType == std::get<1>(key)))
+// DEF-NEXT: if (!(exampleTdType == std::get<1>(tblgenKey)))
// DEF-NEXT: return false;
-// DEF-NEXT: if (!(apFloat.bitwiseIsEqual(std::get<2>(key))))
+// DEF-NEXT: if (!(apFloat.bitwiseIsEqual(std::get<2>(tblgenKey))))
// DEF-NEXT: return false;
-// DEF-NEXT: if (!(dims == std::get<3>(key)))
+// DEF-NEXT: if (!(dims == std::get<3>(tblgenKey)))
// DEF-NEXT: return false;
-// DEF-NEXT: if (!(getType() == std::get<4>(key)))
+// DEF-NEXT: if (!(getType() == std::get<4>(tblgenKey)))
// DEF-NEXT: return false;
// DEF-NEXT: return true;
#include "mlir/TableGen/CodeGenHelpers.h"
#include "mlir/TableGen/Format.h"
#include "mlir/TableGen/GenInfo.h"
+#include "mlir/TableGen/Interfaces.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/SmallSet.h"
+#include "llvm/ADT/StringSet.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/TableGenBackend.h"
/// {1}: The name of the type base class.
/// {2}: The name of the base value type, e.g. Attribute or Type.
/// {3}: The tablegen record type prefix, e.g. Attr or Type.
+/// {4}: The traits of the def class.
static const char *const defDeclSingletonBeginStr = R"(
- class {0} : public ::mlir::{2}::{3}Base<{0}, {1}, ::mlir::{2}Storage> {{
+ class {0} : public ::mlir::{2}::{3}Base<{0}, {1}, ::mlir::{2}Storage{4}> {{
public:
/// Inherit some necessary constructors from '{3}Base'.
using Base::Base;
)";
-/// The code block for the start of a typeDef class declaration -- parametric
-/// case.
+/// The code block for the start of a class declaration -- parametric case.
///
-/// {0}: The name of the typeDef class.
-/// {1}: The name of the type base class.
-/// {2}: The typeDef storage class namespace.
+/// {0}: The name of the def class.
+/// {1}: The name of the base class.
+/// {2}: The def storage class namespace.
/// {3}: The storage class name.
/// {4}: The name of the base value type, e.g. Attribute or Type.
/// {5}: The tablegen record type prefix, e.g. Attr or Type.
+/// {6}: The traits of the def class.
static const char *const defDeclParametricBeginStr = R"(
namespace {2} {
struct {3};
} // end namespace {2}
class {0} : public ::mlir::{4}::{5}Base<{0}, {1},
- {2}::{3}> {{
+ {2}::{3}{6}> {{
public:
/// Inherit some necessary constructors from '{5}Base'.
using Base::Base;
}
}
+static void emitInterfaceMethodDecls(const InterfaceTrait *trait,
+ raw_ostream &os) {
+ Interface interface = trait->getInterface();
+
+ // Get the set of methods that should always be declared.
+ auto alwaysDeclaredMethodsVec = trait->getAlwaysDeclaredMethods();
+ llvm::StringSet<> alwaysDeclaredMethods;
+ alwaysDeclaredMethods.insert(alwaysDeclaredMethodsVec.begin(),
+ alwaysDeclaredMethodsVec.end());
+
+ for (const InterfaceMethod &method : interface.getMethods()) {
+ // Don't declare if the method has a body.
+ if (method.getBody())
+ continue;
+ // Don't declare if the method has a default implementation and the def
+ // didn't request that it always be declared.
+ if (method.getDefaultImplementation() &&
+ !alwaysDeclaredMethods.count(method.getName()))
+ continue;
+
+ // Emit the method declaration.
+ os << " " << (method.isStatic() ? "static " : "")
+ << method.getReturnType() << " " << method.getName() << "(";
+ llvm::interleaveComma(method.getArguments(), os,
+ [&](const InterfaceMethod::Argument &arg) {
+ os << arg.type << " " << arg.name;
+ });
+ os << ")" << (method.isStatic() ? "" : " const") << ";\n";
+ }
+}
+
void DefGenerator::emitDefDecl(const AttrOrTypeDef &def) {
SmallVector<AttrOrTypeParameter, 4> params;
def.getParameters(params);
+ // Build the trait list for this def.
+ std::vector<std::string> traitList;
+ StringSet<> traitSet;
+ for (const Trait &baseTrait : def.getTraits()) {
+ std::string traitStr;
+ if (const auto *trait = dyn_cast<NativeTrait>(&baseTrait))
+ traitStr = trait->getFullyQualifiedTraitName();
+ else if (const auto *trait = dyn_cast<InterfaceTrait>(&baseTrait))
+ traitStr = trait->getFullyQualifiedTraitName();
+ else
+ llvm_unreachable("unexpected Attribute/Type trait type");
+
+ if (traitSet.insert(traitStr).second)
+ traitList.emplace_back(std::move(traitStr));
+ }
+ std::string traitStr;
+ if (!traitList.empty())
+ traitStr = ", " + llvm::join(traitList, ", ");
+
// Emit the beginning string template: either the singleton or parametric
// template.
if (def.getNumParameters() == 0) {
os << formatv(defDeclSingletonBeginStr, def.getCppClassName(),
- def.getCppBaseClassName(), valueType, defTypePrefix);
+ def.getCppBaseClassName(), valueType, defTypePrefix,
+ traitStr);
} else {
os << formatv(defDeclParametricBeginStr, def.getCppClassName(),
def.getCppBaseClassName(), def.getStorageNamespace(),
- def.getStorageClassName(), valueType, defTypePrefix);
+ def.getStorageClassName(), valueType, defTypePrefix,
+ traitStr);
}
// Emit the extra declarations first in case there's a definition in there.
}
}
+ // Emit any interface method declarations.
+ for (const Trait &trait : def.getTraits()) {
+ if (const auto *traitDef = dyn_cast<InterfaceTrait>(&trait)) {
+ if (traitDef->shouldDeclareMethods())
+ emitInterfaceMethodDecls(traitDef, os);
+ }
+ }
+
// End the decl.
os << " };\n";
}
/// Define a construction method for creating a new instance of this
/// storage.
static {0} *construct(::mlir::{1}StorageAllocator &allocator,
- const KeyTy &key) {{
+ const KeyTy &tblgenKey) {{
)";
/// The storage class' constructor return template.
paramInitializer, parameterTypeList, valueType);
// * Emit the comparison method.
- os << " bool operator==(const KeyTy &key) const {\n";
+ os << " bool operator==(const KeyTy &tblgenKey) const {\n";
for (auto it : llvm::enumerate(params)) {
os << " if (!(";
bool isSelfType = isa<AttributeSelfTypeParameter>(it.value());
FmtContext context;
context.addSubst("_lhs", isSelfType ? "getType()" : it.value().getName())
- .addSubst("_rhs", "std::get<" + Twine(it.index()) + ">(key)");
+ .addSubst("_rhs", "std::get<" + Twine(it.index()) + ">(tblgenKey)");
// Use the parameter specified comparator if possible, otherwise default to
// operator==.
os << " return true;\n }\n";
// * Emit the haskKey method.
- os << " static ::llvm::hash_code hashKey(const KeyTy &key) {\n";
+ os << " static ::llvm::hash_code hashKey(const KeyTy &tblgenKey) {\n";
// Extract each parameter from the key.
os << " return ::llvm::hash_combine(";
llvm::interleaveComma(
llvm::seq<unsigned>(0, params.size()), os,
- [&](unsigned it) { os << "std::get<" << it << ">(key)"; });
+ [&](unsigned it) { os << "std::get<" << it << ">(tblgenKey)"; });
os << ");\n }\n";
// * Emit the construct method.
// here and then they can write the definition elsewhere.
if (def.hasStorageCustomConstructor()) {
os << llvm::formatv(" static {0} *construct(::mlir::{1}StorageAllocator "
- "&allocator, const KeyTy &key);\n",
+ "&allocator, const KeyTy &tblgenKey);\n",
def.getStorageClassName(), valueType);
// Otherwise, generate one.
os << formatv(defStorageClassConstructorBeginStr, def.getStorageClassName(),
valueType);
for (unsigned i = 0, e = params.size(); i < e; ++i) {
- os << formatv(" auto {0} = std::get<{1}>(key);\n",
+ os << formatv(" auto {0} = std::get<{1}>(tblgenKey);\n",
params[i].getName(), i);
}
#include "mlir/TableGen/GenInfo.h"
#include "mlir/TableGen/Interfaces.h"
#include "mlir/TableGen/OpClass.h"
-#include "mlir/TableGen/OpTrait.h"
#include "mlir/TableGen/Operator.h"
+#include "mlir/TableGen/Trait.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/CommandLine.h"
#include "mlir/TableGen/GenInfo.h"
#include "mlir/TableGen/Interfaces.h"
#include "mlir/TableGen/OpClass.h"
-#include "mlir/TableGen/OpTrait.h"
#include "mlir/TableGen/Operator.h"
#include "mlir/TableGen/SideEffects.h"
+#include "mlir/TableGen/Trait.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/Path.h"
void genOpInterfaceMethods();
// Generate op interface methods for the given interface.
- void genOpInterfaceMethods(const tblgen::InterfaceOpTrait *trait);
+ void genOpInterfaceMethods(const tblgen::InterfaceTrait *trait);
// Generate op interface method for the given interface method. If
// 'declaration' is true, generates a declaration, else a definition.
}
}
-void OpEmitter::genOpInterfaceMethods(const tblgen::InterfaceOpTrait *opTrait) {
- auto interface = opTrait->getOpInterface();
+void OpEmitter::genOpInterfaceMethods(const tblgen::InterfaceTrait *opTrait) {
+ Interface interface = opTrait->getInterface();
// Get the set of methods that should always be declared.
auto alwaysDeclaredMethodsVec = opTrait->getAlwaysDeclaredMethods();
void OpEmitter::genOpInterfaceMethods() {
for (const auto &trait : op.getTraits()) {
- if (const auto *opTrait = dyn_cast<tblgen::InterfaceOpTrait>(&trait))
+ if (const auto *opTrait = dyn_cast<tblgen::InterfaceTrait>(&trait))
if (opTrait->shouldDeclareMethods())
genOpInterfaceMethods(opTrait);
}
return;
// Generate 'inferReturnTypes' method declaration using the interface method
// declared in 'InferTypeOpInterface' op interface.
- const auto *trait = dyn_cast<InterfaceOpTrait>(
+ const auto *trait = dyn_cast<InterfaceTrait>(
op.getTrait("::mlir::InferTypeOpInterface::Trait"));
- auto interface = trait->getOpInterface();
+ Interface interface = trait->getInterface();
OpMethod *method = [&]() -> OpMethod * {
for (const InterfaceMethod &interfaceMethod : interface.getMethods()) {
if (interfaceMethod.getName() == "inferReturnTypes") {
genOperandResultVerifier(body, op.getResults(), "result");
for (auto &trait : op.getTraits()) {
- if (auto *t = dyn_cast<tblgen::PredOpTrait>(&trait)) {
+ if (auto *t = dyn_cast<tblgen::PredTrait>(&trait)) {
body << tgfmt(" if (!($0))\n "
"return emitOpError(\"failed to verify that $1\");\n",
&verifyCtx, tgfmt(t->getPredTemplate(), &verifyCtx),
// Add the native and interface traits.
for (const auto &trait : op.getTraits()) {
- if (auto opTrait = dyn_cast<tblgen::NativeOpTrait>(&trait))
- opClass.addTrait(opTrait->getTrait());
- else if (auto opTrait = dyn_cast<tblgen::InterfaceOpTrait>(&trait))
- opClass.addTrait(opTrait->getTrait());
+ if (auto opTrait = dyn_cast<tblgen::NativeTrait>(&trait))
+ opClass.addTrait(opTrait->getFullyQualifiedTraitName());
+ else if (auto opTrait = dyn_cast<tblgen::InterfaceTrait>(&trait))
+ opClass.addTrait(opTrait->getFullyQualifiedTraitName());
}
}
// Verify a few traits first so that we can use
// getODSOperands()/getODSResults() in the rest of the verifier.
for (auto &trait : op.getTraits()) {
- if (auto *t = dyn_cast<tblgen::NativeOpTrait>(&trait)) {
- if (t->getTrait() == "::mlir::OpTrait::AttrSizedOperandSegments") {
+ if (auto *t = dyn_cast<tblgen::NativeTrait>(&trait)) {
+ if (t->getFullyQualifiedTraitName() ==
+ "::mlir::OpTrait::AttrSizedOperandSegments") {
body << formatv(checkAttrSizedValueSegmentsCode,
"operand_segment_sizes", op.getNumOperands(),
"operand");
- } else if (t->getTrait() == "::mlir::OpTrait::AttrSizedResultSegments") {
+ } else if (t->getFullyQualifiedTraitName() ==
+ "::mlir::OpTrait::AttrSizedResultSegments") {
body << formatv(checkAttrSizedValueSegmentsCode, "result_segment_sizes",
op.getNumResults(), "result");
}
#include "mlir/TableGen/GenInfo.h"
#include "mlir/TableGen/Interfaces.h"
#include "mlir/TableGen/OpClass.h"
-#include "mlir/TableGen/OpTrait.h"
#include "mlir/TableGen/Operator.h"
+#include "mlir/TableGen/Trait.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/SetVector.h"
operandTypes.resize(op.getNumOperands(), TypeResolution());
resultTypes.resize(op.getNumResults(), TypeResolution());
- hasImplicitTermTrait =
- llvm::any_of(op.getTraits(), [](const OpTrait &trait) {
- return trait.getDef().isSubClassOf("SingleBlockImplicitTerminator");
- });
+ hasImplicitTermTrait = llvm::any_of(op.getTraits(), [](const Trait &trait) {
+ return trait.getDef().isSubClassOf("SingleBlockImplicitTerminator");
+ });
hasSingleBlockTrait =
- hasImplicitTermTrait ||
- llvm::any_of(op.getTraits(), [](const OpTrait &trait) {
- if (auto *native = dyn_cast<NativeOpTrait>(&trait))
- return native->getTrait() == "::mlir::OpTrait::SingleBlock";
- return false;
- });
+ hasImplicitTermTrait || op.getTrait("::mlir::OpTrait::SingleBlock");
}
/// Generate the operation parser from this format.
// Check for any type traits that we can use for inferring types.
llvm::StringMap<TypeResolutionInstance> variableTyResolver;
- for (const OpTrait &trait : op.getTraits()) {
+ for (const Trait &trait : op.getTraits()) {
const llvm::Record &def = trait.getDef();
if (def.isSubClassOf("AllTypesMatch")) {
handleAllTypesMatchConstraint(def.getValueAsListOfStrings("values"),