Add-support-for-SPV_NV_mesh_shader
[platform/upstream/glslang.git] / gtests / TestFixture.cpp
1 //
2 // Copyright (C) 2016 Google, Inc.
3 //
4 // All rights reserved.
5 //
6 // Redistribution and use in source and binary forms, with or without
7 // modification, are permitted provided that the following conditions
8 // are met:
9 //
10 //    Redistributions of source code must retain the above copyright
11 //    notice, this list of conditions and the following disclaimer.
12 //
13 //    Redistributions in binary form must reproduce the above
14 //    copyright notice, this list of conditions and the following
15 //    disclaimer in the documentation and/or other materials provided
16 //    with the distribution.
17 //
18 //    Neither the name of Google Inc. nor the names of its
19 //    contributors may be used to endorse or promote products derived
20 //    from this software without specific prior written permission.
21 //
22 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
23 // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
24 // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
25 // FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
26 // COPYRIGHT HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
27 // INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
28 // BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
29 // LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
30 // CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
31 // LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
32 // ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
33 // POSSIBILITY OF SUCH DAMAGE.
34
35 #include "TestFixture.h"
36
37 namespace glslangtest {
38
39 std::string FileNameAsCustomTestSuffix(
40     const ::testing::TestParamInfo<std::string>& info)
41 {
42     std::string name = info.param;
43     // A valid test case suffix cannot have '.' and '-' inside.
44     std::replace(name.begin(), name.end(), '.', '_');
45     std::replace(name.begin(), name.end(), '-', '_');
46     return name;
47 }
48
49 EShLanguage GetShaderStage(const std::string& stage)
50 {
51     if (stage == "vert") {
52         return EShLangVertex;
53     } else if (stage == "tesc") {
54         return EShLangTessControl;
55     } else if (stage == "tese") {
56         return EShLangTessEvaluation;
57     } else if (stage == "geom") {
58         return EShLangGeometry;
59     } else if (stage == "frag") {
60         return EShLangFragment;
61     } else if (stage == "comp") {
62         return EShLangCompute;
63 #ifdef NV_EXTENSIONS
64     } else if (stage == "task") {
65         return EShLangTaskNV;
66     } else if (stage == "mesh") {
67         return EShLangMeshNV;
68 #endif
69     } else {
70         assert(0 && "Unknown shader stage");
71         return EShLangCount;
72     }
73 }
74
75 EShMessages DeriveOptions(Source source, Semantics semantics, Target target)
76 {
77     EShMessages result = EShMsgCascadingErrors;
78
79     switch (source) {
80         case Source::GLSL:
81             break;
82         case Source::HLSL:
83             result = static_cast<EShMessages>(result | EShMsgReadHlsl);
84             break;
85     }
86
87     switch (target) {
88         case Target::AST:
89             result = static_cast<EShMessages>(result | EShMsgAST);
90             break;
91         case Target::Spv:
92             result = static_cast<EShMessages>(result | EShMsgSpvRules);
93             result = static_cast<EShMessages>(result | EShMsgKeepUncalled);
94             break;
95         case Target::BothASTAndSpv:
96             result = static_cast<EShMessages>(result | EShMsgSpvRules | EShMsgAST);
97             result = static_cast<EShMessages>(result | EShMsgKeepUncalled);
98             break;
99     };
100
101     switch (semantics) {
102         case Semantics::OpenGL:
103             break;
104         case Semantics::Vulkan:
105             result = static_cast<EShMessages>(result | EShMsgVulkanRules | EShMsgSpvRules);
106             break;
107     }
108
109     result = static_cast<EShMessages>(result | EShMsgHlslLegalization);
110
111     return result;
112 }
113
114 std::pair<bool, std::string> ReadFile(const std::string& path)
115 {
116     std::ifstream fstream(path, std::ios::in);
117     if (fstream) {
118         std::string contents;
119         fstream.seekg(0, std::ios::end);
120         contents.reserve((std::string::size_type)fstream.tellg());
121         fstream.seekg(0, std::ios::beg);
122         contents.assign((std::istreambuf_iterator<char>(fstream)),
123                         std::istreambuf_iterator<char>());
124         return std::make_pair(true, contents);
125     }
126     return std::make_pair(false, "");
127 }
128
129 std::pair<bool, std::vector<std::uint32_t> > ReadSpvBinaryFile(const std::string& path)
130 {
131     std::ifstream fstream(path, std::fstream::in | std::fstream::binary);
132
133     if (!fstream)
134         return std::make_pair(false, std::vector<std::uint32_t>());
135
136     std::vector<std::uint32_t> contents;
137
138     // Reserve space (for efficiency, not for correctness)
139     fstream.seekg(0, fstream.end);
140     contents.reserve(size_t(fstream.tellg()) / sizeof(std::uint32_t));
141     fstream.seekg(0, fstream.beg);
142
143     // There is no istream iterator traversing by uint32_t, so we must loop.
144     while (!fstream.eof()) {
145         std::uint32_t inWord;
146         fstream.read((char *)&inWord, sizeof(inWord));
147
148         if (!fstream.eof())
149             contents.push_back(inWord);
150     }
151
152     return std::make_pair(true, contents); // hopefully, c++11 move semantics optimizes the copy away.
153 }
154
155 bool WriteFile(const std::string& path, const std::string& contents)
156 {
157     std::ofstream fstream(path, std::ios::out);
158     if (!fstream) return false;
159     fstream << contents;
160     fstream.flush();
161     return true;
162 }
163
164 std::string GetSuffix(const std::string& name)
165 {
166     const size_t pos = name.rfind('.');
167     return (pos == std::string::npos) ? "" : name.substr(name.rfind('.') + 1);
168 }
169
170 }  // namespace glslangtest