1 // Copyright (c) 2017 Google Inc.
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
7 // http://www.apache.org/licenses/LICENSE-2.0
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
21 #include "const_folding_rules.h"
22 #include "def_use_manager.h"
23 #include "folding_rules.h"
24 #include "ir_builder.h"
25 #include "ir_context.h"
33 #define INT32_MIN (-2147483648)
37 #define INT32_MAX 2147483647
41 #define UINT32_MAX 0xffffffff /* 4294967295U */
44 // Returns the single-word result from performing the given unary operation on
45 // the operand value which is passed in as a 32-bit word.
46 uint32_t UnaryOperate(SpvOp opcode, uint32_t operand) {
49 case SpvOp::SpvOpSNegate:
50 return -static_cast<int32_t>(operand);
53 case SpvOp::SpvOpLogicalNot:
54 return !static_cast<bool>(operand);
57 "Unsupported unary operation for OpSpecConstantOp instruction");
62 // Returns the single-word result from performing the given binary operation on
63 // the operand values which are passed in as two 32-bit word.
64 uint32_t BinaryOperate(SpvOp opcode, uint32_t a, uint32_t b) {
67 case SpvOp::SpvOpIAdd:
69 case SpvOp::SpvOpISub:
71 case SpvOp::SpvOpIMul:
73 case SpvOp::SpvOpUDiv:
76 case SpvOp::SpvOpSDiv:
78 return (static_cast<int32_t>(a)) / (static_cast<int32_t>(b));
79 case SpvOp::SpvOpSRem: {
80 // The sign of non-zero result comes from the first operand: a. This is
81 // guaranteed by C++11 rules for integer division operator. The division
82 // result is rounded toward zero, so the result of '%' has the sign of
85 return static_cast<int32_t>(a) % static_cast<int32_t>(b);
87 case SpvOp::SpvOpSMod: {
88 // The sign of non-zero result comes from the second operand: b
90 int32_t rem = BinaryOperate(SpvOp::SpvOpSRem, a, b);
91 int32_t b_prim = static_cast<int32_t>(b);
92 return (rem + b_prim) % b_prim;
94 case SpvOp::SpvOpUMod:
99 case SpvOp::SpvOpShiftRightLogical: {
102 case SpvOp::SpvOpShiftRightArithmetic:
103 return (static_cast<int32_t>(a)) >> b;
104 case SpvOp::SpvOpShiftLeftLogical:
107 // Bitwise operations
108 case SpvOp::SpvOpBitwiseOr:
110 case SpvOp::SpvOpBitwiseAnd:
112 case SpvOp::SpvOpBitwiseXor:
116 case SpvOp::SpvOpLogicalEqual:
117 return (static_cast<bool>(a)) == (static_cast<bool>(b));
118 case SpvOp::SpvOpLogicalNotEqual:
119 return (static_cast<bool>(a)) != (static_cast<bool>(b));
120 case SpvOp::SpvOpLogicalOr:
121 return (static_cast<bool>(a)) || (static_cast<bool>(b));
122 case SpvOp::SpvOpLogicalAnd:
123 return (static_cast<bool>(a)) && (static_cast<bool>(b));
126 case SpvOp::SpvOpIEqual:
128 case SpvOp::SpvOpINotEqual:
130 case SpvOp::SpvOpULessThan:
132 case SpvOp::SpvOpSLessThan:
133 return (static_cast<int32_t>(a)) < (static_cast<int32_t>(b));
134 case SpvOp::SpvOpUGreaterThan:
136 case SpvOp::SpvOpSGreaterThan:
137 return (static_cast<int32_t>(a)) > (static_cast<int32_t>(b));
138 case SpvOp::SpvOpULessThanEqual:
140 case SpvOp::SpvOpSLessThanEqual:
141 return (static_cast<int32_t>(a)) <= (static_cast<int32_t>(b));
142 case SpvOp::SpvOpUGreaterThanEqual:
144 case SpvOp::SpvOpSGreaterThanEqual:
145 return (static_cast<int32_t>(a)) >= (static_cast<int32_t>(b));
148 "Unsupported binary operation for OpSpecConstantOp instruction");
153 // Returns the single-word result from performing the given ternary operation
154 // on the operand values which are passed in as three 32-bit word.
155 uint32_t TernaryOperate(SpvOp opcode, uint32_t a, uint32_t b, uint32_t c) {
157 case SpvOp::SpvOpSelect:
158 return (static_cast<bool>(a)) ? b : c;
161 "Unsupported ternary operation for OpSpecConstantOp instruction");
166 // Returns the single-word result from performing the given operation on the
167 // operand words. This only works with 32-bit operations and uses boolean
168 // convention that 0u is false, and anything else is boolean true.
169 // TODO(qining): Support operands other than 32-bit wide.
170 uint32_t OperateWords(SpvOp opcode,
171 const std::vector<uint32_t>& operand_words) {
172 switch (operand_words.size()) {
174 return UnaryOperate(opcode, operand_words.front());
176 return BinaryOperate(opcode, operand_words.front(), operand_words.back());
178 return TernaryOperate(opcode, operand_words[0], operand_words[1],
181 assert(false && "Invalid number of operands");
186 bool FoldInstructionInternal(ir::Instruction* inst) {
187 ir::IRContext* context = inst->context();
188 auto identity_map = [](uint32_t id) { return id; };
189 ir::Instruction* folded_inst = FoldInstructionToConstant(inst, identity_map);
190 if (folded_inst != nullptr) {
191 inst->SetOpcode(SpvOpCopyObject);
192 inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {folded_inst->result_id()}}});
196 SpvOp opcode = inst->opcode();
197 analysis::ConstantManager* const_manager = context->get_constant_mgr();
199 std::vector<const analysis::Constant*> constants =
200 const_manager->GetOperandConstants(inst);
202 static FoldingRules* rules = new FoldingRules();
203 for (FoldingRule rule : rules->GetRulesForOpcode(opcode)) {
204 if (rule(inst, constants)) {
213 const ConstantFoldingRules& GetConstantFoldingRules() {
214 static ConstantFoldingRules* rules = new ConstantFoldingRules();
218 // Returns the result of performing an operation on scalar constant operands.
219 // This function extracts the operand values as 32 bit words and returns the
220 // result in 32 bit word. Scalar constants with longer than 32-bit width are
221 // not accepted in this function.
222 uint32_t FoldScalars(SpvOp opcode,
223 const std::vector<const analysis::Constant*>& operands) {
224 assert(IsFoldableOpcode(opcode) &&
225 "Unhandled instruction opcode in FoldScalars");
226 std::vector<uint32_t> operand_values_in_raw_words;
227 for (const auto& operand : operands) {
228 if (const analysis::ScalarConstant* scalar = operand->AsScalarConstant()) {
229 const auto& scalar_words = scalar->words();
230 assert(scalar_words.size() == 1 &&
231 "Scalar constants with longer than 32-bit width are not allowed "
233 operand_values_in_raw_words.push_back(scalar_words.front());
234 } else if (operand->AsNullConstant()) {
235 operand_values_in_raw_words.push_back(0u);
238 "FoldScalars() only accepts ScalarConst or NullConst type of "
242 return OperateWords(opcode, operand_values_in_raw_words);
245 // Returns true if |inst| is a binary operation that takes two integers as
246 // parameters and folds to a constant that can be represented as an unsigned
247 // 32-bit value when the ids have been replaced by |id_map|. If |inst| can be
248 // folded, the resulting value is returned in |*result|. Valid result types for
249 // the instruction are any integer (signed or unsigned) with 32-bits or less, or
251 bool FoldBinaryIntegerOpToConstant(ir::Instruction* inst,
252 std::function<uint32_t(uint32_t)> id_map,
254 SpvOp opcode = inst->opcode();
255 ir::IRContext* context = inst->context();
256 analysis::ConstantManager* const_manger = context->get_constant_mgr();
259 const analysis::IntConstant* constants[2];
260 for (uint32_t i = 0; i < 2; i++) {
261 const ir::Operand* operand = &inst->GetInOperand(i);
262 if (operand->type != SPV_OPERAND_TYPE_ID) {
265 ids[i] = id_map(operand->words[0]);
266 const analysis::Constant* constant =
267 const_manger->FindDeclaredConstant(ids[i]);
268 constants[i] = (constant != nullptr ? constant->AsIntConstant() : nullptr);
273 case SpvOp::SpvOpIMul:
274 for (uint32_t i = 0; i < 2; i++) {
275 if (constants[i] != nullptr && constants[i]->IsZero()) {
281 case SpvOp::SpvOpUDiv:
282 case SpvOp::SpvOpSDiv:
283 case SpvOp::SpvOpSRem:
284 case SpvOp::SpvOpSMod:
285 case SpvOp::SpvOpUMod:
286 // This changes undefined behaviour (ie divide by 0) into a 0.
287 for (uint32_t i = 0; i < 2; i++) {
288 if (constants[i] != nullptr && constants[i]->IsZero()) {
296 case SpvOp::SpvOpShiftRightLogical:
297 case SpvOp::SpvOpShiftLeftLogical:
298 if (constants[1] != nullptr) {
299 // When shifting by a value larger than the size of the result, the
300 // result is undefined. We are setting the undefined behaviour to a
302 uint32_t shift_amount = constants[1]->GetU32BitValue();
303 if (shift_amount >= 32) {
310 // Bitwise operations
311 case SpvOp::SpvOpBitwiseOr:
312 for (uint32_t i = 0; i < 2; i++) {
313 if (constants[i] != nullptr) {
314 // TODO: Change the mask against a value based on the bit width of the
315 // instruction result type. This way we can handle say 16-bit values
317 uint32_t mask = constants[i]->GetU32BitValue();
318 if (mask == 0xFFFFFFFF) {
319 *result = 0xFFFFFFFF;
325 case SpvOp::SpvOpBitwiseAnd:
326 for (uint32_t i = 0; i < 2; i++) {
327 if (constants[i] != nullptr) {
328 if (constants[i]->IsZero()) {
337 case SpvOp::SpvOpULessThan:
338 if (constants[0] != nullptr &&
339 constants[0]->GetU32BitValue() == UINT32_MAX) {
343 if (constants[1] != nullptr && constants[1]->GetU32BitValue() == 0) {
348 case SpvOp::SpvOpSLessThan:
349 if (constants[0] != nullptr &&
350 constants[0]->GetS32BitValue() == INT32_MAX) {
354 if (constants[1] != nullptr &&
355 constants[1]->GetS32BitValue() == INT32_MIN) {
360 case SpvOp::SpvOpUGreaterThan:
361 if (constants[0] != nullptr && constants[0]->IsZero()) {
365 if (constants[1] != nullptr &&
366 constants[1]->GetU32BitValue() == UINT32_MAX) {
371 case SpvOp::SpvOpSGreaterThan:
372 if (constants[0] != nullptr &&
373 constants[0]->GetS32BitValue() == INT32_MIN) {
377 if (constants[1] != nullptr &&
378 constants[1]->GetS32BitValue() == INT32_MAX) {
383 case SpvOp::SpvOpULessThanEqual:
384 if (constants[0] != nullptr && constants[0]->IsZero()) {
388 if (constants[1] != nullptr &&
389 constants[1]->GetU32BitValue() == UINT32_MAX) {
394 case SpvOp::SpvOpSLessThanEqual:
395 if (constants[0] != nullptr &&
396 constants[0]->GetS32BitValue() == INT32_MIN) {
400 if (constants[1] != nullptr &&
401 constants[1]->GetS32BitValue() == INT32_MAX) {
406 case SpvOp::SpvOpUGreaterThanEqual:
407 if (constants[0] != nullptr &&
408 constants[0]->GetU32BitValue() == UINT32_MAX) {
412 if (constants[1] != nullptr && constants[1]->GetU32BitValue() == 0) {
417 case SpvOp::SpvOpSGreaterThanEqual:
418 if (constants[0] != nullptr &&
419 constants[0]->GetS32BitValue() == INT32_MAX) {
423 if (constants[1] != nullptr &&
424 constants[1]->GetS32BitValue() == INT32_MIN) {
435 // Returns true if |inst| is a binary operation on two boolean values, and folds
436 // to a constant boolean value when the ids have been replaced using |id_map|.
437 // If |inst| can be folded, the result value is returned in |*result|.
438 bool FoldBinaryBooleanOpToConstant(ir::Instruction* inst,
439 std::function<uint32_t(uint32_t)> id_map,
441 SpvOp opcode = inst->opcode();
442 ir::IRContext* context = inst->context();
443 analysis::ConstantManager* const_manger = context->get_constant_mgr();
446 const analysis::BoolConstant* constants[2];
447 for (uint32_t i = 0; i < 2; i++) {
448 const ir::Operand* operand = &inst->GetInOperand(i);
449 if (operand->type != SPV_OPERAND_TYPE_ID) {
452 ids[i] = id_map(operand->words[0]);
453 const analysis::Constant* constant =
454 const_manger->FindDeclaredConstant(ids[i]);
455 constants[i] = (constant != nullptr ? constant->AsBoolConstant() : nullptr);
460 case SpvOp::SpvOpLogicalOr:
461 for (uint32_t i = 0; i < 2; i++) {
462 if (constants[i] != nullptr) {
463 if (constants[i]->value()) {
470 case SpvOp::SpvOpLogicalAnd:
471 for (uint32_t i = 0; i < 2; i++) {
472 if (constants[i] != nullptr) {
473 if (!constants[i]->value()) {
487 // Returns true if |inst| can be folded to an constant when the ids have been
488 // substituted using id_map. If it can, the value is returned in |result|. If
489 // not, |result| is unchanged. It is assumed that not all operands are
490 // constant. Those cases are handled by |FoldScalar|.
491 bool FoldIntegerOpToConstant(ir::Instruction* inst,
492 std::function<uint32_t(uint32_t)> id_map,
494 assert(IsFoldableOpcode(inst->opcode()) &&
495 "Unhandled instruction opcode in FoldScalars");
496 switch (inst->NumInOperands()) {
498 return FoldBinaryIntegerOpToConstant(inst, id_map, result) ||
499 FoldBinaryBooleanOpToConstant(inst, id_map, result);
505 std::vector<uint32_t> FoldVectors(
506 SpvOp opcode, uint32_t num_dims,
507 const std::vector<const analysis::Constant*>& operands) {
508 assert(IsFoldableOpcode(opcode) &&
509 "Unhandled instruction opcode in FoldVectors");
510 std::vector<uint32_t> result;
511 for (uint32_t d = 0; d < num_dims; d++) {
512 std::vector<uint32_t> operand_values_for_one_dimension;
513 for (const auto& operand : operands) {
514 if (const analysis::VectorConstant* vector_operand =
515 operand->AsVectorConstant()) {
516 // Extract the raw value of the scalar component constants
517 // in 32-bit words here. The reason of not using FoldScalars() here
518 // is that we do not create temporary null constants as components
519 // when the vector operand is a NullConstant because Constant creation
520 // may need extra checks for the validity and that is not manageed in
522 if (const analysis::ScalarConstant* scalar_component =
523 vector_operand->GetComponents().at(d)->AsScalarConstant()) {
524 const auto& scalar_words = scalar_component->words();
526 scalar_words.size() == 1 &&
527 "Vector components with longer than 32-bit width are not allowed "
529 operand_values_for_one_dimension.push_back(scalar_words.front());
530 } else if (operand->AsNullConstant()) {
531 operand_values_for_one_dimension.push_back(0u);
534 "VectorConst should only has ScalarConst or NullConst as "
537 } else if (operand->AsNullConstant()) {
538 operand_values_for_one_dimension.push_back(0u);
541 "FoldVectors() only accepts VectorConst or NullConst type of "
545 result.push_back(OperateWords(opcode, operand_values_for_one_dimension));
550 bool IsFoldableOpcode(SpvOp opcode) {
551 // NOTE: Extend to more opcodes as new cases are handled in the folder
554 case SpvOp::SpvOpBitwiseAnd:
555 case SpvOp::SpvOpBitwiseOr:
556 case SpvOp::SpvOpBitwiseXor:
557 case SpvOp::SpvOpIAdd:
558 case SpvOp::SpvOpIEqual:
559 case SpvOp::SpvOpIMul:
560 case SpvOp::SpvOpINotEqual:
561 case SpvOp::SpvOpISub:
562 case SpvOp::SpvOpLogicalAnd:
563 case SpvOp::SpvOpLogicalEqual:
564 case SpvOp::SpvOpLogicalNot:
565 case SpvOp::SpvOpLogicalNotEqual:
566 case SpvOp::SpvOpLogicalOr:
567 case SpvOp::SpvOpNot:
568 case SpvOp::SpvOpSDiv:
569 case SpvOp::SpvOpSelect:
570 case SpvOp::SpvOpSGreaterThan:
571 case SpvOp::SpvOpSGreaterThanEqual:
572 case SpvOp::SpvOpShiftLeftLogical:
573 case SpvOp::SpvOpShiftRightArithmetic:
574 case SpvOp::SpvOpShiftRightLogical:
575 case SpvOp::SpvOpSLessThan:
576 case SpvOp::SpvOpSLessThanEqual:
577 case SpvOp::SpvOpSMod:
578 case SpvOp::SpvOpSNegate:
579 case SpvOp::SpvOpSRem:
580 case SpvOp::SpvOpUDiv:
581 case SpvOp::SpvOpUGreaterThan:
582 case SpvOp::SpvOpUGreaterThanEqual:
583 case SpvOp::SpvOpULessThan:
584 case SpvOp::SpvOpULessThanEqual:
585 case SpvOp::SpvOpUMod:
592 bool IsFoldableConstant(const analysis::Constant* cst) {
593 // Currently supported constants are 32-bit values or null constants.
594 if (const analysis::ScalarConstant* scalar = cst->AsScalarConstant())
595 return scalar->words().size() == 1;
597 return cst->AsNullConstant() != nullptr;
600 ir::Instruction* FoldInstructionToConstant(
601 ir::Instruction* inst, std::function<uint32_t(uint32_t)> id_map) {
602 ir::IRContext* context = inst->context();
603 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
605 if (!inst->IsFoldableByFoldScalar() &&
606 !GetConstantFoldingRules().HasFoldingRule(inst->opcode())) {
609 // Collect the values of the constant parameters.
610 std::vector<const analysis::Constant*> constants;
611 bool missing_constants = false;
612 inst->ForEachInId([&constants, &missing_constants, const_mgr,
613 &id_map](uint32_t* op_id) {
614 uint32_t id = id_map(*op_id);
615 const analysis::Constant* const_op = const_mgr->FindDeclaredConstant(id);
617 constants.push_back(nullptr);
618 missing_constants = true;
620 constants.push_back(const_op);
624 if (GetConstantFoldingRules().HasFoldingRule(inst->opcode())) {
625 const analysis::Constant* folded_const = nullptr;
627 GetConstantFoldingRules().GetRulesForOpcode(inst->opcode())) {
628 folded_const = rule(inst, constants);
629 if (folded_const != nullptr) {
630 ir::Instruction* const_inst =
631 const_mgr->GetDefiningInstruction(folded_const);
632 // May be a new instruction that needs to be analysed.
633 context->UpdateDefUse(const_inst);
639 uint32_t result_val = 0;
640 bool successful = false;
641 // If all parameters are constant, fold the instruction to a constant.
642 if (!missing_constants && inst->IsFoldableByFoldScalar()) {
643 result_val = FoldScalars(inst->opcode(), constants);
647 if (!successful && inst->IsFoldableByFoldScalar()) {
648 successful = FoldIntegerOpToConstant(inst, id_map, &result_val);
652 const analysis::Constant* result_const =
653 const_mgr->GetConstant(const_mgr->GetType(inst), {result_val});
654 return const_mgr->GetDefiningInstruction(result_const);
659 bool IsFoldableType(ir::Instruction* type_inst) {
660 // Support 32-bit integers.
661 if (type_inst->opcode() == SpvOpTypeInt) {
662 return type_inst->GetSingleWordInOperand(0) == 32;
665 if (type_inst->opcode() == SpvOpTypeBool) {
672 bool FoldInstruction(ir::Instruction* inst) {
673 bool modified = false;
674 ir::Instruction* folded_inst(inst);
675 while (folded_inst->opcode() != SpvOpCopyObject &&
676 FoldInstructionInternal(&*folded_inst)) {
683 } // namespace spvtools