glsl: Fix codegen for constant ir_binop_{l,r}shift with mixed types
authorSviatoslav Peleshko <sviatoslav.peleshko@globallogic.com>
Thu, 28 Jul 2022 09:46:37 +0000 (12:46 +0300)
committerMarge Bot <emma+marge@anholt.net>
Fri, 17 Mar 2023 05:00:22 +0000 (05:00 +0000)
Fixes: 13106e10 ("glsl: Generate code for constant ir_binop_lshift and ir_binop_rshift expressions")

Signed-off-by: Sviatoslav Peleshko <sviatoslav.peleshko@globallogic.com>
Reviewed-by: Matt Turner <mattst88@gmail.com>
Reviewed-by: Kenneth Graunke <kenneth@whitecape.org>
Reviewed-by: Timothy Arceri <tarceri@itsqueeze.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/17787>

src/compiler/glsl/ir_expression_operation.py

index c9f9831..fa3118b 100644 (file)
@@ -22,6 +22,7 @@
 
 import mako.template
 import sys
+import itertools
 
 class type(object):
    def __init__(self, c_type, union_field, glsl_type):
@@ -112,20 +113,10 @@ constant_template_common = mako.template.Template("""\
       break;""")
 
 # This template is for binary operations that can operate on some combination
-# of scalar and vector operands.
+# of scalar and vector operands where both source types are the same.
 constant_template_vector_scalar = mako.template.Template("""\
    case ${op.get_enum_name()}:
-    % if "mixed" in op.flags:
-        % for i in range(op.num_operands):
-      assert(op[${i}]->type->base_type == ${op.source_types[0].glsl_type} ||
-            % for src_type in op.source_types[1:-1]:
-             op[${i}]->type->base_type == ${src_type.glsl_type} ||
-            % endfor
-             op[${i}]->type->base_type == ${op.source_types[-1].glsl_type});
-        % endfor
-    % else:
       assert(op[0]->type == op[1]->type || op0_scalar || op1_scalar);
-    % endif
       for (unsigned c = 0, c0 = 0, c1 = 0;
            c < components;
            c0 += c0_inc, c1 += c1_inc, c++) {
@@ -142,6 +133,41 @@ constant_template_vector_scalar = mako.template.Template("""\
       }
       break;""")
 
+# This template is for binary operations that can operate on some combination
+# of scalar and vector operands where the source types can be mixed.
+constant_template_vector_scalar_mixed = mako.template.Template("""\
+   case ${op.get_enum_name()}:
+        % for i in range(op.num_operands):
+      assert(op[${i}]->type->base_type == ${op.source_types[0].glsl_type} ||
+            % for src_type in op.source_types[1:-1]:
+             op[${i}]->type->base_type == ${src_type.glsl_type} ||
+            % endfor
+             op[${i}]->type->base_type == ${op.source_types[-1].glsl_type});
+        % endfor
+      for (unsigned c = 0, c0 = 0, c1 = 0;
+           c < components;
+           c0 += c0_inc, c1 += c1_inc, c++) {
+         <%
+           first_sig_dst_type, first_sig_src_types = op.signatures()[0]
+           last_sig_dst_type, last_sig_src_types = op.signatures()[-1]
+         %>
+         if (op[0]->type->base_type == ${first_sig_src_types[0].glsl_type} &&
+             op[1]->type->base_type == ${first_sig_src_types[1].glsl_type}) {
+            data.${first_sig_dst_type.union_field}[c] = ${op.get_c_expression(first_sig_src_types, ("c0", "c1", "c2"))};
+    % for dst_type, src_types in op.signatures()[1:-1]:
+         } else if (op[0]->type->base_type == ${src_types[0].glsl_type} &&
+                    op[1]->type->base_type == ${src_types[1].glsl_type}) {
+            data.${dst_type.union_field}[c] = ${op.get_c_expression(src_types, ("c0", "c1", "c2"))};
+    % endfor
+         } else if (op[0]->type->base_type == ${last_sig_src_types[0].glsl_type} &&
+                    op[1]->type->base_type == ${last_sig_src_types[1].glsl_type}) {
+            data.${last_sig_dst_type.union_field}[c] = ${op.get_c_expression(last_sig_src_types, ("c0", "c1", "c2"))};
+         } else {
+            unreachable("invalid types");
+         }
+      }
+      break;""")
+
 # This template is for multiplication.  It is unique because it has to support
 # matrix * vector and matrix * matrix operations, and those are just different.
 constant_template_mul = mako.template.Template("""\
@@ -378,7 +404,10 @@ class operation(object):
          elif self.name == "vector_extract":
             return constant_template_vector_extract.render(op=self)
          elif vector_scalar_operation in self.flags:
-            return constant_template_vector_scalar.render(op=self)
+            if mixed_type_operation in self.flags:
+               return constant_template_vector_scalar_mixed.render(op=self)
+            else:
+               return constant_template_vector_scalar.render(op=self)
       elif self.num_operands == 3:
          if self.name == "vector_insert":
             return constant_template_vector_insert.render(op=self)
@@ -663,8 +692,12 @@ ir_expression_operation = [
    operation("any_nequal", 2, source_types=all_types, dest_type=bool_type, c_expression="!op[0]->has_value(op[1])", flags=frozenset((horizontal_operation, types_identical_operation))),
 
    # Bit-wise binary operations.
-   operation("lshift", 2, printable_name="<<", source_types=integer_types, c_expression="{src0} << {src1}", flags=frozenset((vector_scalar_operation, mixed_type_operation))),
-   operation("rshift", 2, printable_name=">>", source_types=integer_types, c_expression="{src0} >> {src1}", flags=frozenset((vector_scalar_operation, mixed_type_operation))),
+   operation("lshift", 2,
+             printable_name="<<", all_signatures=list((src_sig[0], src_sig) for src_sig in itertools.product(integer_types, repeat=2)),
+             source_types=integer_types, c_expression="{src0} << {src1}", flags=frozenset((vector_scalar_operation, mixed_type_operation))),
+   operation("rshift", 2,
+             printable_name=">>", all_signatures=list((src_sig[0], src_sig) for src_sig in itertools.product(integer_types, repeat=2)),
+             source_types=integer_types, c_expression="{src0} >> {src1}", flags=frozenset((vector_scalar_operation, mixed_type_operation))),
    operation("bit_and", 2, printable_name="&", source_types=integer_types, c_expression="{src0} & {src1}", flags=vector_scalar_operation),
    operation("bit_xor", 2, printable_name="^", source_types=integer_types, c_expression="{src0} ^ {src1}", flags=vector_scalar_operation),
    operation("bit_or", 2, printable_name="|", source_types=integer_types, c_expression="{src0} | {src1}", flags=vector_scalar_operation),