def Arith_ConstantOp : Op<Arithmetic_Dialect, "constant",
[ConstantLike, NoSideEffect,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
- TypesMatchWith<
- "result and attribute have the same type",
- "value", "result", "$_self">]> {
+ AllTypesMatch<["value", "result"]>]> {
let summary = "integer or floating point constant";
let description = [{
The `constant` operation produces an SSA value equal to some integer or
let results = (outs /*SignlessIntegerOrFloatLike*/AnyType:$result);
let builders = [
- OpBuilder<(ins "Attribute":$value),
- [{ build($_builder, $_state, value.getType(), value); }]>,
OpBuilder<(ins "Attribute":$value, "Type":$type),
[{ build($_builder, $_state, type, value); }]>,
];
/// ensure that the static functions have a unique name.
std::string uniqueOutputLabel;
- /// Unique constraints by their predicate and summary. Constraints that share
- /// the same predicate may have different descriptions; ensure that the
- /// correct error message is reported when verification fails.
- struct ConstraintUniquer {
- static Constraint getEmptyKey();
- static Constraint getTombstoneKey();
- static unsigned getHashValue(Constraint constraint);
- static bool isEqual(Constraint lhs, Constraint rhs);
- };
/// Use a MapVector to ensure that functions are generated deterministically.
- using ConstraintMap =
- llvm::MapVector<Constraint, std::string,
- llvm::DenseMap<Constraint, unsigned, ConstraintUniquer>>;
+ using ConstraintMap = llvm::MapVector<Constraint, std::string,
+ llvm::DenseMap<Constraint, unsigned>>;
/// A generic function to emit constraints
void emitConstraints(const ConstraintMap &constraints, StringRef selfName,
} // namespace tblgen
} // namespace mlir
+namespace llvm {
+/// Unique constraints by their predicate and summary. Constraints that share
+/// the same predicate may have different descriptions; ensure that the
+/// correct error message is reported when verification fails.
+template <>
+struct DenseMapInfo<mlir::tblgen::Constraint> {
+ using RecordDenseMapInfo = llvm::DenseMapInfo<const llvm::Record *>;
+
+ static mlir::tblgen::Constraint getEmptyKey();
+ static mlir::tblgen::Constraint getTombstoneKey();
+ static unsigned getHashValue(mlir::tblgen::Constraint constraint);
+ static bool isEqual(mlir::tblgen::Constraint lhs,
+ mlir::tblgen::Constraint rhs);
+};
+} // namespace llvm
+
#endif // MLIR_TABLEGEN_CONSTRAINT_H_
std::vector<std::string> &&entities)
: constraint(constraint), self(std::string(self)),
entities(std::move(entities)) {}
+
+Constraint DenseMapInfo<Constraint>::getEmptyKey() {
+ return Constraint(RecordDenseMapInfo::getEmptyKey(),
+ Constraint::CK_Uncategorized);
+}
+
+Constraint DenseMapInfo<Constraint>::getTombstoneKey() {
+ return Constraint(RecordDenseMapInfo::getTombstoneKey(),
+ Constraint::CK_Uncategorized);
+}
+
+unsigned DenseMapInfo<Constraint>::getHashValue(Constraint constraint) {
+ if (constraint == getEmptyKey())
+ return RecordDenseMapInfo::getHashValue(RecordDenseMapInfo::getEmptyKey());
+ if (constraint == getTombstoneKey()) {
+ return RecordDenseMapInfo::getHashValue(
+ RecordDenseMapInfo::getTombstoneKey());
+ }
+ return llvm::hash_combine(constraint.getPredicate(), constraint.getSummary());
+}
+
+bool DenseMapInfo<Constraint>::isEqual(Constraint lhs, Constraint rhs) {
+ if (lhs == rhs)
+ return true;
+ if (lhs == getEmptyKey() || lhs == getTombstoneKey())
+ return false;
+ if (rhs == getEmptyKey() || rhs == getTombstoneKey())
+ return false;
+ return lhs.getPredicate() == rhs.getPredicate() &&
+ lhs.getSummary() == rhs.getSummary();
+}
continue;
}
- if (getArg(*mi).is<NamedAttribute *>()) {
- // TODO: Handle attributes.
- continue;
- }
resultTypeMapping[i].emplace_back(*mi);
found = true;
}
loc=None,
ip=None):
if isinstance(value, int):
- super().__init__(result, IntegerAttr.get(result, value), loc=loc, ip=ip)
+ super().__init__(IntegerAttr.get(result, value), loc=loc, ip=ip)
elif isinstance(value, float):
- super().__init__(result, FloatAttr.get(result, value), loc=loc, ip=ip)
+ super().__init__(FloatAttr.get(result, value), loc=loc, ip=ip)
else:
- super().__init__(result, value, loc=loc, ip=ip)
+ super().__init__(value, loc=loc, ip=ip)
@classmethod
def create_index(cls, value: int, *, loc=None, ip=None):
// -----
func.func @complex_constant_wrong_attribute_type() {
- // expected-error @+1 {{'arith.constant' op failed to verify that result and attribute have the same type}}
+ // expected-error @+1 {{'arith.constant' op failed to verify that all of {value, result} have same type}}
%0 = "arith.constant" () {value = 1.0 : f32} : () -> complex<f32>
return
}
func.func @constant() {
^bb:
- %x = "arith.constant"(){value = "xyz"} : () -> i32 // expected-error {{'arith.constant' op failed to verify that result and attribute have the same type}}
+ %x = "arith.constant"(){value = "xyz"} : () -> i32 // expected-error {{'arith.constant' op failed to verify that all of {value, result} have same type}}
return
}
func.func @constant_out_of_range() {
^bb:
- %x = "arith.constant"(){value = 100} : () -> i1 // expected-error {{'arith.constant' op failed to verify that result and attribute have the same type}}
+ %x = "arith.constant"(){value = 100} : () -> i1 // expected-error {{'arith.constant' op failed to verify that all of {value, result} have same type}}
return
}
func.func @constant_wrong_type() {
^bb:
- %x = "arith.constant"(){value = 10.} : () -> f32 // expected-error {{'arith.constant' op failed to verify that result and attribute have the same type}}
+ %x = "arith.constant"(){value = 10.} : () -> f32 // expected-error {{'arith.constant' op failed to verify that all of {value, result} have same type}}
return
}
// Emit the first available call stack in the fused location.
func.func @constant_out_of_range() {
- // CHECK: mysource1:0:0: error: 'arith.constant' op failed to verify that result and attribute have the same type
+ // CHECK: mysource1:0:0: error: 'arith.constant' op failed to verify that all of {value, result} have same type
// CHECK-NEXT: mysource2:1:0: note: called from
// CHECK-NEXT: mysource3:2:0: note: called from
%x = "arith.constant"() {value = 100} : () -> i1 loc(fused["bar", callsite("foo"("mysource1":0:0) at callsite("mysource2":1:0 at "mysource3":2:0))])
// CHECK-LABEL: LogicalResult OpL1::inferReturnTypes
// CHECK-NOT: }
-// CHECK: inferredReturnTypes[0] = operands[0].getType();
+// CHECK: ::mlir::Type odsInferredType0 = operands[0].getType();
+// CHECK: inferredReturnTypes[0] = odsInferredType0;
def OpL2 : NS_Op<"op_with_all_types_constraint",
[AllTypesMatch<["c", "b"]>, AllTypesMatch<["a", "d"]>]> {
// CHECK-LABEL: LogicalResult OpL2::inferReturnTypes
// CHECK-NOT: }
-// CHECK: inferredReturnTypes[0] = operands[2].getType();
-// CHECK: inferredReturnTypes[1] = operands[0].getType();
+// CHECK: ::mlir::Type odsInferredType0 = operands[2].getType();
+// CHECK: ::mlir::Type odsInferredType1 = operands[0].getType();
+// CHECK: inferredReturnTypes[0] = odsInferredType0;
+// CHECK: inferredReturnTypes[1] = odsInferredType1;
+
+def OpL3 : NS_Op<"op_with_all_types_constraint",
+ [AllTypesMatch<["a", "b"]>]> {
+ let arguments = (ins I32Attr:$a);
+ let results = (outs AnyType:$b);
+}
+
+// CHECK-LABEL: LogicalResult OpL3::inferReturnTypes
+// CHECK-NOT: }
+// CHECK: ::mlir::Type odsInferredType0 = attributes.get("a").getType();
+// CHECK: inferredReturnTypes[0] = odsInferredType0;
//===----------------------------------------------------------------------===//
// Constraint Uniquing
-using RecordDenseMapInfo = llvm::DenseMapInfo<const llvm::Record *>;
-
-Constraint StaticVerifierFunctionEmitter::ConstraintUniquer::getEmptyKey() {
- return Constraint(RecordDenseMapInfo::getEmptyKey(),
- Constraint::CK_Uncategorized);
-}
-
-Constraint StaticVerifierFunctionEmitter::ConstraintUniquer::getTombstoneKey() {
- return Constraint(RecordDenseMapInfo::getTombstoneKey(),
- Constraint::CK_Uncategorized);
-}
-
-unsigned StaticVerifierFunctionEmitter::ConstraintUniquer::getHashValue(
- Constraint constraint) {
- if (constraint == getEmptyKey())
- return RecordDenseMapInfo::getHashValue(RecordDenseMapInfo::getEmptyKey());
- if (constraint == getTombstoneKey()) {
- return RecordDenseMapInfo::getHashValue(
- RecordDenseMapInfo::getTombstoneKey());
- }
- return llvm::hash_combine(constraint.getPredicate(), constraint.getSummary());
-}
-
-bool StaticVerifierFunctionEmitter::ConstraintUniquer::isEqual(Constraint lhs,
- Constraint rhs) {
- if (lhs == rhs)
- return true;
- if (lhs == getEmptyKey() || lhs == getTombstoneKey())
- return false;
- if (rhs == getEmptyKey() || rhs == getTombstoneKey())
- return false;
- return lhs.getPredicate() == rhs.getPredicate() &&
- lhs.getSummary() == rhs.getSummary();
-}
-
/// An attribute constraint that references anything other than itself and the
/// current op cannot be generically extracted into a function. Most
/// prohibitive are operands and results, which require calls to
fctx.withBuilder("odsBuilder");
body << " ::mlir::Builder odsBuilder(context);\n";
- auto emitType = [&](const tblgen::Operator::ArgOrType &type) -> MethodBody & {
- if (!type.isArg())
- return body << tgfmt(*type.getType().getBuilderCall(), &fctx);
- auto argIndex = type.getArg();
- assert(!op.getArg(argIndex).is<NamedAttribute *>());
+ // Preprocess the result types and build all of the types used during
+ // inferrence. This limits the amount of duplicated work when a type is used
+ // to infer multiple others.
+ llvm::DenseMap<Constraint, int> constraintsTypes;
+ llvm::DenseMap<int, int> argumentsTypes;
+ int inferredTypeIdx = 0;
+ for (int i = 0, e = op.getNumResults(); i != e; ++i) {
+ auto type = op.getSameTypeAsResult(i).front();
+
+ // If the type isn't an argument, it refers to a buildable type.
+ if (!type.isArg()) {
+ auto it = constraintsTypes.try_emplace(type.getType(), inferredTypeIdx);
+ if (!it.second)
+ continue;
+
+ // If we haven't seen this constraint, generate a variable for it.
+ body << " ::mlir::Type odsInferredType" << inferredTypeIdx++ << " = "
+ << tgfmt(*type.getType().getBuilderCall(), &fctx) << ";\n";
+ continue;
+ }
+
+ // Otherwise, this is an argument.
+ int argIndex = type.getArg();
+ auto it = argumentsTypes.try_emplace(argIndex, inferredTypeIdx);
+ if (!it.second)
+ continue;
+ body << " ::mlir::Type odsInferredType" << inferredTypeIdx++ << " = ";
+
+ // If this is an operand, just index into operand list to access the type.
auto arg = op.getArgToOperandOrAttribute(argIndex);
- if (arg.kind() == Operator::OperandOrAttribute::Kind::Operand)
- return body << "operands[" << arg.operandOrAttributeIndex()
- << "].getType()";
- return body << "attributes[" << arg.operandOrAttributeIndex()
- << "].getType()";
- };
+ if (arg.kind() == Operator::OperandOrAttribute::Kind::Operand) {
+ body << "operands[" << arg.operandOrAttributeIndex() << "].getType()";
+
+ // If this is an attribute, index into the attribute dictionary.
+ } else {
+ auto *attr =
+ op.getArg(arg.operandOrAttributeIndex()).get<NamedAttribute *>();
+ body << "attributes.get(\"" << attr->name << "\").getType()";
+ }
+ body << ";\n";
+ }
+ // Perform a second pass that handles assigning the inferred types to the
+ // results.
for (int i = 0, e = op.getNumResults(); i != e; ++i) {
- body << " inferredReturnTypes[" << i << "] = ";
auto types = op.getSameTypeAsResult(i);
- emitType(types[0]) << ";\n";
+
+ // Append the inferred type.
+ auto type = types.front();
+ body << " inferredReturnTypes[" << i << "] = odsInferredType"
+ << (type.isArg() ? argumentsTypes[type.getArg()]
+ : constraintsTypes[type.getType()])
+ << ";\n";
+
if (types.size() == 1)
continue;
// TODO: We could verify equality here, but skipping that for verification.