currentFunctionType = new TType(EbtVoid);
functionReturnsValue = false;
- inEntryPoint = function.getName().compare(intermediate.getEntryPointName().c_str()) == 0;
- if (inEntryPoint) {
- intermediate.setEntryPointMangledName(function.getMangledName().c_str());
- intermediate.incrementEntryPointCount();
- remapEntryPointIO(function);
- if (entryPointOutput) {
- if (shouldFlatten(entryPointOutput->getType()))
- flatten(loc, *entryPointOutput);
- if (shouldSplit(entryPointOutput->getType()))
- split(*entryPointOutput);
- assignLocations(*entryPointOutput);
- }
- } else
- remapNonEntryPointIO(function);
+ // Entry points need different I/O and other handling, transform it so the
+ // rest of this function doesn't care.
+ transformEntryPoint(loc, function, attributes);
// Insert the $Global constant buffer.
// TODO: this design fails if new members are declared between function definitions.
controlFlowNestingLevel = 0;
postEntryPointReturn = false;
- // Handle function attributes
- if (inEntryPoint) {
- const TIntermAggregate* numThreads = attributes[EatNumThreads];
- if (numThreads != nullptr) {
- const TIntermSequence& sequence = numThreads->getSequence();
+ return paramNodes;
+}
- for (int lid = 0; lid < int(sequence.size()); ++lid)
- intermediate.setLocalSize(lid, sequence[lid]->getAsConstantUnion()->getConstArray()[0].getIConst());
- }
+//
+// Do all special handling for the entry point.
+//
+void HlslParseContext::transformEntryPoint(const TSourceLoc& loc, TFunction& function, const TAttributeMap& attributes)
+{
+ inEntryPoint = function.getName().compare(intermediate.getEntryPointName().c_str()) == 0;
- const TIntermAggregate* maxVertexCount = attributes[EatMaxVertexCount];
- if (maxVertexCount != nullptr) {
- intermediate.setVertices(maxVertexCount->getSequence()[0]->getAsConstantUnion()->getConstArray()[0].getIConst());
- }
+ if (!inEntryPoint) {
+ remapNonEntryPointIO(function);
+ return;
}
- return paramNodes;
+ // entry point logic...
+
+ intermediate.setEntryPointMangledName(function.getMangledName().c_str());
+ intermediate.incrementEntryPointCount();
+
+ // Handle parameters and return value
+ remapEntryPointIO(function);
+ if (entryPointOutput) {
+ if (shouldFlatten(entryPointOutput->getType()))
+ flatten(loc, *entryPointOutput);
+ if (shouldSplit(entryPointOutput->getType()))
+ split(*entryPointOutput);
+ assignLocations(*entryPointOutput);
+ }
+
+ // Handle function attributes
+ const TIntermAggregate* numThreads = attributes[EatNumThreads];
+ if (numThreads != nullptr) {
+ const TIntermSequence& sequence = numThreads->getSequence();
+
+ for (int lid = 0; lid < int(sequence.size()); ++lid)
+ intermediate.setLocalSize(lid, sequence[lid]->getAsConstantUnion()->getConstArray()[0].getIConst());
+ }
+ const TIntermAggregate* maxVertexCount = attributes[EatMaxVertexCount];
+ if (maxVertexCount != nullptr)
+ intermediate.setVertices(maxVertexCount->getSequence()[0]->getAsConstantUnion()->getConstArray()[0].getIConst());
}
void HlslParseContext::handleFunctionBody(const TSourceLoc& loc, TFunction& function, TIntermNode* functionBody, TIntermNode*& node)
void assignLocations(TVariable& variable);
TFunction& handleFunctionDeclarator(const TSourceLoc&, TFunction& function, bool prototype);
TIntermAggregate* handleFunctionDefinition(const TSourceLoc&, TFunction&, const TAttributeMap&);
+ void transformEntryPoint(const TSourceLoc&, TFunction&, const TAttributeMap&);
void handleFunctionBody(const TSourceLoc&, TFunction&, TIntermNode* functionBody, TIntermNode*& node);
void remapEntryPointIO(TFunction& function);
void remapNonEntryPointIO(TFunction& function);