[mlir][irdl] Add verification of IRDL ops
authorMathieu Fehr <mathieu.fehr@gmail.com>
Wed, 8 Mar 2023 22:16:02 +0000 (23:16 +0100)
committerMathieu Fehr <mathieu.fehr@gmail.com>
Wed, 17 May 2023 12:34:00 +0000 (13:34 +0100)
This patch adds verification on registered IRDL operations, types,
and attributes.

This is done through an interface implemented by operations from the
`irdl` dialect, which translate the operations into `Constraint`.
This interface is then use in the `registerDialect` function to
generate verifiers for the entire operation/type/attribute.

Depends on D145733

Reviewed By: Mogball

Differential Revision: https://reviews.llvm.org/D145734

mlir/include/mlir/Dialect/IRDL/IR/CMakeLists.txt
mlir/include/mlir/Dialect/IRDL/IR/IRDL.h
mlir/include/mlir/Dialect/IRDL/IR/IRDLInterfaces.h [new file with mode: 0644]
mlir/include/mlir/Dialect/IRDL/IR/IRDLInterfaces.td [new file with mode: 0644]
mlir/include/mlir/Dialect/IRDL/IR/IRDLOps.td
mlir/lib/Dialect/IRDL/CMakeLists.txt
mlir/lib/Dialect/IRDL/IR/IRDL.cpp
mlir/lib/Dialect/IRDL/IR/IRDLOps.cpp [new file with mode: 0644]
mlir/lib/Dialect/IRDL/IRDLLoading.cpp
mlir/test/Dialect/IRDL/testd.mlir

index e165bd7..1a40883 100644 (file)
@@ -1,5 +1,12 @@
 add_mlir_dialect(IRDL irdl)
 
+# Add IRDL interfaces
+set(LLVM_TARGET_DEFINITIONS IRDLInterfaces.td)
+mlir_tablegen(IRDLInterfaces.h.inc -gen-op-interface-decls)
+mlir_tablegen(IRDLInterfaces.cpp.inc -gen-op-interface-defs)
+add_public_tablegen_target(MLIRIRDLInterfacesIncGen)
+add_dependencies(mlir-generic-headers MLIRIRDLInterfacesIncGen)
+
 # Add IRDL operations
 set(LLVM_TARGET_DEFINITIONS IRDLOps.td)
 mlir_tablegen(IRDLOps.h.inc -gen-op-decls)
index c22f5e2..1b32691 100644 (file)
@@ -13,6 +13,7 @@
 #ifndef MLIR_DIALECT_IRDL_IR_IRDL_H_
 #define MLIR_DIALECT_IRDL_IR_IRDL_H_
 
+#include "mlir/Dialect/IRDL/IR/IRDLInterfaces.h"
 #include "mlir/Dialect/IRDL/IR/IRDLTraits.h"
 #include "mlir/IR/SymbolTable.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
diff --git a/mlir/include/mlir/Dialect/IRDL/IR/IRDLInterfaces.h b/mlir/include/mlir/Dialect/IRDL/IR/IRDLInterfaces.h
new file mode 100644 (file)
index 0000000..6455385
--- /dev/null
@@ -0,0 +1,38 @@
+//===- IRDLInterfaces.h - IRDL interfaces definition ------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares the interfaces used by the IRDL dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_IRDL_IR_IRDLINTERFACES_H_
+#define MLIR_DIALECT_IRDL_IR_IRDLINTERFACES_H_
+
+#include "mlir/Dialect/IRDL/IRDLVerifiers.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/ExtensibleDialect.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/Types.h"
+#include "mlir/Support/LogicalResult.h"
+#include <optional>
+
+namespace mlir {
+namespace irdl {
+class TypeOp;
+class AttributeOp;
+} // namespace irdl
+} // namespace mlir
+
+//===----------------------------------------------------------------------===//
+// IRDL Dialect Interfaces
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/IRDL/IR/IRDLInterfaces.h.inc"
+
+#endif //  MLIR_DIALECT_IRDL_IR_IRDLINTERFACES_H_
diff --git a/mlir/include/mlir/Dialect/IRDL/IR/IRDLInterfaces.td b/mlir/include/mlir/Dialect/IRDL/IR/IRDLInterfaces.td
new file mode 100644 (file)
index 0000000..0e45711
--- /dev/null
@@ -0,0 +1,40 @@
+//===- IRDLInterfaces.td - IRDL Attributes -----------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares the interfaces used by IRDL.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_IRDL_IR_IRDLINTERFACES
+#define MLIR_DIALECT_IRDL_IR_IRDLINTERFACES
+
+include "mlir/IR/OpBase.td"
+
+def VerifyConstraintInterface : OpInterface<"VerifyConstraintInterface"> {
+  let cppNamespace = "::mlir::irdl";
+
+  let description = [{
+    Interface to get an IRDL constraint verifier from an operation. 
+  }];
+
+  let methods = [
+    InterfaceMethod<
+      [{
+        Get an instance of a constraint verifier for the associated operation."
+        Returns `nullptr` upon failure.
+      }],
+      "std::unique_ptr<::mlir::irdl::Constraint>",
+      "getVerifier",
+      (ins "::mlir::SmallVector<Value> const&":$valueRes,
+      "::mlir::DenseMap<::mlir::irdl::TypeOp, std::unique_ptr<::mlir::DynamicTypeDefinition>> &":$types,
+      "::mlir::DenseMap<::mlir::irdl::AttributeOp, std::unique_ptr<::mlir::DynamicAttrDefinition>> &":$attrs)
+    >
+  ];
+}
+
+#endif // MLIR_DIALECT_IRDL_IR_IRDLINTERFACES
index 59d1524..5cce685 100644 (file)
@@ -15,6 +15,7 @@
 
 include "IRDL.td"
 include "IRDLTypes.td"
