[MLIR] Add C API for navigating up the IR tree
authorGeorge <GeorgeLyon@users.noreply.github.com>
Tue, 9 Feb 2021 03:54:19 +0000 (19:54 -0800)
committerGeorge <GeorgeLyon@users.noreply.github.com>
Tue, 9 Feb 2021 03:54:38 +0000 (19:54 -0800)
Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D96301

mlir/include/mlir-c/IR.h
mlir/lib/CAPI/IR/IR.cpp
mlir/test/CAPI/ir.c

index 8bee618..65c097a 100644 (file)
@@ -322,6 +322,9 @@ static inline bool mlirOperationIsNull(MlirOperation op) { return !op.ptr; }
 MLIR_CAPI_EXPORTED bool mlirOperationEqual(MlirOperation op,
                                            MlirOperation other);
 
+/// Gets the context this operation is associated with
+MLIR_CAPI_EXPORTED MlirContext mlirOperationGetContext(MlirOperation op);
+
 /// Gets the name of the operation as an identifier.
 MLIR_CAPI_EXPORTED MlirIdentifier mlirOperationGetName(MlirOperation op);
 
@@ -467,6 +470,9 @@ static inline bool mlirBlockIsNull(MlirBlock block) { return !block.ptr; }
 /// perform deep comparison.
 MLIR_CAPI_EXPORTED bool mlirBlockEqual(MlirBlock block, MlirBlock other);
 
+/// Returns the closest surrounding operation that contains this block.
+MLIR_CAPI_EXPORTED MlirOperation mlirBlockGetParentOperation(MlirBlock);
+
 /// Returns the block immediately following the given block in its parent
 /// region.
 MLIR_CAPI_EXPORTED MlirBlock mlirBlockGetNextInRegion(MlirBlock block);
index fdb830e..87c0994 100644 (file)
@@ -305,6 +305,10 @@ bool mlirOperationEqual(MlirOperation op, MlirOperation other) {
   return unwrap(op) == unwrap(other);
 }
 
+MlirContext mlirOperationGetContext(MlirOperation op) {
+  return wrap(unwrap(op)->getContext());
+}
+
 MlirIdentifier mlirOperationGetName(MlirOperation op) {
   return wrap(unwrap(op)->getName().getIdentifier());
 }
@@ -461,6 +465,10 @@ bool mlirBlockEqual(MlirBlock block, MlirBlock other) {
   return unwrap(block) == unwrap(other);
 }
 
+MlirOperation mlirBlockGetParentOperation(MlirBlock block) {
+  return wrap(unwrap(block)->getParentOp());
+}
+
 MlirBlock mlirBlockGetNextInRegion(MlirBlock block) {
   return wrap(unwrap(block)->getNextNode());
 }
index c2b13d5..2f81d13 100644 (file)
@@ -1439,6 +1439,37 @@ int registerOnlyStd() {
   return 0;
 }
 
+/// Tests backreference APIs
+static int testBackreferences() {
+  fprintf(stderr, "@test_backreferences\n");
+
+  MlirContext ctx = mlirContextCreate();
+  mlirContextSetAllowUnregisteredDialects(ctx, true);
+  MlirLocation loc = mlirLocationUnknownGet(ctx);
+
+  MlirOperationState opState = mlirOperationStateGet(mlirStringRefCreateFromCString("invalid.op"), loc);
+  MlirRegion region = mlirRegionCreate();
+  MlirBlock block = mlirBlockCreate(0, NULL);
+  mlirRegionAppendOwnedBlock(region, block);
+  mlirOperationStateAddOwnedRegions(&opState, 1, &region);
+  MlirOperation op = mlirOperationCreate(&opState);
+
+  if (!mlirContextEqual(ctx, mlirOperationGetContext(op))) {
+    fprintf(stderr, "ERROR: Getting context from operation failed\n");
+    return 1;
+  }
+  if (!mlirOperationEqual(op, mlirBlockGetParentOperation(block))) {
+    fprintf(stderr, "ERROR: Getting parent operation from block failed\n");
+    return 2;
+  }
+  
+  mlirOperationDestroy(op);
+  mlirContextDestroy(ctx);
+
+  // CHECK-LABEL: @test_backreferences
+  return 0;
+}
+
 // Wraps a diagnostic into additional text we can match against.
 MlirLogicalResult errorHandler(MlirDiagnostic diagnostic, void *userData) {
   fprintf(stderr, "processing diagnostic (userData: %ld) <<\n", (long)userData);
@@ -1514,6 +1545,8 @@ int main() {
     return 8;
   if (registerOnlyStd())
     return 9;
+  if (testBackreferences())
+    return 10;
 
   mlirContextDestroy(ctx);