lst.append(a)
return lst
)JIT";
+
+const std::string quantize_script = R"IR(
+ graph(%input: Tensor, %weights: Tensor):
+ %scale: float = prim::Constant[value=1.]()
+ %zero_point: int = prim::Constant[value=1]()
+ %bias: None = prim::Constant()
+ %packed_params = quantized::linear_prepack(%weights, %bias)
+ %1254 = quantized::linear(%input, %packed_params, %scale, %zero_point)
+ %1249: Tensor = aten::dequantize(%1254)
+ return (%1249)
+)IR";
testStaticRuntime(append_tensor_script, args_tensor);
testStaticRuntime(append_tensor_script, args_tensor, args_tensor_large);
}
+
+TEST(StaticRuntime, QuantizedLinear) {
+ at::Tensor weight =
+ at::quantize_per_tensor(torch::randn({3, 2}), 2, 3, torch::kQInt8);
+ at::Tensor input =
+ at::quantize_per_tensor(torch::randn({3, 2}), 2, 3, torch::kQUInt8);
+
+ at::Tensor weight_2 =
+ at::quantize_per_tensor(torch::randn({4, 3}), 2, 3, torch::kQInt8);
+ at::Tensor input_2 =
+ at::quantize_per_tensor(torch::randn({4, 3}), 2, 3, torch::kQUInt8);
+
+ testStaticRuntime(quantize_script, {input, weight}, {input_2, weight_2});
+}
};
});
+REGISTER_OPERATOR_FUNCTOR(
+ fb::quantized_linear,
+ fb_quantized_linear,
+ [](Node* n) -> SROperator {
+ if (!n->matches(torch::schema(
+ "fb::quantized_linear(Tensor X, __torch__.torch.classes.quantized.LinearPackedParamsBase w_prepack, Tensor Y_scale_i, Tensor Y_zero_point_i) -> Tensor"))) {
+ LogAndDumpSchema(n);
+ return nullptr;
+ }
+ const auto w = toIValue(n->inputs()[1]);
+ c10::intrusive_ptr<LinearPackedParamsBase> packed_weight;
+ if (w) {
+ packed_weight = w->toCustomClass<LinearPackedParamsBase>();
+ }
+ return [packed_weight](ProcessedNode* p_node) {
+ const auto& input = p_node->Input(0).toTensor();
+ const auto output_scale = p_node->Input(2).toTensor().item().toFloat();
+ const auto output_zero_point =
+ p_node->Input(3).toTensor().item().toLong();
+
+ if (p_node->Output(0).isNone()) {
+ p_node->Output(0) = at::native::empty_affine_quantized(
+ {0},
+ c10::kQUInt8,
+ c10::nullopt,
+ c10::kCPU,
+ false,
+ output_scale,
+ output_zero_point,
+ c10::nullopt);
+ }
+ auto& out_t = p_node->Output(0).toTensor();
+ fastResizeToZero(out_t);
+
+ if (packed_weight) {
+ packed_weight->apply_out(
+ input, output_scale, output_zero_point, out_t);
+ } else {
+ // Weights could be quantized on the fly
+ auto packed_weight_tmp =
+ p_node->Input(1).toCustomClass<LinearPackedParamsBase>();
+ packed_weight_tmp->apply_out(
+ input, output_scale, output_zero_point, out_t);
+ }
+ };
+ });
+
REGISTER_OPERATOR_FUNCTOR(aten::full, aten_full, [](Node* n) -> SROperator {
if (!n->matches(torch::schema(
"aten::full(int[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"))) {