class BF16Transformer {
const InferenceEngine::details::caseless_set<std::string> _initbf16 =
- { "convolution", "fullyconnected", "innerproduct" };
+ { "convolution", "fullyconnected", "innerproduct", "gemm" };
const InferenceEngine::details::caseless_set<std::string> _complementbf16 =
{ "relu", "tanh", "elu", "square", "abs", "sqrt", "linear", "bounded_relu", "soft_relu", "logistic",
"exp", "gelu", "clamp", "swish", "prelu", "pooling", "norm", "gather", "memory" };
auto inPrec0 = getCnnLayer()->insData[0].lock()->getPrecision();
auto inPrec1 = getCnnLayer()->insData[1].lock()->getPrecision();
if ((inPrec0 != Precision::U8 && inPrec0 != Precision::I8) || inPrec1 != Precision::I8 || isThreeInputs) {
- inPrec0 = Precision::FP32;
- inPrec1 = Precision::FP32;
+ if (inPrec0 == Precision::BF16 || inPrec1 == Precision::BF16) {
+ inPrec0 = Precision::BF16;
+ inPrec1 = Precision::BF16;
+ } else {
+ inPrec0 = Precision::FP32;
+ inPrec1 = Precision::FP32;
+ }
}
auto inputDataType0 = MKLDNNExtensionUtils::IEPrecisionToDataType(inPrec0);
mkldnn_sgemm(transa, transb, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
}
+inline void process_gemm(char transa, char transb, int M, int N, int K, float alpha, const uint16_t *A, int lda,
+ const uint16_t *B, int ldb, float beta, float *C, int ldc) {
+ mkldnn_gemm_bf16bf16f32(transa, transb, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
+}
+
inline void process_gemm(char transa, char transb, int M, int N, int K, float alpha, const uint8_t *A, int lda,
const int8_t *B, int ldb, float beta, float *C, int ldc) {
const int32_t co = 0;
case Precision::FP32:
process_data<float, float>();
break;
+ case Precision::BF16:
+ process_data<uint16_t, uint16_t>();
+ break;
case Precision::I8:
process_data<int8_t, int8_t>();
break;