From 853a614864754cd4b000f03a7ab8fbba103d6177 Mon Sep 17 00:00:00 2001 From: Sean Silva Date: Mon, 14 Jun 2021 17:28:01 -0700 Subject: [PATCH] [mlir:OpFormatGen] Add Support for `$_ctxt` in the transformer. This is useful for "build tuple" type ops. In my case, in npcomp, I have an op: ``` // Result type is `!torch.tuple`. torch.prim.TupleConstruct %0, %1 : !torch.tensor, !torch.tensor ``` and the context is required for the `Torch::TupleType::get` call (for the case of an empty tuple). The handling of these FmtContext's in the code is pretty ad-hoc -- I didn't attempt to rationalize it and just made a targeted fix. As someone unfamiliar with the code I had a hard time seeing how to more broadly fix the situation. Differential Revision: https://reviews.llvm.org/D104274 --- mlir/test/lib/Dialect/Test/TestOps.td | 9 +++++++++ mlir/test/mlir-tblgen/op-format.mlir | 3 +++ mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp | 1 + mlir/tools/mlir-tblgen/OpFormatGen.cpp | 1 + 4 files changed, 14 insertions(+) diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index 847436e..ea39b9c 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -1876,6 +1876,15 @@ def FormatTypesMatchAttrOp : TEST_Op<"format_types_match_attr", [ let assemblyFormat = "attr-dict $value"; } +def FormatTypesMatchContextOp : TEST_Op<"format_types_match_context", [ + TypesMatchWith<"tuple result type matches operand type", "value", "result", + "::mlir::TupleType::get($_ctxt, $_self)"> + ]> { + let arguments = (ins AnyType:$value); + let results = (outs AnyType:$result); + let assemblyFormat = "attr-dict $value `:` type($value)"; +} + //===----------------------------------------------------------------------===// // Test SideEffects //===----------------------------------------------------------------------===// diff --git a/mlir/test/mlir-tblgen/op-format.mlir b/mlir/test/mlir-tblgen/op-format.mlir index e6f998f..759e7f5 100644 --- a/mlir/test/mlir-tblgen/op-format.mlir +++ b/mlir/test/mlir-tblgen/op-format.mlir @@ -348,3 +348,6 @@ test.format_infer_variadic_type_from_non_variadic %i64, %i64 : i64 // CHECK: test.format_types_match_attr 1 : i64 %ignored_res5 = test.format_types_match_attr 1 : i64 + +// CHECK: test.format_types_match_context %[[I64]] : i64 +%ignored_res6 = test.format_types_match_context %i64 : i64 diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp index 78e84d7..cded9af 100644 --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -579,6 +579,7 @@ OpEmitter::OpEmitter(const Operator &op, opClass(op.getCppClassName(), op.getExtraClassDeclaration()), staticVerifierEmitter(staticVerifierEmitter) { verifyCtx.withOp("(*this->getOperation())"); + verifyCtx.addSubst("_ctxt", "this->getOperation()->getContext()"); genTraits(); diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp index 3c3f00f..2c91708 100644 --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -1343,6 +1343,7 @@ void OperationFormat::genParserTypeResolution(Operator &op, } else if (const NamedTypeConstraint *var = resolver.getVariable()) { if (Optional tform = resolver.getVarTransformer()) { FmtContext fmtContext; + fmtContext.addSubst("_ctxt", "parser.getBuilder().getContext()"); if (var->isVariadic()) fmtContext.withSelf(var->name + "Types"); else -- 2.7.4