1 // Copyright (c) 2016 Google Inc.
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
7 // http://www.apache.org/licenses/LICENSE-2.0
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
15 #include <gmock/gmock.h>
16 #include <gtest/gtest.h>
18 #include "spirv-tools/optimizer.hpp"
19 #include "spirv/1.1/spirv.h"
23 using namespace spvtools;
24 using ::testing::ContainerEq;
25 using ::testing::HasSubstr;
27 TEST(CppInterface, SuccessfulRoundTrip) {
28 const std::string input_text = "%2 = OpSizeOf %1 %3\n";
29 SpirvTools t(SPV_ENV_UNIVERSAL_1_1);
31 std::vector<uint32_t> binary;
32 EXPECT_TRUE(t.Assemble(input_text, &binary));
33 EXPECT_TRUE(binary.size() > 5u);
34 EXPECT_EQ(SpvMagicNumber, binary[0]);
35 EXPECT_EQ(SpvVersion, binary[1]);
37 // This cannot pass validation since %1 is not defined.
38 t.SetMessageConsumer([](spv_message_level_t level, const char* source,
39 const spv_position_t& position, const char* message) {
40 EXPECT_EQ(SPV_MSG_ERROR, level);
41 EXPECT_STREQ("input", source);
42 EXPECT_EQ(0u, position.line);
43 EXPECT_EQ(0u, position.column);
44 EXPECT_EQ(1u, position.index);
45 EXPECT_STREQ("ID 1 has not been defined", message);
47 EXPECT_FALSE(t.Validate(binary));
49 std::string output_text;
50 EXPECT_TRUE(t.Disassemble(binary, &output_text));
51 EXPECT_EQ(input_text, output_text);
54 TEST(CppInterface, AssembleEmptyModule) {
55 std::vector<uint32_t> binary(10, 42);
56 SpirvTools t(SPV_ENV_UNIVERSAL_1_1);
57 EXPECT_TRUE(t.Assemble("", &binary));
58 // We only have the header.
59 EXPECT_EQ(5u, binary.size());
60 EXPECT_EQ(SpvMagicNumber, binary[0]);
61 EXPECT_EQ(SpvVersion, binary[1]);
64 TEST(CppInterface, AssembleOverloads) {
65 const std::string input_text = "%2 = OpSizeOf %1 %3\n";
66 SpirvTools t(SPV_ENV_UNIVERSAL_1_1);
68 std::vector<uint32_t> binary;
69 EXPECT_TRUE(t.Assemble(input_text, &binary));
70 EXPECT_TRUE(binary.size() > 5u);
71 EXPECT_EQ(SpvMagicNumber, binary[0]);
72 EXPECT_EQ(SpvVersion, binary[1]);
75 std::vector<uint32_t> binary;
76 EXPECT_TRUE(t.Assemble(input_text.data(), input_text.size(), &binary));
77 EXPECT_TRUE(binary.size() > 5u);
78 EXPECT_EQ(SpvMagicNumber, binary[0]);
79 EXPECT_EQ(SpvVersion, binary[1]);
81 { // Ignore the last newline.
82 std::vector<uint32_t> binary;
83 EXPECT_TRUE(t.Assemble(input_text.data(), input_text.size() - 1, &binary));
84 EXPECT_TRUE(binary.size() > 5u);
85 EXPECT_EQ(SpvMagicNumber, binary[0]);
86 EXPECT_EQ(SpvVersion, binary[1]);
90 TEST(CppInterface, AssembleWithWrongTargetEnv) {
91 const std::string input_text = "%r = OpSizeOf %type %pointer";
92 SpirvTools t(SPV_ENV_UNIVERSAL_1_0);
93 int invocation_count = 0;
95 [&invocation_count](spv_message_level_t level, const char* source,
96 const spv_position_t& position, const char* message) {
98 EXPECT_EQ(SPV_MSG_ERROR, level);
99 EXPECT_STREQ("input", source);
100 EXPECT_EQ(0u, position.line);
101 EXPECT_EQ(5u, position.column);
102 EXPECT_EQ(5u, position.index);
103 EXPECT_STREQ("Invalid Opcode name 'OpSizeOf'", message);
106 std::vector<uint32_t> binary = {42, 42};
107 EXPECT_FALSE(t.Assemble(input_text, &binary));
108 EXPECT_THAT(binary, ContainerEq(std::vector<uint32_t>{42, 42}));
109 EXPECT_EQ(1, invocation_count);
112 TEST(CppInterface, DisassembleEmptyModule) {
113 std::string text(10, 'x');
114 SpirvTools t(SPV_ENV_UNIVERSAL_1_1);
115 int invocation_count = 0;
116 t.SetMessageConsumer(
117 [&invocation_count](spv_message_level_t level, const char* source,
118 const spv_position_t& position, const char* message) {
120 EXPECT_EQ(SPV_MSG_ERROR, level);
121 EXPECT_STREQ("input", source);
122 EXPECT_EQ(0u, position.line);
123 EXPECT_EQ(0u, position.column);
124 EXPECT_EQ(0u, position.index);
125 EXPECT_STREQ("Missing module.", message);
127 EXPECT_FALSE(t.Disassemble({}, &text));
128 EXPECT_EQ("xxxxxxxxxx", text); // The original string is unmodified.
129 EXPECT_EQ(1, invocation_count);
132 TEST(CppInterface, DisassembleOverloads) {
133 const std::string input_text = "%2 = OpSizeOf %1 %3\n";
134 SpirvTools t(SPV_ENV_UNIVERSAL_1_1);
136 std::vector<uint32_t> binary;
137 EXPECT_TRUE(t.Assemble(input_text, &binary));
140 std::string output_text;
141 EXPECT_TRUE(t.Disassemble(binary, &output_text));
142 EXPECT_EQ(input_text, output_text);
145 std::string output_text;
146 EXPECT_TRUE(t.Disassemble(binary.data(), binary.size(), &output_text));
147 EXPECT_EQ(input_text, output_text);
151 TEST(CppInterface, DisassembleWithWrongTargetEnv) {
152 const std::string input_text = "%r = OpSizeOf %type %pointer";
153 SpirvTools t11(SPV_ENV_UNIVERSAL_1_1);
154 SpirvTools t10(SPV_ENV_UNIVERSAL_1_0);
155 int invocation_count = 0;
156 t10.SetMessageConsumer(
157 [&invocation_count](spv_message_level_t level, const char* source,
158 const spv_position_t& position, const char* message) {
160 EXPECT_EQ(SPV_MSG_ERROR, level);
161 EXPECT_STREQ("input", source);
162 EXPECT_EQ(0u, position.line);
163 EXPECT_EQ(0u, position.column);
164 EXPECT_EQ(5u, position.index);
165 EXPECT_STREQ("Invalid opcode: 321", message);
168 std::vector<uint32_t> binary;
169 EXPECT_TRUE(t11.Assemble(input_text, &binary));
171 std::string output_text(10, 'x');
172 EXPECT_FALSE(t10.Disassemble(binary, &output_text));
173 EXPECT_EQ("xxxxxxxxxx", output_text); // The original string is unmodified.
176 TEST(CppInterface, SuccessfulValidation) {
177 const std::string input_text = R"(
180 OpMemoryModel Logical GLSL450)";
181 SpirvTools t(SPV_ENV_UNIVERSAL_1_1);
182 int invocation_count = 0;
183 t.SetMessageConsumer([&invocation_count](spv_message_level_t, const char*,
184 const spv_position_t&, const char*) {
188 std::vector<uint32_t> binary;
189 EXPECT_TRUE(t.Assemble(input_text, &binary));
190 EXPECT_TRUE(t.Validate(binary));
191 EXPECT_EQ(0, invocation_count);
194 TEST(CppInterface, ValidateOverloads) {
195 const std::string input_text = R"(
198 OpMemoryModel Logical GLSL450)";
199 SpirvTools t(SPV_ENV_UNIVERSAL_1_1);
200 std::vector<uint32_t> binary;
201 EXPECT_TRUE(t.Assemble(input_text, &binary));
203 { EXPECT_TRUE(t.Validate(binary)); }
204 { EXPECT_TRUE(t.Validate(binary.data(), binary.size())); }
207 TEST(CppInterface, ValidateEmptyModule) {
208 SpirvTools t(SPV_ENV_UNIVERSAL_1_1);
209 int invocation_count = 0;
210 t.SetMessageConsumer(
211 [&invocation_count](spv_message_level_t level, const char* source,
212 const spv_position_t& position, const char* message) {
214 EXPECT_EQ(SPV_MSG_ERROR, level);
215 EXPECT_STREQ("input", source);
216 EXPECT_EQ(0u, position.line);
217 EXPECT_EQ(0u, position.column);
218 EXPECT_EQ(0u, position.index);
219 EXPECT_STREQ("Invalid SPIR-V magic number.", message);
221 EXPECT_FALSE(t.Validate({}));
222 EXPECT_EQ(1, invocation_count);
225 // Returns the assembly for a SPIR-V module with a struct declaration
226 // with the given number of members.
227 std::string MakeModuleHavingStruct(int num_members) {
228 std::stringstream os;
229 os << R"(OpCapability Shader
231 OpMemoryModel Logical GLSL450
234 for (int i = 0; i < num_members; i++) os << " %1";
238 TEST(CppInterface, ValidateWithOptionsPass) {
239 SpirvTools t(SPV_ENV_UNIVERSAL_1_1);
240 std::vector<uint32_t> binary;
241 EXPECT_TRUE(t.Assemble(MakeModuleHavingStruct(10), &binary));
242 const spvtools::ValidatorOptions opts;
244 EXPECT_TRUE(t.Validate(binary.data(), binary.size(), opts));
247 TEST(CppInterface, ValidateWithOptionsFail) {
248 SpirvTools t(SPV_ENV_UNIVERSAL_1_1);
249 std::vector<uint32_t> binary;
250 EXPECT_TRUE(t.Assemble(MakeModuleHavingStruct(10), &binary));
251 spvtools::ValidatorOptions opts;
252 opts.SetUniversalLimit(spv_validator_limit_max_struct_members, 9);
253 std::stringstream os;
254 t.SetMessageConsumer([&os](spv_message_level_t, const char*,
255 const spv_position_t&,
256 const char* message) { os << message; });
258 EXPECT_FALSE(t.Validate(binary.data(), binary.size(), opts));
262 "Number of OpTypeStruct members (10) has exceeded the limit (9)"));
265 // Checks that after running the given optimizer |opt| on the given |original|
266 // source code, we can get the given |optimized| source code.
267 void CheckOptimization(const char* original, const char* optimized,
268 const Optimizer& opt) {
269 SpirvTools t(SPV_ENV_UNIVERSAL_1_1);
270 std::vector<uint32_t> original_binary;
271 ASSERT_TRUE(t.Assemble(original, &original_binary));
273 std::vector<uint32_t> optimized_binary;
274 EXPECT_TRUE(opt.Run(original_binary.data(), original_binary.size(),
277 std::string optimized_text;
278 EXPECT_TRUE(t.Disassemble(optimized_binary, &optimized_text));
279 EXPECT_EQ(optimized, optimized_text);
282 TEST(CppInterface, OptimizeEmptyModule) {
283 SpirvTools t(SPV_ENV_UNIVERSAL_1_1);
284 std::vector<uint32_t> binary;
285 EXPECT_TRUE(t.Assemble("", &binary));
287 Optimizer o(SPV_ENV_UNIVERSAL_1_1);
288 o.RegisterPass(CreateStripDebugInfoPass());
289 EXPECT_TRUE(o.Run(binary.data(), binary.size(), &binary));
292 TEST(CppInterface, OptimizeModifiedModule) {
293 Optimizer o(SPV_ENV_UNIVERSAL_1_1);
294 o.RegisterPass(CreateStripDebugInfoPass());
295 CheckOptimization("OpSource GLSL 450", "", o);
298 TEST(CppInterface, OptimizeMulitplePasses) {
299 const char* original_text =
301 "OpDecorate %true SpecId 1 "
302 "%bool = OpTypeBool "
303 "%true = OpSpecConstantTrue %bool";
305 Optimizer o(SPV_ENV_UNIVERSAL_1_1);
306 o.RegisterPass(CreateStripDebugInfoPass())
307 .RegisterPass(CreateFreezeSpecConstantValuePass());
309 const char* expected_text =
310 "%bool = OpTypeBool\n"
311 "%true = OpConstantTrue %bool\n";
313 CheckOptimization(original_text, expected_text, o);
316 TEST(CppInterface, OptimizeDoNothingWithPassToken) {
317 CreateFreezeSpecConstantValuePass();
318 auto token = CreateUnifyConstantPass();
321 TEST(CppInterface, OptimizeReassignPassToken) {
322 auto token = CreateNullPass();
323 token = CreateStripDebugInfoPass();
326 "OpSource GLSL 450", "",
327 Optimizer(SPV_ENV_UNIVERSAL_1_1).RegisterPass(std::move(token)));
330 TEST(CppInterface, OptimizeMoveConstructPassToken) {
331 auto token1 = CreateStripDebugInfoPass();
332 Optimizer::PassToken token2(std::move(token1));
335 "OpSource GLSL 450", "",
336 Optimizer(SPV_ENV_UNIVERSAL_1_1).RegisterPass(std::move(token2)));
339 TEST(CppInterface, OptimizeMoveAssignPassToken) {
340 auto token1 = CreateStripDebugInfoPass();
341 auto token2 = CreateNullPass();
342 token2 = std::move(token1);
345 "OpSource GLSL 450", "",
346 Optimizer(SPV_ENV_UNIVERSAL_1_1).RegisterPass(std::move(token2)));
349 TEST(CppInterface, OptimizeSameAddressForOriginalOptimizedBinary) {
350 SpirvTools t(SPV_ENV_UNIVERSAL_1_1);
351 std::vector<uint32_t> binary;
352 ASSERT_TRUE(t.Assemble("OpSource GLSL 450", &binary));
354 EXPECT_TRUE(Optimizer(SPV_ENV_UNIVERSAL_1_1)
355 .RegisterPass(CreateStripDebugInfoPass())
356 .Run(binary.data(), binary.size(), &binary));
358 std::string optimized_text;
359 EXPECT_TRUE(t.Disassemble(binary, &optimized_text));
360 EXPECT_EQ("", optimized_text);
363 // TODO(antiagainst): tests for SetMessageConsumer().
365 } // anonymous namespace