From 6059122601178fe64353387afd344c7555bd6372 Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Mon, 5 Aug 2019 01:57:27 -0700 Subject: [PATCH] Introduce custom syntax for llvm.func Similar to all LLVM dialect operations, llvm.func needs to have the custom syntax. Use the generic FunctionLike printer and parser to implement it. PiperOrigin-RevId: 261641755 --- mlir/include/mlir/IR/FunctionSupport.h | 11 ++-- mlir/include/mlir/LLVMIR/LLVMDialect.h | 5 ++ mlir/include/mlir/LLVMIR/LLVMOps.td | 4 ++ mlir/lib/IR/Function.cpp | 5 +- mlir/lib/IR/FunctionSupport.cpp | 8 ++- mlir/lib/LLVMIR/IR/LLVMDialect.cpp | 71 ++++++++++++++++++++- mlir/test/LLVMIR/func.mlir | 111 ++++++++++++++++++++++++++++----- 7 files changed, 190 insertions(+), 25 deletions(-) diff --git a/mlir/include/mlir/IR/FunctionSupport.h b/mlir/include/mlir/IR/FunctionSupport.h index a70013a..192f5dd 100644 --- a/mlir/include/mlir/IR/FunctionSupport.h +++ b/mlir/include/mlir/IR/FunctionSupport.h @@ -55,15 +55,16 @@ inline ArrayRef getArgAttrs(Operation *op, unsigned index) { /// Callback type for `parseFunctionLikeOp`, the callback should produce the /// type that will be associated with a function-like operation from lists of -/// function arguments and results. -using FuncTypeBuilder = - llvm::function_ref, ArrayRef)>; +/// function arguments and results; in case of error, it may populate the last +/// argument with a message. +using FuncTypeBuilder = llvm::function_ref, + ArrayRef, std::string &)>; /// Parser implementation for function-like operations. Uses /// `funcTypeBuilder` to construct the custom function type given lists of /// input and output types. If the builder returns a null type, `result` will -/// not contain the `type` attribute. The caller can then either add the type -/// or use op's verifier to report errors. +/// not contain the `type` attribute. The caller can then add a type, report +/// the error or delegate the reporting to the op's verifier. ParseResult parseFunctionLikeOp(OpAsmParser *parser, OperationState *result, FuncTypeBuilder funcTypeBuilder); diff --git a/mlir/include/mlir/LLVMIR/LLVMDialect.h b/mlir/include/mlir/LLVMIR/LLVMDialect.h index 2f98828..55479f2 100644 --- a/mlir/include/mlir/LLVMIR/LLVMDialect.h +++ b/mlir/include/mlir/LLVMIR/LLVMDialect.h @@ -67,6 +67,11 @@ public: /// Array type utilities. LLVMType getArrayElementType(); + /// Function type utilities. + LLVMType getFunctionParamType(unsigned argIdx); + unsigned getFunctionNumParams(); + LLVMType getFunctionResultType(); + /// Pointer type utilities. LLVMType getPointerTo(unsigned addrSpace = 0); LLVMType getPointerElementTy(); diff --git a/mlir/include/mlir/LLVMIR/LLVMOps.td b/mlir/include/mlir/LLVMIR/LLVMOps.td index 5c01391..9031242 100644 --- a/mlir/include/mlir/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/LLVMIR/LLVMOps.td @@ -337,6 +337,10 @@ def LLVM_LLVMFuncOp : LLVM_ZeroResultOp<"func", }]; let verifier = [{ return ::verify(*this); }]; + let printer = [{ printLLVMFuncOp(p, *this); }]; + let parser = [{ + return impl::parseFunctionLikeOp(parser, result, buildLLVMFunctionType); + }]; } def LLVM_UndefOp : LLVM_OneResultOp<"undef", [NoSideEffect]>, diff --git a/mlir/lib/IR/Function.cpp b/mlir/lib/IR/Function.cpp index 106b670..af0edf9 100644 --- a/mlir/lib/IR/Function.cpp +++ b/mlir/lib/IR/Function.cpp @@ -78,9 +78,8 @@ void FuncOp::build(Builder *builder, OperationState *result, StringRef name, ParseResult FuncOp::parse(OpAsmParser *parser, OperationState *result) { return impl::parseFunctionLikeOp( parser, result, - [](Builder &builder, ArrayRef argTypes, ArrayRef results) { - return builder.getFunctionType(argTypes, results); - }); + [](Builder &builder, ArrayRef argTypes, ArrayRef results, + std::string &) { return builder.getFunctionType(argTypes, results); }); } void FuncOp::print(OpAsmPrinter *p) { diff --git a/mlir/lib/IR/FunctionSupport.cpp b/mlir/lib/IR/FunctionSupport.cpp index 081da75..92285e4 100644 --- a/mlir/lib/IR/FunctionSupport.cpp +++ b/mlir/lib/IR/FunctionSupport.cpp @@ -110,11 +110,17 @@ mlir::impl::parseFunctionLikeOp(OpAsmParser *parser, OperationState *result, result->attributes.back().second = builder.getStringAttr(nameAttr.getValue()); // Parse the function signature. + auto signatureLocation = parser->getCurrentLocation(); if (parseFunctionSignature(parser, entryArgs, argTypes, argAttrs, results)) return failure(); - if (auto type = funcTypeBuilder(builder, argTypes, results)) + std::string errorMessage; + if (auto type = funcTypeBuilder(builder, argTypes, results, errorMessage)) result->addAttribute(getTypeAttrName(), builder.getTypeAttr(type)); + else + return parser->emitError(signatureLocation) + << "failed to construct function type" + << (errorMessage.empty() ? "" : ": ") << errorMessage; // If function attributes are present, parse them. if (succeeded(parser->parseOptionalKeyword("attributes"))) diff --git a/mlir/lib/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/LLVMIR/IR/LLVMDialect.cpp index da46e8d..1315fdd 100644 --- a/mlir/lib/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/LLVMIR/IR/LLVMDialect.cpp @@ -703,7 +703,7 @@ static ParseResult parseConstantOp(OpAsmParser *parser, } //===----------------------------------------------------------------------===// -// Builder and verifier for LLVM::LLVMFuncOp. +// Builder, printer and verifier for LLVM::LLVMFuncOp. //===----------------------------------------------------------------------===// void LLVMFuncOp::build(Builder *builder, OperationState *result, StringRef name, @@ -726,6 +726,62 @@ void LLVMFuncOp::build(Builder *builder, OperationState *result, StringRef name, result->addAttribute(getArgAttrName(i, argAttrName), argDict); } +// Build an LLVM function type from the given lists of input and output types. +// Returns a null type if any of the types provided are non-LLVM types, or if +// there is more than one output type. +static Type buildLLVMFunctionType(Builder &b, ArrayRef inputs, + ArrayRef outputs, + std::string &errorMessage) { + if (outputs.size() > 1) { + errorMessage = "expected zero or one function result"; + return {}; + } + + // Convert inputs to LLVM types, exit early on error. + SmallVector llvmInputs; + for (auto t : inputs) { + auto llvmTy = t.dyn_cast(); + if (!llvmTy) { + errorMessage = "expected LLVM type for function arguments"; + return {}; + } + llvmInputs.push_back(llvmTy); + } + + // Get the dialect from the input type, if any exist. Look it up in the + // context otherwise. + LLVMDialect *dialect = + llvmInputs.empty() ? b.getContext()->getRegisteredDialect() + : &llvmInputs.front().getDialect(); + + // No output is denoted as "void" in LLVM type system. + LLVMType llvmOutput = outputs.empty() ? LLVMType::getVoidTy(dialect) + : outputs.front().dyn_cast(); + if (!llvmOutput) { + errorMessage = "expected LLVM type for function results"; + return {}; + } + return LLVMType::getFunctionTy(llvmOutput, llvmInputs, + /*isVarArg=*/false); +} + +// Print the LLVMFuncOp. Collects argument and result types and passes them +// to the trait printer. Drops "void" result since it cannot be parsed back. +static void printLLVMFuncOp(OpAsmPrinter *p, LLVMFuncOp op) { + LLVMType fnType = op.getType(); + SmallVector argTypes; + SmallVector resTypes; + argTypes.reserve(fnType.getFunctionNumParams()); + for (unsigned i = 0, e = fnType.getFunctionNumParams(); i < e; ++i) + argTypes.push_back(fnType.getFunctionParamType(i)); + + LLVMType returnType = fnType.getFunctionResultType(); + if (!returnType.getUnderlyingType()->isVoidTy()) + resTypes.push_back(returnType); + + impl::printFunctionLikeOp(p, op, argTypes, resTypes); +} + // Hook for OpTrait::FunctionLike, called after verifying that the 'type' // attribute is present. This can check for preconditions of the // getNumArguments hook not failing. @@ -914,6 +970,19 @@ LLVMType LLVMType::getArrayElementType() { return get(getContext(), getUnderlyingType()->getArrayElementType()); } +/// Function type utilities. +LLVMType LLVMType::getFunctionParamType(unsigned argIdx) { + return get(getContext(), getUnderlyingType()->getFunctionParamType(argIdx)); +} +unsigned LLVMType::getFunctionNumParams() { + return getUnderlyingType()->getFunctionNumParams(); +} +LLVMType LLVMType::getFunctionResultType() { + return get( + getContext(), + llvm::cast(getUnderlyingType())->getReturnType()); +} + /// Pointer type utilities. LLVMType LLVMType::getPointerTo(unsigned addrSpace) { // Lock access to the dialect as this may modify the LLVM context. diff --git a/mlir/test/LLVMIR/func.mlir b/mlir/test/LLVMIR/func.mlir index f056d37..f0dc3f5 100644 --- a/mlir/test/LLVMIR/func.mlir +++ b/mlir/test/LLVMIR/func.mlir @@ -1,31 +1,91 @@ -// RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s +// RUN: mlir-opt -split-input-file -verify-diagnostics %s | mlir-opt | FileCheck %s +// RUN: mlir-opt -split-input-file -verify-diagnostics -mlir-print-op-generic %s | FileCheck %s --check-prefix=GENERIC module { - // CHECK: "llvm.func" - // CHECK: sym_name = "foo" - // CHECK-SAME: type = !llvm<"void ()"> - // CHECK-SAME: () -> () + // GENERIC: "llvm.func" + // GENERIC: sym_name = "foo" + // GENERIC-SAME: type = !llvm<"void ()"> + // GENERIC-SAME: () -> () + // CHECK: llvm.func @foo() "llvm.func"() ({ }) {sym_name = "foo", type = !llvm<"void ()">} : () -> () - // CHECK: "llvm.func" - // CHECK: sym_name = "bar" - // CHECK-SAME: type = !llvm<"i64 (i64, i64)"> - // CHECK-SAME: () -> () + // GENERIC: "llvm.func" + // GENERIC: sym_name = "bar" + // GENERIC-SAME: type = !llvm<"i64 (i64, i64)"> + // GENERIC-SAME: () -> () + // CHECK: llvm.func @bar(!llvm.i64, !llvm.i64) -> !llvm.i64 "llvm.func"() ({ }) {sym_name = "bar", type = !llvm<"i64 (i64, i64)">} : () -> () - // CHECK: "llvm.func" + // GENERIC: "llvm.func" + // CHECK: llvm.func @baz(%{{.*}}: !llvm.i64) -> !llvm.i64 "llvm.func"() ({ - // CHECK: ^bb0 + // GENERIC: ^bb0 ^bb0(%arg0: !llvm.i64): - // CHECK: llvm.return + // GENERIC: llvm.return llvm.return %arg0 : !llvm.i64 - // CHECK: sym_name = "baz" - // CHECK-SAME: type = !llvm<"i64 (i64)"> - // CHECK-SAME: () -> () + // GENERIC: sym_name = "baz" + // GENERIC-SAME: type = !llvm<"i64 (i64)"> + // GENERIC-SAME: () -> () }) {sym_name = "baz", type = !llvm<"i64 (i64)">} : () -> () + + // CHECK: llvm.func @qux(!llvm<"i64*"> {llvm.noalias = true}, !llvm.i64) + // CHECK-NEXT: attributes {xxx = {yyy = 42 : i64}} + "llvm.func"() ({ + }) {sym_name = "qux", type = !llvm<"void (i64*, i64)">, + arg0 = {llvm.noalias = true}, xxx = {yyy = 42}} : () -> () + + // CHECK: llvm.func @roundtrip1() + llvm.func @roundtrip1() + + // CHECK: llvm.func @roundtrip2(!llvm.i64, !llvm.float) -> !llvm.double + llvm.func @roundtrip2(!llvm.i64, !llvm.float) -> !llvm.double + + // CHECK: llvm.func @roundtrip3(!llvm.i32, !llvm.i1) + llvm.func @roundtrip3(%a: !llvm.i32, %b: !llvm.i1) + + // CHECK: llvm.func @roundtrip4(%{{.*}}: !llvm.i32, %{{.*}}: !llvm.i1) { + llvm.func @roundtrip4(%a: !llvm.i32, %b: !llvm.i1) { + llvm.return + } + + // CHECK: llvm.func @roundtrip5() + // CHECK-NEXT: attributes {baz = 42 : i64, foo = "bar"} + llvm.func @roundtrip5() attributes {foo = "bar", baz = 42} + + // CHECK: llvm.func @roundtrip6() + // CHECK-NEXT: attributes {baz = 42 : i64, foo = "bar"} + llvm.func @roundtrip6() attributes {foo = "bar", baz = 42} { + llvm.return + } + + // CHECK: llvm.func @roundtrip7() { + llvm.func @roundtrip7() attributes {} { + llvm.return + } + + // CHECK: llvm.func @roundtrip8() -> !llvm.i32 + llvm.func @roundtrip8() -> !llvm.i32 attributes {} + + // CHECK: llvm.func @roundtrip9(!llvm<"i32*"> {llvm.noalias = true}) + llvm.func @roundtrip9(!llvm<"i32*"> {llvm.noalias = true}) + + // CHECK: llvm.func @roundtrip10(!llvm<"i32*"> {llvm.noalias = true}) + llvm.func @roundtrip10(%arg0: !llvm<"i32*"> {llvm.noalias = true}) + + // CHECK: llvm.func @roundtrip11(%{{.*}}: !llvm<"i32*"> {llvm.noalias = true}) { + llvm.func @roundtrip11(%arg0: !llvm<"i32*"> {llvm.noalias = true}) { + llvm.return + } + + // CHECK: llvm.func @roundtrip12(%{{.*}}: !llvm<"i32*"> {llvm.noalias = true}) + // CHECK-NEXT: attributes {foo = 42 : i32} + llvm.func @roundtrip12(%arg0: !llvm<"i32*"> {llvm.noalias = true}) + attributes {foo = 42 : i32} { + llvm.return + } } // ----- @@ -85,3 +145,24 @@ module { llvm.return }) {sym_name = "wrong_arg_number", type = !llvm<"void (i64)">} : () -> () } + +// ----- + +module { + // expected-error@+1 {{failed to construct function type: expected LLVM type for function arguments}} + llvm.func @foo(i64) +} + +// ----- + +module { + // expected-error@+1 {{failed to construct function type: expected LLVM type for function results}} + llvm.func @foo() -> i64 +} + +// ----- + +module { + // expected-error@+1 {{failed to construct function type: expected zero or one function result}} + llvm.func @foo() -> (!llvm.i64, !llvm.i64) +} -- 2.7.4