import math
a = 'a'
+b = 'b'
+c = 'c'
# The nir_lower_bit_size() pass gets rid of all 8bit ALUs but insert new u2u8
# and i2i8 operations to convert the result back to the original type after the
remove_unsupported_casts(no_8bit_conv, 8, 0xff, 255.0, -128.0, 127.0)
remove_unsupported_casts(no_16bit_conv, 16, 0xffff, 65535.0, -32768.0, 32767.0)
-lower_x2b = [
+algebraic_ops = [
(('b2b32', 'a'), ('b2i32', 'a')),
(('b2b1', 'a'), ('ine', ('b2i32', a), 0)),
+
+ # We don't support the saturating versions of these
+ (('sdot_4x8_iadd_sat', a, b, c), ('iadd_sat', ('sdot_4x8_iadd', a, b, 0), c)),
+ (('udot_4x8_uadd_sat', a, b, c), ('uadd_sat', ('udot_4x8_uadd', a, b, 0), c)),
]
no_16bit_conv += [
no_8bit_conv).render())
print(nir_algebraic.AlgebraicPass("dxil_nir_lower_16bit_conv",
no_16bit_conv).render())
- print(nir_algebraic.AlgebraicPass("dxil_nir_lower_x2b",
- lower_x2b).render())
+ print(nir_algebraic.AlgebraicPass("dxil_nir_algebraic",
+ algebraic_ops).render())
if __name__ == '__main__':
main()
options->lower_doubles_options = ~0;
if ((supported_int_sizes & 16) && (supported_float_sizes & 16))
options->support_16bit_alu = true;
+ if (shader_model_max >= SHADER_MODEL_6_4) {
+ options->has_sdot_4x8 = true;
+ options->has_udot_4x8 = true;
+ }
}
static bool
DXIL_INTR_RAW_BUFFER_LOAD = 139,
DXIL_INTR_RAW_BUFFER_STORE = 140,
+ DXIL_INTR_DOT4_ADD_I8_PACKED = 163,
+ DXIL_INTR_DOT4_ADD_U8_PACKED = 164,
+
DXIL_INTR_ANNOTATE_HANDLE = 216,
DXIL_INTR_CREATE_HANDLE_FROM_BINDING = 217,
DXIL_INTR_CREATE_HANDLE_FROM_HEAP = 218,
return true;
}
+static bool
+emit_dot4add_packed(struct ntd_context *ctx, nir_alu_instr *alu,
+ enum dxil_intr intr,
+ const struct dxil_value *src0,
+ const struct dxil_value *src1,
+ const struct dxil_value *accum)
+{
+ const struct dxil_func *f = dxil_get_function(&ctx->mod, "dx.op.dot4AddPacked", DXIL_I32);
+ if (!f)
+ return false;
+ const struct dxil_value *srcs[] = { dxil_module_get_int32_const(&ctx->mod, intr), accum, src0, src1 };
+ const struct dxil_value *v = dxil_emit_call(&ctx->mod, f, srcs, ARRAY_SIZE(srcs));
+ if (!v)
+ return false;
+
+ store_alu_dest(ctx, alu, 0, v);
+ return true;
+}
+
static bool emit_select(struct ntd_context *ctx, nir_alu_instr *alu,
const struct dxil_value *sel,
const struct dxil_value *val_true,
case nir_op_unpack_half_2x16_split_y: return emit_f16tof32(ctx, alu, src[0], true);
case nir_op_pack_half_2x16_split: return emit_f32tof16(ctx, alu, src[0], src[1]);
+ case nir_op_sdot_4x8_iadd: return emit_dot4add_packed(ctx, alu, DXIL_INTR_DOT4_ADD_I8_PACKED, src[0], src[1], src[2]);
+ case nir_op_udot_4x8_uadd: return emit_dot4add_packed(ctx, alu, DXIL_INTR_DOT4_ADD_U8_PACKED, src[0], src[1], src[2]);
+
case nir_op_i2i1:
case nir_op_u2u1:
case nir_op_b2i16:
NIR_PASS(progress, s, nir_opt_cse);
NIR_PASS(progress, s, nir_opt_peephole_select, 8, true, true);
NIR_PASS(progress, s, nir_opt_algebraic);
- NIR_PASS(progress, s, dxil_nir_lower_x2b);
+ NIR_PASS(progress, s, dxil_nir_algebraic);
if (s->options->lower_int64_options)
NIR_PASS(progress, s, nir_lower_int64);
NIR_PASS(progress, s, nir_lower_alu);