#include <algorithm>
#include <cmath>
#include <cstdint>
-#include <limits>
#include <immintrin.h>
int32_t b_zero_point,
float c_scale,
int32_t c_zero_point) {
- // TODO: this intrinsic code is replicated in dnnlowp.cc,
- // fbgemm_i8i8_acc32.cc, conv_dnnlowp_op.cc, and here.
- // We need to somehow refactor this.
- __m256 min_v = _mm256_set1_ps(numeric_limits<uint8_t>::min());
- __m256 max_v = _mm256_set1_ps(numeric_limits<uint8_t>::max());
-
- __m256i shuffle_mask_v = _mm256_set_epi8(
- 0xff,
- 0xff,
- 0xff,
- 0xff,
- 0xff,
- 0xff,
- 0xff,
- 0xff,
- 0xff,
- 0xff,
- 0xff,
- 0xff,
- 0x0c,
- 0x08,
- 0x04,
- 0x00,
- 0xff,
- 0xff,
- 0xff,
- 0xff,
- 0xff,
- 0xff,
- 0xff,
- 0xff,
- 0xff,
- 0xff,
- 0xff,
- 0xff,
- 0x0c,
- 0x08,
- 0x04,
- 0x00);
__m256i permute_mask_v =
- _mm256_set_epi32(0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00);
+ _mm256_set_epi32(0x07, 0x03, 0x06, 0x02, 0x05, 0x01, 0x04, 0x00);
- int len_aligned = len / VLEN * VLEN;
+ int len_aligned = len / (VLEN * 4) * (VLEN * 4);
int j = 0;
- for (; j < len_aligned; j += VLEN) {
+ for (; j < len_aligned; j += VLEN * 4) {
// Input is uint8_t but cvtepi8_epi32 assumes the input is int8_t,
// so we subtract 0x80, cvtepi8_epi32, and then add 0x80
- __m256 in_v = _mm256_cvtepi32_ps(_mm256_add_epi32(
+ // x
+ __m256 in_v0 = _mm256_cvtepi32_ps(_mm256_add_epi32(
_mm256_cvtepi8_epi32(_mm_sub_epi8(
_mm_loadl_epi64(reinterpret_cast<const __m128i*>(input0 + j)),
_mm_set1_epi8(0x80))),
_mm256_set1_epi32(0x80)));
- in_v = _mm256_fmadd_ps(
- in_v, _mm256_set1_ps(a_scale), _mm256_set1_ps(-a_zero_point * a_scale));
- __m256 acc_v = in_v;
+ in_v0 = _mm256_fmadd_ps(
+ in_v0,
+ _mm256_set1_ps(a_scale),
+ _mm256_set1_ps(-a_zero_point * a_scale - b_zero_point * b_scale));
- in_v = _mm256_cvtepi32_ps(_mm256_add_epi32(
+ __m256 in_v1 = _mm256_cvtepi32_ps(_mm256_add_epi32(
_mm256_cvtepi8_epi32(_mm_sub_epi8(
_mm_loadl_epi64(reinterpret_cast<const __m128i*>(input1 + j)),
_mm_set1_epi8(0x80))),
_mm256_set1_epi32(0x80)));
- in_v = _mm256_fmadd_ps(
- in_v, _mm256_set1_ps(b_scale), _mm256_set1_ps(-b_zero_point * b_scale));
- acc_v = _mm256_add_ps(acc_v, in_v);
+ __m256 acc_v = _mm256_fmadd_ps(in_v1, _mm256_set1_ps(b_scale), in_v0);
- __m256 transformed_v = _mm256_fmadd_ps(
+ __m256 x_transformed_v = _mm256_fmadd_ps(
acc_v, _mm256_set1_ps(1.0 / c_scale), _mm256_set1_ps(c_zero_point));
- __m256 clipped_v = _mm256_min_ps(
- _mm256_max_ps(
- transformed_v, ReluFused ? _mm256_set1_ps(c_zero_point) : min_v),
- max_v);
- __m256i rounded_v = _mm256_cvtps_epi32(clipped_v);
- rounded_v = _mm256_shuffle_epi8(rounded_v, shuffle_mask_v);
- rounded_v = _mm256_permutevar8x32_epi32(rounded_v, permute_mask_v);
- *reinterpret_cast<int64_t*>(output + j) =
- _mm256_extract_epi64(rounded_v, 0);
+
+ // y
+ in_v0 = _mm256_cvtepi32_ps(_mm256_add_epi32(
+ _mm256_cvtepi8_epi32(_mm_sub_epi8(
+ _mm_loadl_epi64(
+ reinterpret_cast<const __m128i*>(input0 + j + VLEN)),
+ _mm_set1_epi8(0x80))),
+ _mm256_set1_epi32(0x80)));
+ in_v0 = _mm256_fmadd_ps(
+ in_v0,
+ _mm256_set1_ps(a_scale),
+ _mm256_set1_ps(-a_zero_point * a_scale - b_zero_point * b_scale));
+
+ in_v1 = _mm256_cvtepi32_ps(_mm256_add_epi32(
+ _mm256_cvtepi8_epi32(_mm_sub_epi8(
+ _mm_loadl_epi64(
+ reinterpret_cast<const __m128i*>(input1 + j + VLEN)),
+ _mm_set1_epi8(0x80))),
+ _mm256_set1_epi32(0x80)));
+ acc_v = _mm256_fmadd_ps(in_v1, _mm256_set1_ps(b_scale), in_v0);
+
+ __m256 y_transformed_v = _mm256_fmadd_ps(
+ acc_v, _mm256_set1_ps(1.0 / c_scale), _mm256_set1_ps(c_zero_point));
+
+ // z
+ in_v0 = _mm256_cvtepi32_ps(_mm256_add_epi32(
+ _mm256_cvtepi8_epi32(_mm_sub_epi8(
+ _mm_loadl_epi64(
+ reinterpret_cast<const __m128i*>(input0 + j + 2 * VLEN)),
+ _mm_set1_epi8(0x80))),
+ _mm256_set1_epi32(0x80)));
+ in_v0 = _mm256_fmadd_ps(
+ in_v0,
+ _mm256_set1_ps(a_scale),
+ _mm256_set1_ps(-a_zero_point * a_scale - b_zero_point * b_scale));
+
+ in_v1 = _mm256_cvtepi32_ps(_mm256_add_epi32(
+ _mm256_cvtepi8_epi32(_mm_sub_epi8(
+ _mm_loadl_epi64(
+ reinterpret_cast<const __m128i*>(input1 + j + 2 * VLEN)),
+ _mm_set1_epi8(0x80))),
+ _mm256_set1_epi32(0x80)));
+ acc_v = _mm256_fmadd_ps(in_v1, _mm256_set1_ps(b_scale), in_v0);
+
+ __m256 z_transformed_v = _mm256_fmadd_ps(
+ acc_v, _mm256_set1_ps(1.0 / c_scale), _mm256_set1_ps(c_zero_point));
+
+ // w
+ in_v0 = _mm256_cvtepi32_ps(_mm256_add_epi32(
+ _mm256_cvtepi8_epi32(_mm_sub_epi8(
+ _mm_loadl_epi64(
+ reinterpret_cast<const __m128i*>(input0 + j + 3 * VLEN)),
+ _mm_set1_epi8(0x80))),
+ _mm256_set1_epi32(0x80)));
+ in_v0 = _mm256_fmadd_ps(
+ in_v0,
+ _mm256_set1_ps(a_scale),
+ _mm256_set1_ps(-a_zero_point * a_scale - b_zero_point * b_scale));
+
+ in_v1 = _mm256_cvtepi32_ps(_mm256_add_epi32(
+ _mm256_cvtepi8_epi32(_mm_sub_epi8(
+ _mm_loadl_epi64(
+ reinterpret_cast<const __m128i*>(input1 + j + 3 * VLEN)),
+ _mm_set1_epi8(0x80))),
+ _mm256_set1_epi32(0x80)));
+ acc_v = _mm256_fmadd_ps(in_v1, _mm256_set1_ps(b_scale), in_v0);
+
+ __m256 w_transformed_v = _mm256_fmadd_ps(
+ acc_v, _mm256_set1_ps(1.0 / c_scale), _mm256_set1_ps(c_zero_point));
+
+ // See fbgemm/src/QuantUtilsAvx2.cc requantizeOutputProcessingAvx2 function
+ // for more details on this instruction sequence
+ __m256i x_rounded_v = _mm256_cvtps_epi32(x_transformed_v);
+ __m256i y_rounded_v = _mm256_cvtps_epi32(y_transformed_v);
+ __m256i z_rounded_v = _mm256_cvtps_epi32(z_transformed_v);
+ __m256i w_rounded_v = _mm256_cvtps_epi32(w_transformed_v);
+
+ __m256i xy_packed_v = _mm256_packs_epi32(x_rounded_v, y_rounded_v);
+ __m256i zw_packed_v = _mm256_packs_epi32(z_rounded_v, w_rounded_v);
+ __m256i xyzw_packed_v = _mm256_packus_epi16(xy_packed_v, zw_packed_v);
+ __m256i xyzw_clamped_v = _mm256_max_epu8(
+ ReluFused ? _mm256_set1_epi8(c_zero_point) : _mm256_setzero_si256(),
+ _mm256_min_epu8(
+ xyzw_packed_v, _mm256_set1_epi8(static_cast<uint8_t>(255))));
+
+ xyzw_clamped_v =
+ _mm256_permutevar8x32_epi32(xyzw_clamped_v, permute_mask_v);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(output + j), xyzw_clamped_v);
}
for (; j < len; ++j) {
float acc = 0;