From 3ad167329aafde02e70b0327c0488602111a81ee Mon Sep 17 00:00:00 2001 From: =?utf8?q?Timm=20B=C3=A4der?= Date: Wed, 11 Jan 2023 12:12:52 +0100 Subject: [PATCH] [clang][Interp] Implement function pointers Differential Revision: https://reviews.llvm.org/D141472 --- clang/lib/AST/Interp/ByteCodeExprGen.cpp | 78 +++++++++++++++++++------------- clang/lib/AST/Interp/Context.cpp | 6 ++- clang/lib/AST/Interp/Descriptor.cpp | 1 + clang/lib/AST/Interp/FunctionPointer.h | 57 +++++++++++++++++++++++ clang/lib/AST/Interp/Interp.h | 17 +++++++ clang/lib/AST/Interp/InterpStack.h | 3 ++ clang/lib/AST/Interp/Opcodes.td | 18 +++++++- clang/lib/AST/Interp/PrimType.cpp | 1 + clang/lib/AST/Interp/PrimType.h | 6 +++ clang/test/AST/Interp/functions.cpp | 63 ++++++++++++++++++++++++++ 10 files changed, 216 insertions(+), 34 deletions(-) create mode 100644 clang/lib/AST/Interp/FunctionPointer.h diff --git a/clang/lib/AST/Interp/ByteCodeExprGen.cpp b/clang/lib/AST/Interp/ByteCodeExprGen.cpp index fff2425..c6cf7f7 100644 --- a/clang/lib/AST/Interp/ByteCodeExprGen.cpp +++ b/clang/lib/AST/Interp/ByteCodeExprGen.cpp @@ -131,6 +131,11 @@ bool ByteCodeExprGen::VisitCastExpr(const CastExpr *CE) { return this->emitCastFloatingIntegral(*ToT, CE); } + case CK_NullToPointer: + if (DiscardResult) + return true; + return this->emitNull(classifyPrim(CE->getType()), CE); + case CK_ArrayToPointerDecay: case CK_AtomicToNonAtomic: case CK_ConstructorConversion: @@ -138,7 +143,6 @@ bool ByteCodeExprGen::VisitCastExpr(const CastExpr *CE) { case CK_NonAtomicToAtomic: case CK_NoOp: case CK_UserDefinedConversion: - case CK_NullToPointer: return this->visit(SubExpr); case CK_IntegralToBoolean: @@ -400,10 +404,7 @@ bool ByteCodeExprGen::VisitImplicitValueInitExpr(const ImplicitValueIni if (!T) return false; - if (E->getType()->isPointerType()) - return this->emitNullPtr(E); - - return this->emitZero(*T, E); + return this->visitZeroInitializer(*T, E); } template @@ -950,6 +951,8 @@ bool ByteCodeExprGen::visitZeroInitializer(PrimType T, const Expr *E) { return this->emitZeroUint64(E); case PT_Ptr: return this->emitNullPtr(E); + case PT_FnPtr: + return this->emitNullFnPtr(E); case PT_Float: assert(false); } @@ -1116,6 +1119,7 @@ bool ByteCodeExprGen::emitConst(T Value, const Expr *E) { case PT_Bool: return this->emitConstBool(Value, E); case PT_Ptr: + case PT_FnPtr: case PT_Float: llvm_unreachable("Invalid integral type"); break; @@ -1606,8 +1610,27 @@ bool ByteCodeExprGen::VisitCallExpr(const CallExpr *E) { if (E->getBuiltinCallee()) return VisitBuiltinCallExpr(E); - const Decl *Callee = E->getCalleeDecl(); - if (const auto *FuncDecl = dyn_cast_if_present(Callee)) { + QualType ReturnType = E->getCallReturnType(Ctx.getASTContext()); + std::optional T = classify(ReturnType); + bool HasRVO = !ReturnType->isVoidType() && !T; + + if (HasRVO && DiscardResult) { + // If we need to discard the return value but the function returns its + // value via an RVO pointer, we need to create one such pointer just + // for this call. + if (std::optional LocalIndex = allocateLocal(E)) { + if (!this->emitGetPtrLocal(*LocalIndex, E)) + return false; + } + } + + // Put arguments on the stack. + for (const auto *Arg : E->arguments()) { + if (!this->visit(Arg)) + return false; + } + + if (const FunctionDecl *FuncDecl = E->getDirectCallee()) { const Function *Func = getFunction(FuncDecl); if (!Func) return false; @@ -1619,24 +1642,7 @@ bool ByteCodeExprGen::VisitCallExpr(const CallExpr *E) { if (Func->isFullyCompiled() && !Func->isConstexpr()) return false; - QualType ReturnType = E->getCallReturnType(Ctx.getASTContext()); - std::optional T = classify(ReturnType); - - if (Func->hasRVO() && DiscardResult) { - // If we need to discard the return value but the function returns its - // value via an RVO pointer, we need to create one such pointer just - // for this call. - if (std::optional LocalIndex = allocateLocal(E)) { - if (!this->emitGetPtrLocal(*LocalIndex, E)) - return false; - } - } - - // Put arguments on the stack. - for (const auto *Arg : E->arguments()) { - if (!this->visit(Arg)) - return false; - } + assert(HasRVO == Func->hasRVO()); // In any case call the function. The return value will end up on the stack // and if the function has RVO, we already have the pointer on the stack to @@ -1644,15 +1650,22 @@ bool ByteCodeExprGen::VisitCallExpr(const CallExpr *E) { if (!this->emitCall(Func, E)) return false; - if (DiscardResult && !ReturnType->isVoidType() && T) - return this->emitPop(*T, E); - - return true; } else { - assert(false && "We don't support non-FunctionDecl callees right now."); + // Indirect call. Visit the callee, which will leave a FunctionPointer on + // the stack. Cleanup of the returned value if necessary will be done after + // the function call completed. + if (!this->visit(E->getCallee())) + return false; + + if (!this->emitCallPtr(E)) + return false; } - return false; + // Cleanup for discarded return values. + if (DiscardResult && !ReturnType->isVoidType() && T) + return this->emitPop(*T, E); + + return true; } template @@ -1846,6 +1859,9 @@ bool ByteCodeExprGen::VisitDeclRefExpr(const DeclRefExpr *E) { return this->emitConst(ECD->getInitVal(), E); } else if (const auto *BD = dyn_cast(Decl)) { return this->visit(BD->getBinding()); + } else if (const auto *FuncDecl = dyn_cast(Decl)) { + const Function *F = getFunction(FuncDecl); + return F && this->emitGetFnPtr(F, E); } return false; diff --git a/clang/lib/AST/Interp/Context.cpp b/clang/lib/AST/Interp/Context.cpp index dcf41a5..6ede05e 100644 --- a/clang/lib/AST/Interp/Context.cpp +++ b/clang/lib/AST/Interp/Context.cpp @@ -78,9 +78,11 @@ bool Context::evaluateAsInitializer(State &Parent, const VarDecl *VD, const LangOptions &Context::getLangOpts() const { return Ctx.getLangOpts(); } std::optional Context::classify(QualType T) const { - if (T->isReferenceType() || T->isPointerType()) { + if (T->isFunctionPointerType() || T->isFunctionReferenceType()) + return PT_FnPtr; + + if (T->isReferenceType() || T->isPointerType()) return PT_Ptr; - } if (T->isBooleanType()) return PT_Bool; diff --git a/clang/lib/AST/Interp/Descriptor.cpp b/clang/lib/AST/Interp/Descriptor.cpp index 212311c..31554dd 100644 --- a/clang/lib/AST/Interp/Descriptor.cpp +++ b/clang/lib/AST/Interp/Descriptor.cpp @@ -9,6 +9,7 @@ #include "Descriptor.h" #include "Boolean.h" #include "Floating.h" +#include "FunctionPointer.h" #include "Pointer.h" #include "PrimType.h" #include "Record.h" diff --git a/clang/lib/AST/Interp/FunctionPointer.h b/clang/lib/AST/Interp/FunctionPointer.h new file mode 100644 index 0000000..2d449bd --- /dev/null +++ b/clang/lib/AST/Interp/FunctionPointer.h @@ -0,0 +1,57 @@ + + +#ifndef LLVM_CLANG_AST_INTERP_FUNCTION_POINTER_H +#define LLVM_CLANG_AST_INTERP_FUNCTION_POINTER_H + +#include "Function.h" +#include "Primitives.h" +#include "clang/AST/APValue.h" + +namespace clang { +namespace interp { + +class FunctionPointer final { +private: + const Function *Func; + +public: + FunctionPointer() : Func(nullptr) {} + FunctionPointer(const Function *Func) : Func(Func) { assert(Func); } + + const Function *getFunction() const { return Func; } + + APValue toAPValue() const { + if (!Func) + return APValue(static_cast(nullptr), CharUnits::Zero(), {}, + /*OnePastTheEnd=*/false, /*IsNull=*/true); + + return APValue(Func->getDecl(), CharUnits::Zero(), {}, + /*OnePastTheEnd=*/false, /*IsNull=*/false); + } + + void print(llvm::raw_ostream &OS) const { + OS << "FnPtr("; + if (Func) + OS << Func->getName(); + else + OS << "nullptr"; + OS << ")"; + } + + ComparisonCategoryResult compare(const FunctionPointer &RHS) const { + if (Func == RHS.Func) + return ComparisonCategoryResult::Equal; + return ComparisonCategoryResult::Unordered; + } +}; + +inline llvm::raw_ostream &operator<<(llvm::raw_ostream &OS, + FunctionPointer FP) { + FP.print(OS); + return OS; +} + +} // namespace interp +} // namespace clang + +#endif diff --git a/clang/lib/AST/Interp/Interp.h b/clang/lib/AST/Interp/Interp.h index 98561f0..bb34737 100644 --- a/clang/lib/AST/Interp/Interp.h +++ b/clang/lib/AST/Interp/Interp.h @@ -16,6 +16,7 @@ #include "Boolean.h" #include "Floating.h" #include "Function.h" +#include "FunctionPointer.h" #include "InterpFrame.h" #include "InterpStack.h" #include "InterpState.h" @@ -1538,6 +1539,22 @@ inline bool CallBI(InterpState &S, CodePtr &PC, const Function *Func) { return false; } +inline bool CallPtr(InterpState &S, CodePtr &PC) { + const FunctionPointer &FuncPtr = S.Stk.pop(); + + const Function *F = FuncPtr.getFunction(); + if (!F || !F->isConstexpr()) + return false; + + return Call(S, PC, F); +} + +inline bool GetFnPtr(InterpState &S, CodePtr &PC, const Function *Func) { + assert(Func); + S.Stk.push(Func); + return true; +} + //===----------------------------------------------------------------------===// // Read opcode arguments //===----------------------------------------------------------------------===// diff --git a/clang/lib/AST/Interp/InterpStack.h b/clang/lib/AST/Interp/InterpStack.h index 4987b2c..e625ffd 100644 --- a/clang/lib/AST/Interp/InterpStack.h +++ b/clang/lib/AST/Interp/InterpStack.h @@ -13,6 +13,7 @@ #ifndef LLVM_CLANG_AST_INTERP_INTERPSTACK_H #define LLVM_CLANG_AST_INTERP_INTERPSTACK_H +#include "FunctionPointer.h" #include "PrimType.h" #include #include @@ -162,6 +163,8 @@ private: return PT_Uint64; else if constexpr (std::is_same_v) return PT_Float; + else if constexpr (std::is_same_v) + return PT_FnPtr; llvm_unreachable("unknown type push()'ed into InterpStack"); } diff --git a/clang/lib/AST/Interp/Opcodes.td b/clang/lib/AST/Interp/Opcodes.td index 80d5c652..f3662dc 100644 --- a/clang/lib/AST/Interp/Opcodes.td +++ b/clang/lib/AST/Interp/Opcodes.td @@ -27,6 +27,7 @@ def Sint64 : Type; def Uint64 : Type; def Float : Type; def Ptr : Type; +def FnPtr : Type; //===----------------------------------------------------------------------===// // Types transferred to the interpreter. @@ -77,7 +78,7 @@ def AluTypeClass : TypeClass { } def PtrTypeClass : TypeClass { - let Types = [Ptr]; + let Types = [Ptr, FnPtr]; } def BoolTypeClass : TypeClass { @@ -187,6 +188,12 @@ def CallBI : Opcode { let ChangesPC = 1; } +def CallPtr : Opcode { + let Args = []; + let Types = []; + let ChangesPC = 1; +} + //===----------------------------------------------------------------------===// // Frame management //===----------------------------------------------------------------------===// @@ -228,6 +235,7 @@ def Zero : Opcode { // [] -> [Pointer] def Null : Opcode { let Types = [PtrTypeClass]; + let HasGroup = 1; } //===----------------------------------------------------------------------===// @@ -448,6 +456,14 @@ def DecPtr : Opcode { } //===----------------------------------------------------------------------===// +// Function pointers. +//===----------------------------------------------------------------------===// +def GetFnPtr : Opcode { + let Args = [ArgFunction]; +} + + +//===----------------------------------------------------------------------===// // Binary operators. //===----------------------------------------------------------------------===// diff --git a/clang/lib/AST/Interp/PrimType.cpp b/clang/lib/AST/Interp/PrimType.cpp index da07b6f..a9b5d8e 100644 --- a/clang/lib/AST/Interp/PrimType.cpp +++ b/clang/lib/AST/Interp/PrimType.cpp @@ -9,6 +9,7 @@ #include "PrimType.h" #include "Boolean.h" #include "Floating.h" +#include "FunctionPointer.h" #include "Pointer.h" using namespace clang; diff --git a/clang/lib/AST/Interp/PrimType.h b/clang/lib/AST/Interp/PrimType.h index db9d8c3..91311cf 100644 --- a/clang/lib/AST/Interp/PrimType.h +++ b/clang/lib/AST/Interp/PrimType.h @@ -24,6 +24,7 @@ namespace interp { class Pointer; class Boolean; class Floating; +class FunctionPointer; /// Enumeration of the primitive types of the VM. enum PrimType : unsigned { @@ -38,6 +39,7 @@ enum PrimType : unsigned { PT_Bool, PT_Float, PT_Ptr, + PT_FnPtr, }; /// Mapping from primitive types to their representation. @@ -53,6 +55,9 @@ template <> struct PrimConv { using T = Integral<64, false>; }; template <> struct PrimConv { using T = Floating; }; template <> struct PrimConv { using T = Boolean; }; template <> struct PrimConv { using T = Pointer; }; +template <> struct PrimConv { + using T = FunctionPointer; +}; /// Returns the size of a primitive type in bytes. size_t primSize(PrimType Type); @@ -90,6 +95,7 @@ static inline bool aligned(const void *P) { TYPE_SWITCH_CASE(PT_Float, B) \ TYPE_SWITCH_CASE(PT_Bool, B) \ TYPE_SWITCH_CASE(PT_Ptr, B) \ + TYPE_SWITCH_CASE(PT_FnPtr, B) \ } \ } while (0) #define COMPOSITE_TYPE_SWITCH(Expr, B, D) \ diff --git a/clang/test/AST/Interp/functions.cpp b/clang/test/AST/Interp/functions.cpp index 4fb5d3d..48862d3 100644 --- a/clang/test/AST/Interp/functions.cpp +++ b/clang/test/AST/Interp/functions.cpp @@ -99,3 +99,66 @@ constexpr void invalid2() { huh(); // expected-error {{use of undeclared identifier}} \ // ref-error {{use of undeclared identifier}} } + +namespace FunctionPointers { + constexpr int add(int a, int b) { + return a + b; + } + + struct S { int a; }; + constexpr S getS() { + return S{12}; + } + + constexpr int applyBinOp(int a, int b, int (*op)(int, int)) { + return op(a, b); + } + static_assert(applyBinOp(1, 2, add) == 3, ""); + + constexpr int ignoreReturnValue() { + int (*foo)(int, int) = add; + + foo(1, 2); + return 1; + } + static_assert(ignoreReturnValue() == 1, ""); + + constexpr int createS(S (*gimme)()) { + gimme(); // Ignored return value + return gimme().a; + } + static_assert(createS(getS) == 12, ""); + +namespace FunctionReturnType { + typedef int (*ptr)(int*); + typedef ptr (*pm)(); + + constexpr int fun1(int* y) { + return *y + 10; + } + constexpr ptr fun() { + return &fun1; + } + static_assert(fun() == nullptr, ""); // expected-error {{static assertion failed}} \ + // ref-error {{static assertion failed}} + + constexpr int foo() { + int (*f)(int *) = fun(); + int m = 0; + + m = f(&m); + + return m; + } + static_assert(foo() == 10); + + struct S { + int i; + void (*fp)(); + }; + + constexpr S s{ 12 }; + static_assert(s.fp == nullptr); // zero-initialized function pointer. +} + +} -- 2.7.4