TORCH_CHECK(
fbgemm::fbgemmSupportedCPU(), "Your CPU doesn't support FBGEMM.");
- TORCH_INTERNAL_ASSERT(!ReluFused);
- return packed_weight->apply_dynamic(std::move(input));
+ auto output = packed_weight->apply_dynamic(std::move(input));
+
+ // Call the relu operator here until fp16 linear dynamic in FBGEMM
+ // supports it natively.
+ if (ReluFused) {
+ output.relu_();
+ }
+ return output;
}
#else // USE_FBGEMM
static at::Tensor run(
m.impl(TORCH_SELECTIVE_NAME("quantized::linear_dynamic"), TORCH_FN(QLinearDynamicInt8<false>::run));
m.impl(TORCH_SELECTIVE_NAME("quantized::linear_relu_dynamic"), TORCH_FN(QLinearDynamicInt8<true>::run));
m.impl(TORCH_SELECTIVE_NAME("quantized::linear_dynamic_fp16"), TORCH_FN(QLinearDynamicFp16<false>::run));
+ m.impl(TORCH_SELECTIVE_NAME("quantized::linear_relu_dynamic_fp16"), TORCH_FN(QLinearDynamicFp16<true>::run));
}
TORCH_LIBRARY_IMPL(_quantized, CPU, m) {
m.def(TORCH_SELECTIVE_SCHEMA("quantized::linear_dynamic(Tensor X, __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack, bool reduce_range=False) -> Tensor Y"));
m.def(TORCH_SELECTIVE_SCHEMA("quantized::linear_relu_dynamic(Tensor X, __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack, bool reduce_range=False) -> Tensor Y"));
m.def(TORCH_SELECTIVE_SCHEMA("quantized::linear_dynamic_fp16(Tensor X, __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack) -> Tensor Y"));
+ m.def(TORCH_SELECTIVE_SCHEMA("quantized::linear_relu_dynamic_fp16(Tensor X, __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack) -> Tensor Y"));
m.def(TORCH_SELECTIVE_SCHEMA("quantized::linear_prepack(Tensor W, Tensor? B=None) -> __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack"));
m.def(TORCH_SELECTIVE_SCHEMA("quantized::linear_prepack_fp16(Tensor W, Tensor? B=None) -> __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack"));
m.def(TORCH_SELECTIVE_SCHEMA("quantized::linear_prepack_legacy(Tensor W, Tensor? B=None) -> Tensor W_prepack"));
self.assertEqual(Y_fp32, Y_fp32_ref,
msg="torch.ops.quantized.fbgemm_linear_dynamic results are off")
+ @skipIfNoFBGEMM
+ def test_qlinear_dynamic_fp16(self):
+
+ options = itertools.product(
+ (2, 4), # batch_size
+ (4, 5, 12), # input_channels
+ (4, 7, 8), # output_channels
+ (True, False), # use_bias
+ (True, False), # use_relu
+ )
+ for batch_size, input_channels, output_channels, use_bias, use_relu in options:
+ qlinear_prepack = torch.ops.quantized.linear_prepack_fp16
+ if use_relu:
+ qlinear_dynamic = torch.ops.quantized.linear_relu_dynamic_fp16
+ else:
+ qlinear_dynamic = torch.ops.quantized.linear_dynamic_fp16
+
+ x = torch.randn(batch_size, input_channels)
+ w = torch.randn(output_channels, input_channels)
+ bias = torch.randn(output_channels) if use_bias else None
+
+ w_packed = qlinear_prepack(w, bias)
+ out = qlinear_dynamic(x, w_packed)
+
+ # qlinear_dynamic_fp16 uses FP32 activation tensors and FP16 weight tensors
+ # output is FP32
+ w_fp16 = w.to(torch.float16).to(torch.float32)
+ ref = F.linear(x, w_fp16, bias)
+ if use_relu:
+ ref.relu_()
+
+ self.assertEqual(out, ref)
class TestDynamicQuantizedRNNOp(TestCase):
"""Tests the correctness of the dynamic quantized lstm/gru."""