1 // Copyright (c) 2017 Google Inc.
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
7 // http://www.apache.org/licenses/LICENSE-2.0
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
21 #include "const_folding_rules.h"
22 #include "def_use_manager.h"
23 #include "folding_rules.h"
24 #include "ir_builder.h"
25 #include "ir_context.h"
33 #define INT32_MIN (-2147483648)
37 #define INT32_MAX 2147483647
41 #define UINT32_MAX 0xffffffff /* 4294967295U */
44 // Returns the single-word result from performing the given unary operation on
45 // the operand value which is passed in as a 32-bit word.
46 uint32_t UnaryOperate(SpvOp opcode, uint32_t operand) {
49 case SpvOp::SpvOpSNegate:
50 return -static_cast<int32_t>(operand);
53 case SpvOp::SpvOpLogicalNot:
54 return !static_cast<bool>(operand);
57 "Unsupported unary operation for OpSpecConstantOp instruction");
62 // Returns the single-word result from performing the given binary operation on
63 // the operand values which are passed in as two 32-bit word.
64 uint32_t BinaryOperate(SpvOp opcode, uint32_t a, uint32_t b) {
67 case SpvOp::SpvOpIAdd:
69 case SpvOp::SpvOpISub:
71 case SpvOp::SpvOpIMul:
73 case SpvOp::SpvOpUDiv:
76 case SpvOp::SpvOpSDiv:
78 return (static_cast<int32_t>(a)) / (static_cast<int32_t>(b));
79 case SpvOp::SpvOpSRem: {
80 // The sign of non-zero result comes from the first operand: a. This is
81 // guaranteed by C++11 rules for integer division operator. The division
82 // result is rounded toward zero, so the result of '%' has the sign of
85 return static_cast<int32_t>(a) % static_cast<int32_t>(b);
87 case SpvOp::SpvOpSMod: {
88 // The sign of non-zero result comes from the second operand: b
90 int32_t rem = BinaryOperate(SpvOp::SpvOpSRem, a, b);
91 int32_t b_prim = static_cast<int32_t>(b);
92 return (rem + b_prim) % b_prim;
94 case SpvOp::SpvOpUMod:
99 case SpvOp::SpvOpShiftRightLogical: {
102 case SpvOp::SpvOpShiftRightArithmetic:
103 return (static_cast<int32_t>(a)) >> b;
104 case SpvOp::SpvOpShiftLeftLogical:
107 // Bitwise operations
108 case SpvOp::SpvOpBitwiseOr:
110 case SpvOp::SpvOpBitwiseAnd:
112 case SpvOp::SpvOpBitwiseXor:
116 case SpvOp::SpvOpLogicalEqual:
117 return (static_cast<bool>(a)) == (static_cast<bool>(b));
118 case SpvOp::SpvOpLogicalNotEqual:
119 return (static_cast<bool>(a)) != (static_cast<bool>(b));
120 case SpvOp::SpvOpLogicalOr:
121 return (static_cast<bool>(a)) || (static_cast<bool>(b));
122 case SpvOp::SpvOpLogicalAnd:
123 return (static_cast<bool>(a)) && (static_cast<bool>(b));
126 case SpvOp::SpvOpIEqual:
128 case SpvOp::SpvOpINotEqual:
130 case SpvOp::SpvOpULessThan:
132 case SpvOp::SpvOpSLessThan:
133 return (static_cast<int32_t>(a)) < (static_cast<int32_t>(b));
134 case SpvOp::SpvOpUGreaterThan:
136 case SpvOp::SpvOpSGreaterThan:
137 return (static_cast<int32_t>(a)) > (static_cast<int32_t>(b));
138 case SpvOp::SpvOpULessThanEqual:
140 case SpvOp::SpvOpSLessThanEqual:
141 return (static_cast<int32_t>(a)) <= (static_cast<int32_t>(b));
142 case SpvOp::SpvOpUGreaterThanEqual:
144 case SpvOp::SpvOpSGreaterThanEqual:
145 return (static_cast<int32_t>(a)) >= (static_cast<int32_t>(b));
148 "Unsupported binary operation for OpSpecConstantOp instruction");
153 // Returns the single-word result from performing the given ternary operation
154 // on the operand values which are passed in as three 32-bit word.
155 uint32_t TernaryOperate(SpvOp opcode, uint32_t a, uint32_t b, uint32_t c) {
157 case SpvOp::SpvOpSelect:
158 return (static_cast<bool>(a)) ? b : c;
161 "Unsupported ternary operation for OpSpecConstantOp instruction");
166 // Returns the single-word result from performing the given operation on the
167 // operand words. This only works with 32-bit operations and uses boolean
168 // convention that 0u is false, and anything else is boolean true.
169 // TODO(qining): Support operands other than 32-bit wide.
170 uint32_t OperateWords(SpvOp opcode,
171 const std::vector<uint32_t>& operand_words) {
172 switch (operand_words.size()) {
174 return UnaryOperate(opcode, operand_words.front());
176 return BinaryOperate(opcode, operand_words.front(), operand_words.back());
178 return TernaryOperate(opcode, operand_words[0], operand_words[1],
181 assert(false && "Invalid number of operands");
186 bool FoldInstructionInternal(ir::Instruction* inst) {
187 ir::IRContext* context = inst->context();
188 auto identity_map = [](uint32_t id) { return id; };
189 ir::Instruction* folded_inst = FoldInstructionToConstant(inst, identity_map);
190 if (folded_inst != nullptr) {
191 inst->SetOpcode(SpvOpCopyObject);
192 inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {folded_inst->result_id()}}});
196 SpvOp opcode = inst->opcode();
197 analysis::ConstantManager* const_manger = context->get_constant_mgr();
199 std::vector<const analysis::Constant*> constants;
200 for (uint32_t i = 0; i < inst->NumInOperands(); i++) {
201 const ir::Operand* operand = &inst->GetInOperand(i);
202 if (operand->type != SPV_OPERAND_TYPE_ID) {
203 constants.push_back(nullptr);
205 uint32_t id = operand->words[0];
206 const analysis::Constant* constant =
207 const_manger->FindDeclaredConstant(id);
208 constants.push_back(constant);
212 static FoldingRules* rules = new FoldingRules();
213 for (FoldingRule rule : rules->GetRulesForOpcode(opcode)) {
214 if (rule(inst, constants)) {
223 const ConstantFoldingRules& GetConstantFoldingRules() {
224 static ConstantFoldingRules* rules = new ConstantFoldingRules();
228 // Returns the result of performing an operation on scalar constant operands.
229 // This function extracts the operand values as 32 bit words and returns the
230 // result in 32 bit word. Scalar constants with longer than 32-bit width are
231 // not accepted in this function.
232 uint32_t FoldScalars(SpvOp opcode,
233 const std::vector<const analysis::Constant*>& operands) {
234 assert(IsFoldableOpcode(opcode) &&
235 "Unhandled instruction opcode in FoldScalars");
236 std::vector<uint32_t> operand_values_in_raw_words;
237 for (const auto& operand : operands) {
238 if (const analysis::ScalarConstant* scalar = operand->AsScalarConstant()) {
239 const auto& scalar_words = scalar->words();
240 assert(scalar_words.size() == 1 &&
241 "Scalar constants with longer than 32-bit width are not allowed "
243 operand_values_in_raw_words.push_back(scalar_words.front());
244 } else if (operand->AsNullConstant()) {
245 operand_values_in_raw_words.push_back(0u);
248 "FoldScalars() only accepts ScalarConst or NullConst type of "
252 return OperateWords(opcode, operand_values_in_raw_words);
255 // Returns true if |inst| is a binary operation that takes two integers as
256 // parameters and folds to a constant that can be represented as an unsigned
257 // 32-bit value when the ids have been replaced by |id_map|. If |inst| can be
258 // folded, the resulting value is returned in |*result|. Valid result types for
259 // the instruction are any integer (signed or unsigned) with 32-bits or less, or
261 bool FoldBinaryIntegerOpToConstant(ir::Instruction* inst,
262 std::function<uint32_t(uint32_t)> id_map,
264 SpvOp opcode = inst->opcode();
265 ir::IRContext* context = inst->context();
266 analysis::ConstantManager* const_manger = context->get_constant_mgr();
269 const analysis::IntConstant* constants[2];
270 for (uint32_t i = 0; i < 2; i++) {
271 const ir::Operand* operand = &inst->GetInOperand(i);
272 if (operand->type != SPV_OPERAND_TYPE_ID) {
275 ids[i] = id_map(operand->words[0]);
276 const analysis::Constant* constant =
277 const_manger->FindDeclaredConstant(ids[i]);
278 constants[i] = (constant != nullptr ? constant->AsIntConstant() : nullptr);
283 case SpvOp::SpvOpIMul:
284 for (uint32_t i = 0; i < 2; i++) {
285 if (constants[i] != nullptr && constants[i]->IsZero()) {
291 case SpvOp::SpvOpUDiv:
292 case SpvOp::SpvOpSDiv:
293 case SpvOp::SpvOpSRem:
294 case SpvOp::SpvOpSMod:
295 case SpvOp::SpvOpUMod:
296 // This changes undefined behaviour (ie divide by 0) into a 0.
297 for (uint32_t i = 0; i < 2; i++) {
298 if (constants[i] != nullptr && constants[i]->IsZero()) {
306 case SpvOp::SpvOpShiftRightLogical:
307 case SpvOp::SpvOpShiftLeftLogical:
308 if (constants[1] != nullptr) {
309 // When shifting by a value larger than the size of the result, the
310 // result is undefined. We are setting the undefined behaviour to a
312 uint32_t shift_amount = constants[1]->GetU32BitValue();
313 if (shift_amount >= 32) {
320 // Bitwise operations
321 case SpvOp::SpvOpBitwiseOr:
322 for (uint32_t i = 0; i < 2; i++) {
323 if (constants[i] != nullptr) {
324 // TODO: Change the mask against a value based on the bit width of the
325 // instruction result type. This way we can handle say 16-bit values
327 uint32_t mask = constants[i]->GetU32BitValue();
328 if (mask == 0xFFFFFFFF) {
329 *result = 0xFFFFFFFF;
335 case SpvOp::SpvOpBitwiseAnd:
336 for (uint32_t i = 0; i < 2; i++) {
337 if (constants[i] != nullptr) {
338 if (constants[i]->IsZero()) {
347 case SpvOp::SpvOpULessThan:
348 if (constants[0] != nullptr &&
349 constants[0]->GetU32BitValue() == UINT32_MAX) {
353 if (constants[1] != nullptr && constants[1]->GetU32BitValue() == 0) {
358 case SpvOp::SpvOpSLessThan:
359 if (constants[0] != nullptr &&
360 constants[0]->GetS32BitValue() == INT32_MAX) {
364 if (constants[1] != nullptr &&
365 constants[1]->GetS32BitValue() == INT32_MIN) {
370 case SpvOp::SpvOpUGreaterThan:
371 if (constants[0] != nullptr && constants[0]->IsZero()) {
375 if (constants[1] != nullptr &&
376 constants[1]->GetU32BitValue() == UINT32_MAX) {
381 case SpvOp::SpvOpSGreaterThan:
382 if (constants[0] != nullptr &&
383 constants[0]->GetS32BitValue() == INT32_MIN) {
387 if (constants[1] != nullptr &&
388 constants[1]->GetS32BitValue() == INT32_MAX) {
393 case SpvOp::SpvOpULessThanEqual:
394 if (constants[0] != nullptr && constants[0]->IsZero()) {
398 if (constants[1] != nullptr &&
399 constants[1]->GetU32BitValue() == UINT32_MAX) {
404 case SpvOp::SpvOpSLessThanEqual:
405 if (constants[0] != nullptr &&
406 constants[0]->GetS32BitValue() == INT32_MIN) {
410 if (constants[1] != nullptr &&
411 constants[1]->GetS32BitValue() == INT32_MAX) {
416 case SpvOp::SpvOpUGreaterThanEqual:
417 if (constants[0] != nullptr &&
418 constants[0]->GetU32BitValue() == UINT32_MAX) {
422 if (constants[1] != nullptr && constants[1]->GetU32BitValue() == 0) {
427 case SpvOp::SpvOpSGreaterThanEqual:
428 if (constants[0] != nullptr &&
429 constants[0]->GetS32BitValue() == INT32_MAX) {
433 if (constants[1] != nullptr &&
434 constants[1]->GetS32BitValue() == INT32_MIN) {
445 // Returns true if |inst| is a binary operation on two boolean values, and folds
446 // to a constant boolean value when the ids have been replaced using |id_map|.
447 // If |inst| can be folded, the result value is returned in |*result|.
448 bool FoldBinaryBooleanOpToConstant(ir::Instruction* inst,
449 std::function<uint32_t(uint32_t)> id_map,
451 SpvOp opcode = inst->opcode();
452 ir::IRContext* context = inst->context();
453 analysis::ConstantManager* const_manger = context->get_constant_mgr();
456 const analysis::BoolConstant* constants[2];
457 for (uint32_t i = 0; i < 2; i++) {
458 const ir::Operand* operand = &inst->GetInOperand(i);
459 if (operand->type != SPV_OPERAND_TYPE_ID) {
462 ids[i] = id_map(operand->words[0]);
463 const analysis::Constant* constant =
464 const_manger->FindDeclaredConstant(ids[i]);
465 constants[i] = (constant != nullptr ? constant->AsBoolConstant() : nullptr);
470 case SpvOp::SpvOpLogicalOr:
471 for (uint32_t i = 0; i < 2; i++) {
472 if (constants[i] != nullptr) {
473 if (constants[i]->value()) {
480 case SpvOp::SpvOpLogicalAnd:
481 for (uint32_t i = 0; i < 2; i++) {
482 if (constants[i] != nullptr) {
483 if (!constants[i]->value()) {
497 // Returns true if |inst| can be folded to an constant when the ids have been
498 // substituted using id_map. If it can, the value is returned in |result|. If
499 // not, |result| is unchanged. It is assumed that not all operands are
500 // constant. Those cases are handled by |FoldScalar|.
501 bool FoldIntegerOpToConstant(ir::Instruction* inst,
502 std::function<uint32_t(uint32_t)> id_map,
504 assert(IsFoldableOpcode(inst->opcode()) &&
505 "Unhandled instruction opcode in FoldScalars");
506 switch (inst->NumInOperands()) {
508 return FoldBinaryIntegerOpToConstant(inst, id_map, result) ||
509 FoldBinaryBooleanOpToConstant(inst, id_map, result);
515 std::vector<uint32_t> FoldVectors(
516 SpvOp opcode, uint32_t num_dims,
517 const std::vector<const analysis::Constant*>& operands) {
518 assert(IsFoldableOpcode(opcode) &&
519 "Unhandled instruction opcode in FoldVectors");
520 std::vector<uint32_t> result;
521 for (uint32_t d = 0; d < num_dims; d++) {
522 std::vector<uint32_t> operand_values_for_one_dimension;
523 for (const auto& operand : operands) {
524 if (const analysis::VectorConstant* vector_operand =
525 operand->AsVectorConstant()) {
526 // Extract the raw value of the scalar component constants
527 // in 32-bit words here. The reason of not using FoldScalars() here
528 // is that we do not create temporary null constants as components
529 // when the vector operand is a NullConstant because Constant creation
530 // may need extra checks for the validity and that is not manageed in
532 if (const analysis::ScalarConstant* scalar_component =
533 vector_operand->GetComponents().at(d)->AsScalarConstant()) {
534 const auto& scalar_words = scalar_component->words();
536 scalar_words.size() == 1 &&
537 "Vector components with longer than 32-bit width are not allowed "
539 operand_values_for_one_dimension.push_back(scalar_words.front());
540 } else if (operand->AsNullConstant()) {
541 operand_values_for_one_dimension.push_back(0u);
544 "VectorConst should only has ScalarConst or NullConst as "
547 } else if (operand->AsNullConstant()) {
548 operand_values_for_one_dimension.push_back(0u);
551 "FoldVectors() only accepts VectorConst or NullConst type of "
555 result.push_back(OperateWords(opcode, operand_values_for_one_dimension));
560 bool IsFoldableOpcode(SpvOp opcode) {
561 // NOTE: Extend to more opcodes as new cases are handled in the folder
564 case SpvOp::SpvOpBitwiseAnd:
565 case SpvOp::SpvOpBitwiseOr:
566 case SpvOp::SpvOpBitwiseXor:
567 case SpvOp::SpvOpIAdd:
568 case SpvOp::SpvOpIEqual:
569 case SpvOp::SpvOpIMul:
570 case SpvOp::SpvOpINotEqual:
571 case SpvOp::SpvOpISub:
572 case SpvOp::SpvOpLogicalAnd:
573 case SpvOp::SpvOpLogicalEqual:
574 case SpvOp::SpvOpLogicalNot:
575 case SpvOp::SpvOpLogicalNotEqual:
576 case SpvOp::SpvOpLogicalOr:
577 case SpvOp::SpvOpNot:
578 case SpvOp::SpvOpSDiv:
579 case SpvOp::SpvOpSelect:
580 case SpvOp::SpvOpSGreaterThan:
581 case SpvOp::SpvOpSGreaterThanEqual:
582 case SpvOp::SpvOpShiftLeftLogical:
583 case SpvOp::SpvOpShiftRightArithmetic:
584 case SpvOp::SpvOpShiftRightLogical:
585 case SpvOp::SpvOpSLessThan:
586 case SpvOp::SpvOpSLessThanEqual:
587 case SpvOp::SpvOpSMod:
588 case SpvOp::SpvOpSNegate:
589 case SpvOp::SpvOpSRem:
590 case SpvOp::SpvOpUDiv:
591 case SpvOp::SpvOpUGreaterThan:
592 case SpvOp::SpvOpUGreaterThanEqual:
593 case SpvOp::SpvOpULessThan:
594 case SpvOp::SpvOpULessThanEqual:
595 case SpvOp::SpvOpUMod:
602 bool IsFoldableConstant(const analysis::Constant* cst) {
603 // Currently supported constants are 32-bit values or null constants.
604 if (const analysis::ScalarConstant* scalar = cst->AsScalarConstant())
605 return scalar->words().size() == 1;
607 return cst->AsNullConstant() != nullptr;
610 ir::Instruction* FoldInstructionToConstant(
611 ir::Instruction* inst, std::function<uint32_t(uint32_t)> id_map) {
612 ir::IRContext* context = inst->context();
613 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
615 if (!inst->IsFoldableByFoldScalar() &&
616 !GetConstantFoldingRules().HasFoldingRule(inst->opcode())) {
619 // Collect the values of the constant parameters.
620 std::vector<const analysis::Constant*> constants;
621 bool missing_constants = false;
622 inst->ForEachInId([&constants, &missing_constants, const_mgr,
623 &id_map](uint32_t* op_id) {
624 uint32_t id = id_map(*op_id);
625 const analysis::Constant* const_op = const_mgr->FindDeclaredConstant(id);
627 constants.push_back(nullptr);
628 missing_constants = true;
630 constants.push_back(const_op);
634 if (GetConstantFoldingRules().HasFoldingRule(inst->opcode())) {
635 const analysis::Constant* folded_const = nullptr;
637 GetConstantFoldingRules().GetRulesForOpcode(inst->opcode())) {
638 folded_const = rule(inst, constants);
639 if (folded_const != nullptr) {
640 ir::Instruction* const_inst =
641 const_mgr->GetDefiningInstruction(folded_const);
642 // May be a new instruction that needs to be analysed.
643 context->UpdateDefUse(const_inst);
649 uint32_t result_val = 0;
650 bool successful = false;
651 // If all parameters are constant, fold the instruction to a constant.
652 if (!missing_constants && inst->IsFoldableByFoldScalar()) {
653 result_val = FoldScalars(inst->opcode(), constants);
657 if (!successful && inst->IsFoldableByFoldScalar()) {
658 successful = FoldIntegerOpToConstant(inst, id_map, &result_val);
662 const analysis::Constant* result_const =
663 const_mgr->GetConstant(const_mgr->GetType(inst), {result_val});
664 return const_mgr->GetDefiningInstruction(result_const);
669 bool IsFoldableType(ir::Instruction* type_inst) {
670 // Support 32-bit integers.
671 if (type_inst->opcode() == SpvOpTypeInt) {
672 return type_inst->GetSingleWordInOperand(0) == 32;
675 if (type_inst->opcode() == SpvOpTypeBool) {
682 bool FoldInstruction(ir::Instruction* inst) {
683 bool modified = false;
684 ir::Instruction* folded_inst(inst);
685 while (folded_inst->opcode() != SpvOpCopyObject &&
686 FoldInstructionInternal(&*folded_inst)) {
693 } // namespace spvtools