+include "IRDLInterfaces.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/IR/SymbolInterfaces.td"
@@ -264,7 +265,8 @@ def IRDL_ResultsOp : IRDL_Op<"results", [HasParent<"OperationOp">]> {
 //===----------------------------------------------------------------------===//
 
 class IRDL_ConstraintOp<string mnemonic, list<Trait> traits = []>
-    : IRDL_Op<mnemonic, traits> {
+    : IRDL_Op<mnemonic, [VerifyConstraintInterface,
+        DeclareOpInterfaceMethods<VerifyConstraintInterface>] # traits> {
 }
 
 def IRDL_Is : IRDL_ConstraintOp<"is",
index 7af0e42..d25760e 100644 (file)
@@ -1,5 +1,6 @@
 add_mlir_dialect_library(MLIRIRDL
   IR/IRDL.cpp
+  IR/IRDLOps.cpp
   IRDLLoading.cpp
   IRDLVerifiers.cpp
 
index e2649f2..01e58cc 100644 (file)
@@ -71,6 +71,8 @@ LogicalResult DialectOp::verify() {
   return success();
 }
 
+#include "mlir/Dialect/IRDL/IR/IRDLInterfaces.cpp.inc"
+
 #define GET_TYPEDEF_CLASSES
 #include "mlir/Dialect/IRDL/IR/IRDLTypesGen.cpp.inc"
 
diff --git a/mlir/lib/Dialect/IRDL/IR/IRDLOps.cpp b/mlir/lib/Dialect/IRDL/IR/IRDLOps.cpp
new file mode 100644 (file)
index 0000000..a9956cc
--- /dev/null
@@ -0,0 +1,61 @@
+//===- IRDLOps.cpp - IRDL dialect -------------------------------*- C++ -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/IRDL/IR/IRDL.h"
+
+using namespace mlir;
+using namespace mlir::irdl;
+
+std::unique_ptr<Constraint> Is::getVerifier(
+    SmallVector<Value> const &valueToConstr,
+    DenseMap<TypeOp, std::unique_ptr<DynamicTypeDefinition>> &types,
+    DenseMap<AttributeOp, std::unique_ptr<DynamicAttrDefinition>> &attrs) {
+  return std::make_unique<IsConstraint>(getExpectedAttr());
+}
+
+std::unique_ptr<Constraint> Parametric::getVerifier(
+    SmallVector<Value> const &valueToConstr,
+    DenseMap<TypeOp, std::unique_ptr<DynamicTypeDefinition>> &types,
+    DenseMap<AttributeOp, std::unique_ptr<DynamicAttrDefinition>> &attrs) {
+  SmallVector<unsigned> constraints;
+  for (Value arg : getArgs()) {
+    for (auto [i, value] : enumerate(valueToConstr)) {
+      if (value == arg) {
+        constraints.push_back(i);
+        break;
+      }
+    }
+  }
+
+  // Symbol reference case for the base
+  SymbolRefAttr symRef = getBaseType();
+  Operation *defOp =
+      SymbolTable::lookupNearestSymbolFrom(getOperation(), symRef);
+  if (!defOp) {
+    emitError() << symRef << " does not refer to any existing symbol";
+    return nullptr;
+  }
+
+  if (auto typeOp = dyn_cast<TypeOp>(defOp))
+    return std::make_unique<DynParametricTypeConstraint>(types[typeOp].get(),
+                                                         constraints);
+
+  if (auto attrOp = dyn_cast<AttributeOp>(defOp))
+    return std::make_unique<DynParametricAttrConstraint>(attrs[attrOp].get(),
+                                                         constraints);
+
+  llvm_unreachable("verifier should ensure that the referenced operation is "
+                   "either a type or an attribute definition");
+}
+
+std::unique_ptr<Constraint> Any::getVerifier(
+    SmallVector<Value> const &valueToConstr,
+    DenseMap<TypeOp, std::unique_ptr<DynamicTypeDefinition>> &types,
+    DenseMap<AttributeOp, std::unique_ptr<DynamicAttrDefinition>> &attrs) {
+  return std::make_unique<AnyAttributeConstraint>();
+}
index fb00085..f65d0ec 100644 (file)
@@ -12,6 +12,7 @@
 
 #include "mlir/Dialect/IRDL/IRDLLoading.h"
 #include "mlir/Dialect/IRDL/IR/IRDL.h"
+#include "mlir/Dialect/IRDL/IR/IRDLInterfaces.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/ExtensibleDialect.h"
 #include "mlir/Support/LogicalResult.h"
 using namespace mlir;
 using namespace mlir::irdl;
 
+/// Verify that the given list of parameters satisfy the given constraints.
+/// This encodes the logic of the verification method for attributes and types
+/// defined with IRDL.
+static LogicalResult
+irdlAttrOrTypeVerifier(function_ref<InFlightDiagnostic()> emitError,
+                       ArrayRef<Attribute> params,
+                       ArrayRef<std::unique_ptr<Constraint>> constraints,
+                       ArrayRef<size_t> paramConstraints) {
+  if (params.size() != paramConstraints.size()) {
+    emitError() << "expected " << paramConstraints.size()
+                << " type arguments, but had " << params.size();
+    return failure();
+  }
+
+  ConstraintVerifier verifier(constraints);
+
+  // Check that each parameter satisfies its constraint.
+  for (auto [i, param] : enumerate(params))
+    if (failed(verifier.verify(emitError, param, paramConstraints[i])))
+      return failure();
+
+  return success();
+}
+
+/// Verify that the given operation satisfies the given constraints.
+/// This encodes the logic of the verification method for operations defined
+/// with IRDL.
+static LogicalResult
+irdlOpVerifier(Operation *op, ArrayRef<std::unique_ptr<Constraint>> constraints,
+               ArrayRef<size_t> operandConstrs,
+               ArrayRef<size_t> resultConstrs) {
+  /// Check that we have the right number of operands.
+  unsigned numOperands = op->getNumOperands();
+  size_t numExpectedOperands = operandConstrs.size();
+  if (numOperands != numExpectedOperands)
+    return op->emitOpError() << numExpectedOperands
+                             << " operands expected, but got " << numOperands;
+
+  /// Check that we have the right number of results.
+  unsigned numResults = op->getNumResults();
+  size_t numExpectedResults = resultConstrs.size();
+  if (numResults != numExpectedResults)
+    return op->emitOpError()
+           << numExpectedResults << " results expected, but got " << numResults;
+
+  auto emitError = [op]() { return op->emitError(); };
+
+  ConstraintVerifier verifier(constraints);
+
+  /// Check that all operands satisfy the constraints.
+  for (auto [i, operandType] : enumerate(op->getOperandTypes()))
+    if (failed(verifier.verify({emitError}, TypeAttr::get(operandType),
+                               operandConstrs[i])))
+      return failure();
+
+  /// Check that all results satisfy the constraints.
+  for (auto [i, resultType] : enumerate(op->getResultTypes()))
+    if (failed(verifier.verify({emitError}, TypeAttr::get(resultType),
+                               resultConstrs[i])))
+      return failure();
+
+  return success();
+}
+
 /// Define and load an operation represented by a `irdl.operation`
 /// operation.
-static WalkResult loadOperation(OperationOp op, ExtensibleDialect *dialect) {
+static WalkResult loadOperation(
+    OperationOp op, ExtensibleDialect *dialect,
+    DenseMap<TypeOp, std::unique_ptr<DynamicTypeDefinition>> &types,
+    DenseMap<AttributeOp, std::unique_ptr<DynamicAttrDefinition>> &attrs) {
+  // Resolve SSA values to verifier constraint slots
+  SmallVector<Value> constrToValue;
+  for (Operation &op : op->getRegion(0).getOps()) {
+    if (isa<VerifyConstraintInterface>(op)) {
+      if (op.getNumResults() != 1)
+        return op.emitError()
+               << "IRDL constraint operations must have exactly one result";
+      constrToValue.push_back(op.getResult(0));
+    }
+  }
+
+  // Build the verifiers for each constraint slot
+  SmallVector<std::unique_ptr<Constraint>> constraints;
+  for (Value v : constrToValue) {
+    VerifyConstraintInterface op =
+        cast<VerifyConstraintInterface>(v.getDefiningOp());
+    std::unique_ptr<Constraint> verifier =
+        op.getVerifier(constrToValue, types, attrs);
+    if (!verifier)
+      return WalkResult::interrupt();
+    constraints.push_back(std::move(verifier));
+  }
+
+  SmallVector<size_t> operandConstraints;
+  SmallVector<size_t> resultConstraints;
+
+  // Gather which constraint slots correspond to operand constraints
+  auto operandsOp = op.getOp<OperandsOp>();
+  if (operandsOp.has_value()) {
+    operandConstraints.reserve(operandsOp->getArgs().size());
+    for (Value operand : operandsOp->getArgs()) {
+      for (auto [i, constr] : enumerate(constrToValue)) {
+        if (constr == operand) {
+          operandConstraints.push_back(i);
+          break;
+        }
+      }
+    }
+  }
+
+  // Gather which constraint slots correspond to result constraints
+  auto resultsOp = op.getOp<ResultsOp>();
+  if (resultsOp.has_value()) {
+    resultConstraints.reserve(resultsOp->getArgs().size());
+    for (Value result : resultsOp->getArgs()) {
+      for (auto [i, constr] : enumerate(constrToValue)) {
+        if (constr == result) {
+          resultConstraints.push_back(i);
+          break;
+        }
+      }
+    }
+  }
+
   // IRDL does not support defining custom parsers or printers.
   auto parser = [](OpAsmParser &parser, OperationState &result) {
     return failure();
@@ -33,7 +155,13 @@ static WalkResult loadOperation(OperationOp op, ExtensibleDialect *dialect) {
     printer.printGenericOp(op);
   };
 
-  auto verifier = [](Operation *op) { return success(); };
+  auto verifier =
+      [constraints{std::move(constraints)},
+       operandConstraints{std::move(operandConstraints)},
+       resultConstraints{std::move(resultConstraints)}](Operation *op) {
+        return irdlOpVerifier(op, constraints, operandConstraints,
+                              resultConstraints);
+      };
 
   // IRDL does not support defining regions.
   auto regionVerifier = [](Operation *op) { return success(); };
@@ -46,6 +174,71 @@ static WalkResult loadOperation(OperationOp op, ExtensibleDialect *dialect) {
   return WalkResult::advance();
 }
 
+/// Get the verifier of a type or attribute definition.
+/// Return nullptr if the definition is invalid.
+static DynamicAttrDefinition::VerifierFn getAttrOrTypeVerifier(
+    Operation *attrOrTypeDef, ExtensibleDialect *dialect,
+    DenseMap<TypeOp, std::unique_ptr<DynamicTypeDefinition>> &types,
+    DenseMap<AttributeOp, std::unique_ptr<DynamicAttrDefinition>> &attrs) {
+  assert((isa<AttributeOp>(attrOrTypeDef) || isa<TypeOp>(attrOrTypeDef)) &&
+         "Expected an attribute or type definition");
+
+  // Resolve SSA values to verifier constraint slots
+  SmallVector<Value> constrToValue;
+  for (Operation &op : attrOrTypeDef->getRegion(0).getOps()) {
+    if (isa<VerifyConstraintInterface>(op)) {
+      assert(op.getNumResults() == 1 &&
+             "IRDL constraint operations must have exactly one result");
+      constrToValue.push_back(op.getResult(0));
+    }
+  }
+
+  // Build the verifiers for each constraint slot
+  SmallVector<std::unique_ptr<Constraint>> constraints;
+  for (Value v : constrToValue) {
+    VerifyConstraintInterface op =
+        cast<VerifyConstraintInterface>(v.getDefiningOp());
+    std::unique_ptr<Constraint> verifier =
+        op.getVerifier(constrToValue, types, attrs);
+    if (!verifier)
+      return {};
+    constraints.push_back(std::move(verifier));
+  }
+
+  // Get the parameter definitions.
+  std::optional<ParametersOp> params;
+  if (auto attr = dyn_cast<AttributeOp>(attrOrTypeDef))
+    params = attr.getOp<ParametersOp>();
+  else if (auto type = dyn_cast<TypeOp>(attrOrTypeDef))
+    params = type.getOp<ParametersOp>();
+
+  // Gather which constraint slots correspond to parameter constraints
+  SmallVector<size_t> paramConstraints;
+  if (params.has_value()) {
+    paramConstraints.reserve(params->getArgs().size());
+    for (Value param : params->getArgs()) {
+      for (auto [i, constr] : enumerate(constrToValue)) {
+        if (constr == param) {
+          paramConstraints.push_back(i);
+          break;
+        }
+      }
+    }
+  }
+
+  auto verifier = [paramConstraints{std::move(paramConstraints)},
+                   constraints{std::move(constraints)}](
+                      function_ref<InFlightDiagnostic()> emitError,
+                      ArrayRef<Attribute> params) {
+    return irdlAttrOrTypeVerifier(emitError, params, constraints,
+                                  paramConstraints);
+  };
+
+  // While the `std::move` is not required, not adding it triggers a bug in
+  // clang-10.
+  return std::move(verifier);
+}
+
 /// Load all dialects in the given module, without loading any operation, type
 /// or attribute definitions.
 static DenseMap<DialectOp, ExtensibleDialect *> loadEmptyDialects(ModuleOp op) {
@@ -108,9 +301,33 @@ LogicalResult mlir::irdl::loadDialects(ModuleOp op) {
   DenseMap<AttributeOp, std::unique_ptr<DynamicAttrDefinition>> attrs =
       preallocateAttrDefs(op, dialects);
 
+  // Set the verifier for types.
+  WalkResult res = op.walk([&](TypeOp typeOp) {
+    DynamicAttrDefinition::VerifierFn verifier = getAttrOrTypeVerifier(
+        typeOp, dialects[typeOp.getParentOp()], types, attrs);
+    if (!verifier)
+      return WalkResult::interrupt();
+    types[typeOp]->setVerifyFn(std::move(verifier));
+    return WalkResult::advance();
+  });
+  if (res.wasInterrupted())
+    return failure();
+
+  // Set the verifier for attributes.
+  res = op.walk([&](AttributeOp attrOp) {
+    DynamicAttrDefinition::VerifierFn verifier = getAttrOrTypeVerifier(
+        attrOp, dialects[attrOp.getParentOp()], types, attrs);
+    if (!verifier)
+      return WalkResult::interrupt();
+    attrs[attrOp]->setVerifyFn(std::move(verifier));
+    return WalkResult::advance();
+  });
+  if (res.wasInterrupted())
+    return failure();
+
   // Define and load all operations.
-  WalkResult res = op.walk([&](OperationOp opOp) {
-    return loadOperation(opOp, dialects[opOp.getParentOp()]);
+  res = op.walk([&](OperationOp opOp) {
+    return loadOperation(opOp, dialects[opOp.getParentOp()], types, attrs);
   });
   if (res.wasInterrupted())
     return failure();
index f6d1bcb..e9be54b 100644 (file)
@@ -45,6 +45,13 @@ func.func @succeededEqConstraint() {
   return
 }
 
+// -----
+
+func.func @failedEqConstraint() {
+  // expected-error@+1 {{expected 'i32' but got 'i64'}}
+  "testd.eq"() : () -> i64
+  return
+}
 
 // -----
 
@@ -74,6 +81,13 @@ func.func @succeededDynBaseConstraint() {
   return
 }
 
+// -----
+
+func.func @failedDynBaseConstraint() {
+  // expected-error@+1 {{expected base type 'testd.parametric' but got 'i32'}}
+  "testd.dynbase"() : () -> i32
+  return
+}
 
 // -----
 
@@ -89,6 +103,22 @@ func.func @succeededDynParamsConstraint() {
 
 // -----
 
+func.func @failedDynParamsConstraintBase() {
+  // expected-error@+1 {{expected base type 'testd.parametric' but got 'i32'}}
+  "testd.dynparams"() : () -> i32
+  return
+}
+
+// -----
+
+func.func @failedDynParamsConstraintParam() {
+  // expected-error@+1 {{expected 'i32' but got 'i1'}}
+  "testd.dynparams"() : () -> !testd.parametric<i1>
+  return
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // Constraint variables
 //===----------------------------------------------------------------------===//
@@ -106,3 +136,11 @@ func.func @succeededConstraintVars2() {
   "testd.constraint_vars"() : () -> (i64, i64)
   return
 }
+
+// -----
+
+func.func @failedConstraintVars() {
+  // expected-error@+1 {{expected 'i64' but got 'i32'}}
+  "testd.constraint_vars"() : () -> (i64, i32)
+  return
+}