SROA: Do replacement on structs with no partial references.
[platform/upstream/SPIRV-Tools.git] / test / opt / loop_optimizations / nested_loops.cpp
1 // Copyright (c) 2017 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include <gmock/gmock.h>
16
17 #include <memory>
18 #include <string>
19 #include <unordered_set>
20 #include <vector>
21
22 #include "../assembly_builder.h"
23 #include "../function_utils.h"
24 #include "../pass_fixture.h"
25 #include "../pass_utils.h"
26
27 #include "opt/iterator.h"
28 #include "opt/loop_descriptor.h"
29 #include "opt/pass.h"
30 #include "opt/tree_iterator.h"
31
32 namespace {
33
34 using namespace spvtools;
35 using ::testing::UnorderedElementsAre;
36
37 bool Validate(const std::vector<uint32_t>& bin) {
38   spv_target_env target_env = SPV_ENV_UNIVERSAL_1_2;
39   spv_context spvContext = spvContextCreate(target_env);
40   spv_diagnostic diagnostic = nullptr;
41   spv_const_binary_t binary = {bin.data(), bin.size()};
42   spv_result_t error = spvValidate(spvContext, &binary, &diagnostic);
43   if (error != 0) spvDiagnosticPrint(diagnostic);
44   spvDiagnosticDestroy(diagnostic);
45   spvContextDestroy(spvContext);
46   return error == 0;
47 }
48
49 using PassClassTest = PassTest<::testing::Test>;
50
51 /*
52 Generated from the following GLSL
53 #version 330 core
54 layout(location = 0) out vec4 c;
55 void main() {
56   int i = 0;
57   for (; i < 10; ++i) {
58     int j = 0;
59     int k = 0;
60     for (; j < 11; ++j) {}
61     for (; k < 12; ++k) {}
62   }
63 }
64 */
65 TEST_F(PassClassTest, BasicVisitFromEntryPoint) {
66   const std::string text = R"(
67                OpCapability Shader
68           %1 = OpExtInstImport "GLSL.std.450"
69                OpMemoryModel Logical GLSL450
70                OpEntryPoint Fragment %2 "main" %3
71                OpExecutionMode %2 OriginUpperLeft
72                OpSource GLSL 330
73                OpName %2 "main"
74                OpName %4 "i"
75                OpName %5 "j"
76                OpName %6 "k"
77                OpName %3 "c"
78                OpDecorate %3 Location 0
79           %7 = OpTypeVoid
80           %8 = OpTypeFunction %7
81           %9 = OpTypeInt 32 1
82          %10 = OpTypePointer Function %9
83          %11 = OpConstant %9 0
84          %12 = OpConstant %9 10
85          %13 = OpTypeBool
86          %14 = OpConstant %9 11
87          %15 = OpConstant %9 1
88          %16 = OpConstant %9 12
89          %17 = OpTypeFloat 32
90          %18 = OpTypeVector %17 4
91          %19 = OpTypePointer Output %18
92           %3 = OpVariable %19 Output
93           %2 = OpFunction %7 None %8
94          %20 = OpLabel
95           %4 = OpVariable %10 Function
96           %5 = OpVariable %10 Function
97           %6 = OpVariable %10 Function
98                OpStore %4 %11
99                OpBranch %21
100          %21 = OpLabel
101                OpLoopMerge %22 %23 None
102                OpBranch %24
103          %24 = OpLabel
104          %25 = OpLoad %9 %4
105          %26 = OpSLessThan %13 %25 %12
106                OpBranchConditional %26 %27 %22
107          %27 = OpLabel
108                OpStore %5 %11
109                OpStore %6 %11
110                OpBranch %28
111          %28 = OpLabel
112                OpLoopMerge %29 %30 None
113                OpBranch %31
114          %31 = OpLabel
115          %32 = OpLoad %9 %5
116          %33 = OpSLessThan %13 %32 %14
117                OpBranchConditional %33 %34 %29
118          %34 = OpLabel
119                OpBranch %30
120          %30 = OpLabel
121          %35 = OpLoad %9 %5
122          %36 = OpIAdd %9 %35 %15
123                OpStore %5 %36
124                OpBranch %28
125          %29 = OpLabel
126                OpBranch %37
127          %37 = OpLabel
128                OpLoopMerge %38 %39 None
129                OpBranch %40
130          %40 = OpLabel
131          %41 = OpLoad %9 %6
132          %42 = OpSLessThan %13 %41 %16
133                OpBranchConditional %42 %43 %38
134          %43 = OpLabel
135                OpBranch %39
136          %39 = OpLabel
137          %44 = OpLoad %9 %6
138          %45 = OpIAdd %9 %44 %15
139                OpStore %6 %45
140                OpBranch %37
141          %38 = OpLabel
142                OpBranch %23
143          %23 = OpLabel
144          %46 = OpLoad %9 %4
145          %47 = OpIAdd %9 %46 %15
146                OpStore %4 %47
147                OpBranch %21
148          %22 = OpLabel
149                OpReturn
150                OpFunctionEnd
151   )";
152   // clang-format on
153   std::unique_ptr<ir::IRContext> context =
154       BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
155                   SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
156   ir::Module* module = context->module();
157   EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
158                              << text << std::endl;
159   const ir::Function* f = spvtest::GetFunction(module, 2);
160   ir::LoopDescriptor& ld = *context->GetLoopDescriptor(f);
161
162   EXPECT_EQ(ld.NumLoops(), 3u);
163
164   // Invalid basic block id.
165   EXPECT_EQ(ld[0u], nullptr);
166   // Not a loop header.
167   EXPECT_EQ(ld[20], nullptr);
168
169   ir::Loop& parent_loop = *ld[21];
170   EXPECT_TRUE(parent_loop.HasNestedLoops());
171   EXPECT_FALSE(parent_loop.IsNested());
172   EXPECT_EQ(parent_loop.GetDepth(), 1u);
173   EXPECT_EQ(std::distance(parent_loop.begin(), parent_loop.end()), 2u);
174   EXPECT_EQ(parent_loop.GetHeaderBlock(), spvtest::GetBasicBlock(f, 21));
175   EXPECT_EQ(parent_loop.GetLatchBlock(), spvtest::GetBasicBlock(f, 23));
176   EXPECT_EQ(parent_loop.GetMergeBlock(), spvtest::GetBasicBlock(f, 22));
177
178   ir::Loop& child_loop_1 = *ld[28];
179   EXPECT_FALSE(child_loop_1.HasNestedLoops());
180   EXPECT_TRUE(child_loop_1.IsNested());
181   EXPECT_EQ(child_loop_1.GetDepth(), 2u);
182   EXPECT_EQ(std::distance(child_loop_1.begin(), child_loop_1.end()), 0u);
183   EXPECT_EQ(child_loop_1.GetHeaderBlock(), spvtest::GetBasicBlock(f, 28));
184   EXPECT_EQ(child_loop_1.GetLatchBlock(), spvtest::GetBasicBlock(f, 30));
185   EXPECT_EQ(child_loop_1.GetMergeBlock(), spvtest::GetBasicBlock(f, 29));
186
187   ir::Loop& child_loop_2 = *ld[37];
188   EXPECT_FALSE(child_loop_2.HasNestedLoops());
189   EXPECT_TRUE(child_loop_2.IsNested());
190   EXPECT_EQ(child_loop_2.GetDepth(), 2u);
191   EXPECT_EQ(std::distance(child_loop_2.begin(), child_loop_2.end()), 0u);
192   EXPECT_EQ(child_loop_2.GetHeaderBlock(), spvtest::GetBasicBlock(f, 37));
193   EXPECT_EQ(child_loop_2.GetLatchBlock(), spvtest::GetBasicBlock(f, 39));
194   EXPECT_EQ(child_loop_2.GetMergeBlock(), spvtest::GetBasicBlock(f, 38));
195 }
196
197 static void CheckLoopBlocks(ir::Loop* loop,
198                             std::unordered_set<uint32_t>* expected_ids) {
199   SCOPED_TRACE("Check loop " + std::to_string(loop->GetHeaderBlock()->id()));
200   for (uint32_t bb_id : loop->GetBlocks()) {
201     EXPECT_EQ(expected_ids->count(bb_id), 1u);
202     expected_ids->erase(bb_id);
203   }
204   EXPECT_FALSE(loop->IsInsideLoop(loop->GetMergeBlock()));
205   EXPECT_EQ(expected_ids->size(), 0u);
206 }
207
208 /*
209 Generated from the following GLSL
210 #version 330 core
211 layout(location = 0) out vec4 c;
212 void main() {
213   int i = 0;
214   for (; i < 10; ++i) {
215     for (int j = 0; j < 11; ++j) {
216       if (j < 5) {
217         for (int k = 0; k < 12; ++k) {}
218       }
219       else {}
220       for (int k = 0; k < 12; ++k) {}
221     }
222   }
223 }*/
224 TEST_F(PassClassTest, TripleNestedLoop) {
225   const std::string text = R"(
226                OpCapability Shader
227           %1 = OpExtInstImport "GLSL.std.450"
228                OpMemoryModel Logical GLSL450
229                OpEntryPoint Fragment %2 "main" %3
230                OpExecutionMode %2 OriginUpperLeft
231                OpSource GLSL 330
232                OpName %2 "main"
233                OpName %4 "i"
234                OpName %5 "j"
235                OpName %6 "k"
236                OpName %7 "k"
237                OpName %3 "c"
238                OpDecorate %3 Location 0
239           %8 = OpTypeVoid
240           %9 = OpTypeFunction %8
241          %10 = OpTypeInt 32 1
242          %11 = OpTypePointer Function %10
243          %12 = OpConstant %10 0
244          %13 = OpConstant %10 10
245          %14 = OpTypeBool
246          %15 = OpConstant %10 11
247          %16 = OpConstant %10 5
248          %17 = OpConstant %10 12
249          %18 = OpConstant %10 1
250          %19 = OpTypeFloat 32
251          %20 = OpTypeVector %19 4
252          %21 = OpTypePointer Output %20
253           %3 = OpVariable %21 Output
254           %2 = OpFunction %8 None %9
255          %22 = OpLabel
256           %4 = OpVariable %11 Function
257           %5 = OpVariable %11 Function
258           %6 = OpVariable %11 Function
259           %7 = OpVariable %11 Function
260                OpStore %4 %12
261                OpBranch %23
262          %23 = OpLabel
263                OpLoopMerge %24 %25 None
264                OpBranch %26
265          %26 = OpLabel
266          %27 = OpLoad %10 %4
267          %28 = OpSLessThan %14 %27 %13
268                OpBranchConditional %28 %29 %24
269          %29 = OpLabel
270                OpStore %5 %12
271                OpBranch %30
272          %30 = OpLabel
273                OpLoopMerge %31 %32 None
274                OpBranch %33
275          %33 = OpLabel
276          %34 = OpLoad %10 %5
277          %35 = OpSLessThan %14 %34 %15
278                OpBranchConditional %35 %36 %31
279          %36 = OpLabel
280          %37 = OpLoad %10 %5
281          %38 = OpSLessThan %14 %37 %16
282                OpSelectionMerge %39 None
283                OpBranchConditional %38 %40 %39
284          %40 = OpLabel
285                OpStore %6 %12
286                OpBranch %41
287          %41 = OpLabel
288                OpLoopMerge %42 %43 None
289                OpBranch %44
290          %44 = OpLabel
291          %45 = OpLoad %10 %6
292          %46 = OpSLessThan %14 %45 %17
293                OpBranchConditional %46 %47 %42
294          %47 = OpLabel
295                OpBranch %43
296          %43 = OpLabel
297          %48 = OpLoad %10 %6
298          %49 = OpIAdd %10 %48 %18
299                OpStore %6 %49
300                OpBranch %41
301          %42 = OpLabel
302                OpBranch %39
303          %39 = OpLabel
304                OpStore %7 %12
305                OpBranch %50
306          %50 = OpLabel
307                OpLoopMerge %51 %52 None
308                OpBranch %53
309          %53 = OpLabel
310          %54 = OpLoad %10 %7
311          %55 = OpSLessThan %14 %54 %17
312                OpBranchConditional %55 %56 %51
313          %56 = OpLabel
314                OpBranch %52
315          %52 = OpLabel
316          %57 = OpLoad %10 %7
317          %58 = OpIAdd %10 %57 %18
318                OpStore %7 %58
319                OpBranch %50
320          %51 = OpLabel
321                OpBranch %32
322          %32 = OpLabel
323          %59 = OpLoad %10 %5
324          %60 = OpIAdd %10 %59 %18
325                OpStore %5 %60
326                OpBranch %30
327          %31 = OpLabel
328                OpBranch %25
329          %25 = OpLabel
330          %61 = OpLoad %10 %4
331          %62 = OpIAdd %10 %61 %18
332                OpStore %4 %62
333                OpBranch %23
334          %24 = OpLabel
335                OpReturn
336                OpFunctionEnd
337   )";
338   // clang-format on
339   std::unique_ptr<ir::IRContext> context =
340       BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
341                   SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
342   ir::Module* module = context->module();
343   EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
344                              << text << std::endl;
345   const ir::Function* f = spvtest::GetFunction(module, 2);
346   ir::LoopDescriptor& ld = *context->GetLoopDescriptor(f);
347
348   EXPECT_EQ(ld.NumLoops(), 4u);
349
350   // Invalid basic block id.
351   EXPECT_EQ(ld[0u], nullptr);
352   // Not in a loop.
353   EXPECT_EQ(ld[22], nullptr);
354
355   // Check that we can map basic block to the correct loop.
356   // The following block ids do not belong to a loop.
357   for (uint32_t bb_id : {22, 24}) EXPECT_EQ(ld[bb_id], nullptr);
358
359   {
360     std::unordered_set<uint32_t> basic_block_in_loop = {
361         {23, 26, 29, 30, 33, 36, 40, 41, 44, 47, 43,
362          42, 39, 50, 53, 56, 52, 51, 32, 31, 25}};
363     ir::Loop* loop = ld[23];
364     CheckLoopBlocks(loop, &basic_block_in_loop);
365
366     EXPECT_TRUE(loop->HasNestedLoops());
367     EXPECT_FALSE(loop->IsNested());
368     EXPECT_EQ(loop->GetDepth(), 1u);
369     EXPECT_EQ(std::distance(loop->begin(), loop->end()), 1u);
370     EXPECT_EQ(loop->GetPreHeaderBlock(), spvtest::GetBasicBlock(f, 22));
371     EXPECT_EQ(loop->GetHeaderBlock(), spvtest::GetBasicBlock(f, 23));
372     EXPECT_EQ(loop->GetLatchBlock(), spvtest::GetBasicBlock(f, 25));
373     EXPECT_EQ(loop->GetMergeBlock(), spvtest::GetBasicBlock(f, 24));
374     EXPECT_FALSE(loop->IsInsideLoop(loop->GetMergeBlock()));
375     EXPECT_FALSE(loop->IsInsideLoop(loop->GetPreHeaderBlock()));
376   }
377
378   {
379     std::unordered_set<uint32_t> basic_block_in_loop = {
380         {30, 33, 36, 40, 41, 44, 47, 43, 42, 39, 50, 53, 56, 52, 51, 32}};
381     ir::Loop* loop = ld[30];
382     CheckLoopBlocks(loop, &basic_block_in_loop);
383
384     EXPECT_TRUE(loop->HasNestedLoops());
385     EXPECT_TRUE(loop->IsNested());
386     EXPECT_EQ(loop->GetDepth(), 2u);
387     EXPECT_EQ(std::distance(loop->begin(), loop->end()), 2u);
388     EXPECT_EQ(loop->GetPreHeaderBlock(), spvtest::GetBasicBlock(f, 29));
389     EXPECT_EQ(loop->GetHeaderBlock(), spvtest::GetBasicBlock(f, 30));
390     EXPECT_EQ(loop->GetLatchBlock(), spvtest::GetBasicBlock(f, 32));
391     EXPECT_EQ(loop->GetMergeBlock(), spvtest::GetBasicBlock(f, 31));
392     EXPECT_FALSE(loop->IsInsideLoop(loop->GetMergeBlock()));
393     EXPECT_FALSE(loop->IsInsideLoop(loop->GetPreHeaderBlock()));
394   }
395
396   {
397     std::unordered_set<uint32_t> basic_block_in_loop = {{41, 44, 47, 43}};
398     ir::Loop* loop = ld[41];
399     CheckLoopBlocks(loop, &basic_block_in_loop);
400
401     EXPECT_FALSE(loop->HasNestedLoops());
402     EXPECT_TRUE(loop->IsNested());
403     EXPECT_EQ(loop->GetDepth(), 3u);
404     EXPECT_EQ(std::distance(loop->begin(), loop->end()), 0u);
405     EXPECT_EQ(loop->GetPreHeaderBlock(), spvtest::GetBasicBlock(f, 40));
406     EXPECT_EQ(loop->GetHeaderBlock(), spvtest::GetBasicBlock(f, 41));
407     EXPECT_EQ(loop->GetLatchBlock(), spvtest::GetBasicBlock(f, 43));
408     EXPECT_EQ(loop->GetMergeBlock(), spvtest::GetBasicBlock(f, 42));
409     EXPECT_FALSE(loop->IsInsideLoop(loop->GetMergeBlock()));
410     EXPECT_FALSE(loop->IsInsideLoop(loop->GetPreHeaderBlock()));
411   }
412
413   {
414     std::unordered_set<uint32_t> basic_block_in_loop = {{50, 53, 56, 52}};
415     ir::Loop* loop = ld[50];
416     CheckLoopBlocks(loop, &basic_block_in_loop);
417
418     EXPECT_FALSE(loop->HasNestedLoops());
419     EXPECT_TRUE(loop->IsNested());
420     EXPECT_EQ(loop->GetDepth(), 3u);
421     EXPECT_EQ(std::distance(loop->begin(), loop->end()), 0u);
422     EXPECT_EQ(loop->GetPreHeaderBlock(), spvtest::GetBasicBlock(f, 39));
423     EXPECT_EQ(loop->GetHeaderBlock(), spvtest::GetBasicBlock(f, 50));
424     EXPECT_EQ(loop->GetLatchBlock(), spvtest::GetBasicBlock(f, 52));
425     EXPECT_EQ(loop->GetMergeBlock(), spvtest::GetBasicBlock(f, 51));
426     EXPECT_FALSE(loop->IsInsideLoop(loop->GetMergeBlock()));
427     EXPECT_FALSE(loop->IsInsideLoop(loop->GetPreHeaderBlock()));
428   }
429
430   // Make sure LoopDescriptor gives us the inner most loop when we query for
431   // loops.
432   for (const ir::BasicBlock& bb : *f) {
433     if (ir::Loop* loop = ld[&bb]) {
434       for (ir::Loop& sub_loop :
435            ir::make_range(++opt::TreeDFIterator<ir::Loop>(loop),
436                           opt::TreeDFIterator<ir::Loop>())) {
437         EXPECT_FALSE(sub_loop.IsInsideLoop(bb.id()));
438       }
439     }
440   }
441 }
442
443 /*
444 Generated from the following GLSL
445 #version 330 core
446 layout(location = 0) out vec4 c;
447 void main() {
448   for (int i = 0; i < 10; ++i) {
449     for (int j = 0; j < 11; ++j) {
450       for (int k = 0; k < 11; ++k) {}
451     }
452     for (int k = 0; k < 12; ++k) {}
453   }
454 }
455 */
456 TEST_F(PassClassTest, LoopParentTest) {
457   const std::string text = R"(
458                OpCapability Shader
459           %1 = OpExtInstImport "GLSL.std.450"
460                OpMemoryModel Logical GLSL450
461                OpEntryPoint Fragment %2 "main" %3
462                OpExecutionMode %2 OriginUpperLeft
463                OpSource GLSL 330
464                OpName %2 "main"
465                OpName %4 "i"
466                OpName %5 "j"
467                OpName %6 "k"
468                OpName %7 "k"
469                OpName %3 "c"
470                OpDecorate %3 Location 0
471           %8 = OpTypeVoid
472           %9 = OpTypeFunction %8
473          %10 = OpTypeInt 32 1
474          %11 = OpTypePointer Function %10
475          %12 = OpConstant %10 0
476          %13 = OpConstant %10 10
477          %14 = OpTypeBool
478          %15 = OpConstant %10 11
479          %16 = OpConstant %10 1
480          %17 = OpConstant %10 12
481          %18 = OpTypeFloat 32
482          %19 = OpTypeVector %18 4
483          %20 = OpTypePointer Output %19
484           %3 = OpVariable %20 Output
485           %2 = OpFunction %8 None %9
486          %21 = OpLabel
487           %4 = OpVariable %11 Function
488           %5 = OpVariable %11 Function
489           %6 = OpVariable %11 Function
490           %7 = OpVariable %11 Function
491                OpStore %4 %12
492                OpBranch %22
493          %22 = OpLabel
494                OpLoopMerge %23 %24 None
495                OpBranch %25
496          %25 = OpLabel
497          %26 = OpLoad %10 %4
498          %27 = OpSLessThan %14 %26 %13
499                OpBranchConditional %27 %28 %23
500          %28 = OpLabel
501                OpStore %5 %12
502                OpBranch %29
503          %29 = OpLabel
504                OpLoopMerge %30 %31 None
505                OpBranch %32
506          %32 = OpLabel
507          %33 = OpLoad %10 %5
508          %34 = OpSLessThan %14 %33 %15
509                OpBranchConditional %34 %35 %30
510          %35 = OpLabel
511                OpStore %6 %12
512                OpBranch %36
513          %36 = OpLabel
514                OpLoopMerge %37 %38 None
515                OpBranch %39
516          %39 = OpLabel
517          %40 = OpLoad %10 %6
518          %41 = OpSLessThan %14 %40 %15
519                OpBranchConditional %41 %42 %37
520          %42 = OpLabel
521                OpBranch %38
522          %38 = OpLabel
523          %43 = OpLoad %10 %6
524          %44 = OpIAdd %10 %43 %16
525                OpStore %6 %44
526                OpBranch %36
527          %37 = OpLabel
528                OpBranch %31
529          %31 = OpLabel
530          %45 = OpLoad %10 %5
531          %46 = OpIAdd %10 %45 %16
532                OpStore %5 %46
533                OpBranch %29
534          %30 = OpLabel
535                OpStore %7 %12
536                OpBranch %47
537          %47 = OpLabel
538                OpLoopMerge %48 %49 None
539                OpBranch %50
540          %50 = OpLabel
541          %51 = OpLoad %10 %7
542          %52 = OpSLessThan %14 %51 %17
543                OpBranchConditional %52 %53 %48
544          %53 = OpLabel
545                OpBranch %49
546          %49 = OpLabel
547          %54 = OpLoad %10 %7
548          %55 = OpIAdd %10 %54 %16
549                OpStore %7 %55
550                OpBranch %47
551          %48 = OpLabel
552                OpBranch %24
553          %24 = OpLabel
554          %56 = OpLoad %10 %4
555          %57 = OpIAdd %10 %56 %16
556                OpStore %4 %57
557                OpBranch %22
558          %23 = OpLabel
559                OpReturn
560                OpFunctionEnd
561   )";
562   // clang-format on
563   std::unique_ptr<ir::IRContext> context =
564       BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
565                   SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
566   ir::Module* module = context->module();
567   EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
568                              << text << std::endl;
569   const ir::Function* f = spvtest::GetFunction(module, 2);
570   ir::LoopDescriptor& ld = *context->GetLoopDescriptor(f);
571
572   EXPECT_EQ(ld.NumLoops(), 4u);
573
574   {
575     ir::Loop& loop = *ld[22];
576     EXPECT_TRUE(loop.HasNestedLoops());
577     EXPECT_FALSE(loop.IsNested());
578     EXPECT_EQ(loop.GetDepth(), 1u);
579     EXPECT_EQ(loop.GetParent(), nullptr);
580   }
581
582   {
583     ir::Loop& loop = *ld[29];
584     EXPECT_TRUE(loop.HasNestedLoops());
585     EXPECT_TRUE(loop.IsNested());
586     EXPECT_EQ(loop.GetDepth(), 2u);
587     EXPECT_EQ(loop.GetParent(), ld[22]);
588   }
589
590   {
591     ir::Loop& loop = *ld[36];
592     EXPECT_FALSE(loop.HasNestedLoops());
593     EXPECT_TRUE(loop.IsNested());
594     EXPECT_EQ(loop.GetDepth(), 3u);
595     EXPECT_EQ(loop.GetParent(), ld[29]);
596   }
597
598   {
599     ir::Loop& loop = *ld[47];
600     EXPECT_FALSE(loop.HasNestedLoops());
601     EXPECT_TRUE(loop.IsNested());
602     EXPECT_EQ(loop.GetDepth(), 2u);
603     EXPECT_EQ(loop.GetParent(), ld[22]);
604   }
605 }
606
607 /*
608 Generated from the following GLSL + --eliminate-local-multi-store
609 The preheader of loop %33 and %41 were removed as well.
610
611 #version 330 core
612 void main() {
613   int a = 0;
614   for (int i = 0; i < 10; ++i) {
615     if (i == 0) {
616       a = 1;
617     } else {
618       a = 2;
619     }
620     for (int j = 0; j < 11; ++j) {
621       a++;
622     }
623   }
624   for (int k = 0; k < 12; ++k) {}
625 }
626 */
627 TEST_F(PassClassTest, CreatePreheaderTest) {
628   const std::string text = R"(
629                OpCapability Shader
630           %1 = OpExtInstImport "GLSL.std.450"
631                OpMemoryModel Logical GLSL450
632                OpEntryPoint Fragment %2 "main"
633                OpExecutionMode %2 OriginUpperLeft
634                OpSource GLSL 330
635                OpName %2 "main"
636           %3 = OpTypeVoid
637           %4 = OpTypeFunction %3
638           %5 = OpTypeInt 32 1
639           %6 = OpTypePointer Function %5
640           %7 = OpConstant %5 0
641           %8 = OpConstant %5 10
642           %9 = OpTypeBool
643          %10 = OpConstant %5 1
644          %11 = OpConstant %5 2
645          %12 = OpConstant %5 11
646          %13 = OpConstant %5 12
647          %14 = OpUndef %5
648           %2 = OpFunction %3 None %4
649          %15 = OpLabel
650                OpBranch %16
651          %16 = OpLabel
652          %17 = OpPhi %5 %7 %15 %18 %19
653          %20 = OpPhi %5 %7 %15 %21 %19
654          %22 = OpPhi %5 %14 %15 %23 %19
655                OpLoopMerge %41 %19 None
656                OpBranch %25
657          %25 = OpLabel
658          %26 = OpSLessThan %9 %20 %8
659                OpBranchConditional %26 %27 %41
660          %27 = OpLabel
661          %28 = OpIEqual %9 %20 %7
662                OpSelectionMerge %33 None
663                OpBranchConditional %28 %30 %31
664          %30 = OpLabel
665                OpBranch %33
666          %31 = OpLabel
667                OpBranch %33
668          %33 = OpLabel
669          %18 = OpPhi %5 %10 %30 %11 %31 %34 %35
670          %23 = OpPhi %5 %7 %30 %7 %31 %36 %35
671                OpLoopMerge %37 %35 None
672                OpBranch %38
673          %38 = OpLabel
674          %39 = OpSLessThan %9 %23 %12
675                OpBranchConditional %39 %40 %37
676          %40 = OpLabel
677          %34 = OpIAdd %5 %18 %10
678                OpBranch %35
679          %35 = OpLabel
680          %36 = OpIAdd %5 %23 %10
681                OpBranch %33
682          %37 = OpLabel
683                OpBranch %19
684          %19 = OpLabel
685          %21 = OpIAdd %5 %20 %10
686                OpBranch %16
687          %41 = OpLabel
688          %42 = OpPhi %5 %7 %25 %43 %44
689                OpLoopMerge %45 %44 None
690                OpBranch %46
691          %46 = OpLabel
692          %47 = OpSLessThan %9 %42 %13
693                OpBranchConditional %47 %48 %45
694          %48 = OpLabel
695                OpBranch %44
696          %44 = OpLabel
697          %43 = OpIAdd %5 %42 %10
698                OpBranch %41
699          %45 = OpLabel
700                OpReturn
701                OpFunctionEnd
702   )";
703   // clang-format on
704   std::unique_ptr<ir::IRContext> context =
705       BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
706                   SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
707   ir::Module* module = context->module();
708   EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
709                              << text << std::endl;
710   const ir::Function* f = spvtest::GetFunction(module, 2);
711   ir::LoopDescriptor& ld = *context->GetLoopDescriptor(f);
712   // No invalidation of the cfg should occur during this test.
713   ir::CFG* cfg = context->cfg();
714
715   EXPECT_EQ(ld.NumLoops(), 3u);
716
717   {
718     ir::Loop& loop = *ld[16];
719     EXPECT_TRUE(loop.HasNestedLoops());
720     EXPECT_FALSE(loop.IsNested());
721     EXPECT_EQ(loop.GetDepth(), 1u);
722     EXPECT_EQ(loop.GetParent(), nullptr);
723   }
724
725   {
726     ir::Loop& loop = *ld[33];
727     EXPECT_EQ(loop.GetPreHeaderBlock(), nullptr);
728     EXPECT_NE(loop.GetOrCreatePreHeaderBlock(context.get()), nullptr);
729     // Make sure the loop descriptor was properly updated.
730     EXPECT_EQ(ld[loop.GetPreHeaderBlock()], ld[16]);
731     {
732       const std::vector<uint32_t>& preds =
733           cfg->preds(loop.GetPreHeaderBlock()->id());
734       std::unordered_set<uint32_t> pred_set(preds.begin(), preds.end());
735       EXPECT_EQ(pred_set.size(), 2u);
736       EXPECT_TRUE(pred_set.count(30));
737       EXPECT_TRUE(pred_set.count(31));
738       // Check the phi instructions.
739       loop.GetPreHeaderBlock()->ForEachPhiInst(
740           [&pred_set](ir::Instruction* phi) {
741             for (uint32_t i = 1; i < phi->NumInOperands(); i += 2) {
742               EXPECT_TRUE(pred_set.count(phi->GetSingleWordInOperand(i)));
743             }
744           });
745     }
746     {
747       const std::vector<uint32_t>& preds =
748           cfg->preds(loop.GetHeaderBlock()->id());
749       std::unordered_set<uint32_t> pred_set(preds.begin(), preds.end());
750       EXPECT_EQ(pred_set.size(), 2u);
751       EXPECT_TRUE(pred_set.count(loop.GetPreHeaderBlock()->id()));
752       EXPECT_TRUE(pred_set.count(35));
753       // Check the phi instructions.
754       loop.GetHeaderBlock()->ForEachPhiInst([&pred_set](ir::Instruction* phi) {
755         for (uint32_t i = 1; i < phi->NumInOperands(); i += 2) {
756           EXPECT_TRUE(pred_set.count(phi->GetSingleWordInOperand(i)));
757         }
758       });
759     }
760   }
761
762   {
763     ir::Loop& loop = *ld[41];
764     EXPECT_EQ(loop.GetPreHeaderBlock(), nullptr);
765     EXPECT_NE(loop.GetOrCreatePreHeaderBlock(context.get()), nullptr);
766     EXPECT_EQ(ld[loop.GetPreHeaderBlock()], nullptr);
767     EXPECT_EQ(cfg->preds(loop.GetPreHeaderBlock()->id()).size(), 1u);
768     EXPECT_EQ(cfg->preds(loop.GetPreHeaderBlock()->id())[0], 25u);
769     // Check the phi instructions.
770     loop.GetPreHeaderBlock()->ForEachPhiInst([](ir::Instruction* phi) {
771       EXPECT_EQ(phi->NumInOperands(), 2u);
772       EXPECT_EQ(phi->GetSingleWordInOperand(1), 25u);
773     });
774     {
775       const std::vector<uint32_t>& preds =
776           cfg->preds(loop.GetHeaderBlock()->id());
777       std::unordered_set<uint32_t> pred_set(preds.begin(), preds.end());
778       EXPECT_EQ(pred_set.size(), 2u);
779       EXPECT_TRUE(pred_set.count(loop.GetPreHeaderBlock()->id()));
780       EXPECT_TRUE(pred_set.count(44));
781       // Check the phi instructions.
782       loop.GetHeaderBlock()->ForEachPhiInst([&pred_set](ir::Instruction* phi) {
783         for (uint32_t i = 1; i < phi->NumInOperands(); i += 2) {
784           EXPECT_TRUE(pred_set.count(phi->GetSingleWordInOperand(i)));
785         }
786       });
787     }
788   }
789
790   // Make sure pre-header insertion leaves the module valid.
791   std::vector<uint32_t> bin;
792   context->module()->ToBinary(&bin, true);
793   EXPECT_TRUE(Validate(bin));
794 }
795
796 }  // namespace