From 00ecacca7d90f96a1d54bc3fa38986fdd64e4c72 Mon Sep 17 00:00:00 2001 From: Chris Bieneman Date: Fri, 2 Sep 2022 14:32:24 -0500 Subject: [PATCH] [HLSL] Generate buffer subscript operators In HLSL buffer types support array subscripting syntax for loads and stores. This change fleshes out the subscript operators to become array accesses on the underlying handle pointer. This will allow LLVM optimization passes to optimize resource accesses the same way any other memory access would be optimized. Reviewed By: aaron.ballman Differential Revision: https://reviews.llvm.org/D131268 --- clang/lib/Sema/HLSLExternalSemaSource.cpp | 111 ++++++++++++++++++++-- clang/lib/Sema/SemaType.cpp | 6 +- clang/test/AST/HLSL/RWBuffer-AST.hlsl | 23 ++++- clang/test/CodeGenHLSL/buffer-array-operator.hlsl | 30 ++++++ 4 files changed, 158 insertions(+), 12 deletions(-) create mode 100644 clang/test/CodeGenHLSL/buffer-array-operator.hlsl diff --git a/clang/lib/Sema/HLSLExternalSemaSource.cpp b/clang/lib/Sema/HLSLExternalSemaSource.cpp index fe963fd..ee3aa4d 100644 --- a/clang/lib/Sema/HLSLExternalSemaSource.cpp +++ b/clang/lib/Sema/HLSLExternalSemaSource.cpp @@ -104,7 +104,14 @@ struct BuiltinTypeDeclBuilder { BuiltinTypeDeclBuilder & addHandleMember(AccessSpecifier Access = AccessSpecifier::AS_private) { - return addMemberVariable("h", Record->getASTContext().VoidPtrTy, Access); + QualType Ty = Record->getASTContext().VoidPtrTy; + if (Template) { + if (const auto *TTD = dyn_cast( + Template->getTemplateParameters()->getParam(0))) + Ty = Record->getASTContext().getPointerType( + QualType(TTD->getTypeForDecl(), 0)); + } + return addMemberVariable("h", Ty, Access); } BuiltinTypeDeclBuilder & @@ -158,15 +165,25 @@ struct BuiltinTypeDeclBuilder { lookupBuiltinFunction(AST, S, "__builtin_hlsl_create_handle"); Expr *RCExpr = emitResourceClassExpr(AST, RC); - CallExpr *Call = - CallExpr::Create(AST, Fn, {RCExpr}, AST.VoidPtrTy, VK_PRValue, - SourceLocation(), FPOptionsOverride()); + Expr *Call = CallExpr::Create(AST, Fn, {RCExpr}, AST.VoidPtrTy, VK_PRValue, + SourceLocation(), FPOptionsOverride()); CXXThisExpr *This = new (AST) CXXThisExpr(SourceLocation(), Constructor->getThisType(), true); - MemberExpr *Handle = MemberExpr::CreateImplicit( - AST, This, true, Fields["h"], Fields["h"]->getType(), VK_LValue, - OK_Ordinary); + Expr *Handle = MemberExpr::CreateImplicit(AST, This, true, Fields["h"], + Fields["h"]->getType(), VK_LValue, + OK_Ordinary); + + // If the handle isn't a void pointer, cast the builtin result to the + // correct type. + if (Handle->getType().getCanonicalType() != AST.VoidPtrTy) { + Call = CXXStaticCastExpr::Create( + AST, Handle->getType(), VK_PRValue, CK_Dependent, Call, nullptr, + AST.getTrivialTypeSourceInfo(Handle->getType(), SourceLocation()), + FPOptionsOverride(), SourceLocation(), SourceLocation(), + SourceRange()); + } + BinaryOperator *Assign = BinaryOperator::Create( AST, Handle, Call, BO_Assign, Handle->getType(), VK_LValue, OK_Ordinary, SourceLocation(), FPOptionsOverride()); @@ -179,6 +196,85 @@ struct BuiltinTypeDeclBuilder { return *this; } + BuiltinTypeDeclBuilder &addArraySubscriptOperators() { + addArraySubscriptOperator(true); + addArraySubscriptOperator(false); + return *this; + } + + BuiltinTypeDeclBuilder &addArraySubscriptOperator(bool IsConst) { + assert(Fields.count("h") > 0 && + "Subscript operator must be added after the handle."); + + FieldDecl *Handle = Fields["h"]; + ASTContext &AST = Record->getASTContext(); + + assert(Handle->getType().getCanonicalType() != AST.VoidPtrTy && + "Not yet supported for void pointer handles."); + + QualType ElemTy = + QualType(Handle->getType()->getPointeeOrArrayElementType(), 0); + QualType ReturnTy = ElemTy; + + FunctionProtoType::ExtProtoInfo ExtInfo; + + // Subscript operators return references to elements, const makes the + // reference and method const so that the underlying data is not mutable. + ReturnTy = AST.getLValueReferenceType(ReturnTy); + if (IsConst) { + ExtInfo.TypeQuals.addConst(); + ReturnTy.addConst(); + } + + QualType MethodTy = + AST.getFunctionType(ReturnTy, {AST.UnsignedIntTy}, ExtInfo); + auto *TSInfo = AST.getTrivialTypeSourceInfo(MethodTy, SourceLocation()); + auto *MethodDecl = CXXMethodDecl::Create( + AST, Record, SourceLocation(), + DeclarationNameInfo( + AST.DeclarationNames.getCXXOperatorName(OO_Subscript), + SourceLocation()), + MethodTy, TSInfo, SC_None, false, false, ConstexprSpecKind::Unspecified, + SourceLocation()); + + IdentifierInfo &II = AST.Idents.get("Idx", tok::TokenKind::identifier); + auto *IdxParam = ParmVarDecl::Create( + AST, MethodDecl->getDeclContext(), SourceLocation(), SourceLocation(), + &II, AST.UnsignedIntTy, + AST.getTrivialTypeSourceInfo(AST.UnsignedIntTy, SourceLocation()), + SC_None, nullptr); + MethodDecl->setParams({IdxParam}); + + // Also add the parameter to the function prototype. + auto FnProtoLoc = TSInfo->getTypeLoc().getAs(); + FnProtoLoc.setParam(0, IdxParam); + + auto *This = new (AST) + CXXThisExpr(SourceLocation(), MethodDecl->getThisType(), true); + auto *HandleAccess = MemberExpr::CreateImplicit( + AST, This, true, Handle, Handle->getType(), VK_LValue, OK_Ordinary); + + auto *IndexExpr = DeclRefExpr::Create( + AST, NestedNameSpecifierLoc(), SourceLocation(), IdxParam, false, + DeclarationNameInfo(IdxParam->getDeclName(), SourceLocation()), + AST.UnsignedIntTy, VK_PRValue); + + auto *Array = + new (AST) ArraySubscriptExpr(HandleAccess, IndexExpr, ElemTy, VK_LValue, + OK_Ordinary, SourceLocation()); + + auto *Return = ReturnStmt::Create(AST, SourceLocation(), Array, nullptr); + + MethodDecl->setBody(CompoundStmt::Create(AST, {Return}, FPOptionsOverride(), + SourceLocation(), + SourceLocation())); + MethodDecl->setLexicalDeclContext(Record); + MethodDecl->setAccess(AccessSpecifier::AS_public); + Record->addDecl(MethodDecl); + + return *this; + } + BuiltinTypeDeclBuilder &startDefinition() { Record->startDefinition(); return *this; @@ -368,6 +464,7 @@ void HLSLExternalSemaSource::completeBufferType(CXXRecordDecl *Record) { BuiltinTypeDeclBuilder(Record) .addHandleMember() .addDefaultHandleConstructor(*SemaPtr, ResourceClass::UAV) + .addArraySubscriptOperators() .annotateResourceClass(HLSLResourceAttr::UAV) .completeDefinition(); } diff --git a/clang/lib/Sema/SemaType.cpp b/clang/lib/Sema/SemaType.cpp index 313a534..e87b59d 100644 --- a/clang/lib/Sema/SemaType.cpp +++ b/clang/lib/Sema/SemaType.cpp @@ -2174,7 +2174,7 @@ QualType Sema::BuildPointerType(QualType T, return QualType(); } - if (getLangOpts().HLSL) { + if (getLangOpts().HLSL && Loc.isValid()) { Diag(Loc, diag::err_hlsl_pointers_unsupported) << 0; return QualType(); } @@ -2244,7 +2244,7 @@ QualType Sema::BuildReferenceType(QualType T, bool SpelledAsLValue, return QualType(); } - if (getLangOpts().HLSL) { + if (getLangOpts().HLSL && Loc.isValid()) { Diag(Loc, diag::err_hlsl_pointers_unsupported) << 1; return QualType(); } @@ -3008,7 +3008,7 @@ QualType Sema::BuildMemberPointerType(QualType T, QualType Class, return QualType(); } - if (getLangOpts().HLSL) { + if (getLangOpts().HLSL && Loc.isValid()) { Diag(Loc, diag::err_hlsl_pointers_unsupported) << 0; return QualType(); } diff --git a/clang/test/AST/HLSL/RWBuffer-AST.hlsl b/clang/test/AST/HLSL/RWBuffer-AST.hlsl index c9cbd73..193ef67 100644 --- a/clang/test/AST/HLSL/RWBuffer-AST.hlsl +++ b/clang/test/AST/HLSL/RWBuffer-AST.hlsl @@ -39,11 +39,30 @@ RWBuffer Buffer; // CHECK: FinalAttr 0x{{[0-9A-Fa-f]+}} <> Implicit final // CHECK-NEXT: HLSLResourceAttr 0x{{[0-9A-Fa-f]+}} <> Implicit UAV -// CHECK-NEXT: FieldDecl 0x{{[0-9A-Fa-f]+}} <> implicit h 'void *' +// CHECK-NEXT: FieldDecl 0x{{[0-9A-Fa-f]+}} <> implicit h 'element_type *' + +// CHECK: CXXMethodDecl 0x{{[0-9A-Fa-f]+}} <> operator[] 'element_type &const (unsigned int) const' +// CHECK-NEXT: ParmVarDecl 0x{{[0-9A-Fa-f]+}} <> Idx 'unsigned int' +// CHECK-NEXT: CompoundStmt 0x{{[0-9A-Fa-f]+}} <> +// CHECK-NEXT: ReturnStmt 0x{{[0-9A-Fa-f]+}} <> +// CHECK-NEXT: ArraySubscriptExpr 0x{{[0-9A-Fa-f]+}} <> 'element_type' lvalue +// CHECK-NEXT: MemberExpr 0x{{[0-9A-Fa-f]+}} <> 'element_type *' lvalue ->h 0x{{[0-9A-Fa-f]+}} +// CHECK-NEXT: CXXThisExpr 0x{{[0-9A-Fa-f]+}} <> 'const RWBuffer *' implicit this +// CHECK-NEXT: DeclRefExpr 0x{{[0-9A-Fa-f]+}} <> 'unsigned int' ParmVar 0x{{[0-9A-Fa-f]+}} 'Idx' 'unsigned int' + +// CHECK-NEXT: CXXMethodDecl 0x{{[0-9A-Fa-f]+}} <> operator[] 'element_type &(unsigned int)' +// CHECK-NEXT: ParmVarDecl 0x{{[0-9A-Fa-f]+}} <> Idx 'unsigned int' +// CHECK-NEXT: CompoundStmt 0x{{[0-9A-Fa-f]+}} <> +// CHECK-NEXT: ReturnStmt 0x{{[0-9A-Fa-f]+}} <> +// CHECK-NEXT: ArraySubscriptExpr 0x{{[0-9A-Fa-f]+}} <> 'element_type' lvalue +// CHECK-NEXT: MemberExpr 0x{{[0-9A-Fa-f]+}} <> 'element_type *' lvalue ->h 0x{{[0-9A-Fa-f]+}} +// CHECK-NEXT: CXXThisExpr 0x{{[0-9A-Fa-f]+}} <> 'RWBuffer *' implicit this +// CHECK-NEXT: DeclRefExpr 0x{{[0-9A-Fa-f]+}} <> 'unsigned int' ParmVar 0x{{[0-9A-Fa-f]+}} 'Idx' 'unsigned int' + // CHECK: ClassTemplateSpecializationDecl 0x{{[0-9A-Fa-f]+}} <> class RWBuffer definition // CHECK: TemplateArgument type 'float' // CHECK-NEXT: BuiltinType 0x{{[0-9A-Fa-f]+}} 'float' // CHECK-NEXT: FinalAttr 0x{{[0-9A-Fa-f]+}} <> Implicit final // CHECK-NEXT: HLSLResourceAttr 0x{{[0-9A-Fa-f]+}} <> Implicit UAV -// CHECK-NEXT: FieldDecl 0x{{[0-9A-Fa-f]+}} <> implicit referenced h 'void *' +// CHECK-NEXT: FieldDecl 0x{{[0-9A-Fa-f]+}} <> implicit referenced h 'float *' diff --git a/clang/test/CodeGenHLSL/buffer-array-operator.hlsl b/clang/test/CodeGenHLSL/buffer-array-operator.hlsl new file mode 100644 index 0000000..6bcb061 --- /dev/null +++ b/clang/test/CodeGenHLSL/buffer-array-operator.hlsl @@ -0,0 +1,30 @@ +// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -x hlsl -emit-llvm -disable-llvm-passes -o - %s | FileCheck %s + +const RWBuffer In; +RWBuffer Out; + +void fn(int Idx) { + Out[Idx] = In[Idx]; +} + +// This test is intended to verify reasonable code generation of the subscript +// operator. In this test case we should be generating both the const and +// non-const operators so we verify both cases. + +// Non-const comes first. +// CHECK: ptr @"??A?$RWBuffer@M@hlsl@@QBAAAMI@Z" +// CHECK: %this1 = load ptr, ptr %this.addr, align 4 +// CHECK-NEXT: %h = getelementptr inbounds %"class.hlsl::RWBuffer", ptr %this1, i32 0, i32 0 +// CHECK-NEXT: %0 = load ptr, ptr %h, align 4 +// CHECK-NEXT: %1 = load i32, ptr %Idx.addr, align 4 +// CHECK-NEXT: %arrayidx = getelementptr inbounds float, ptr %0, i32 %1 +// CHECK-NEXT: ret ptr %arrayidx + +// Const comes next, and returns the pointer instead of the value. +// CHECK: ptr @"??A?$RWBuffer@M@hlsl@@QAAAAMI@Z" +// CHECK: %this1 = load ptr, ptr %this.addr, align 4 +// CHECK-NEXT: %h = getelementptr inbounds %"class.hlsl::RWBuffer", ptr %this1, i32 0, i32 0 +// CHECK-NEXT: %0 = load ptr, ptr %h, align 4 +// CHECK-NEXT: %1 = load i32, ptr %Idx.addr, align 4 +// CHECK-NEXT: %arrayidx = getelementptr inbounds float, ptr %0, i32 %1 +// CHECK-NEXT: ret ptr %arrayidx -- 2.7.4