[Static Runtime] Implement out variant for fb::quantized_linear (#63635)
authorHao Lu <hlu@fb.com>
Sat, 21 Aug 2021 04:41:19 +0000 (21:41 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Sat, 21 Aug 2021 04:42:22 +0000 (21:42 -0700)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63635

Reviewed By: ajyu

Differential Revision: D30446234

fbshipit-source-id: 1ef014186ff725930a97d0159626f9233ee74030

benchmarks/static_runtime/test_scripts.h
benchmarks/static_runtime/test_static_runtime.cc
torch/csrc/jit/runtime/static/ops.cpp

index 9946c7a..7338012 100644 (file)
@@ -719,3 +719,14 @@ const auto append_tensor_script = R"JIT(
       lst.append(a)
       return lst
 )JIT";
+
+const std::string quantize_script = R"IR(
+  graph(%input: Tensor, %weights: Tensor):
+      %scale: float = prim::Constant[value=1.]()
+      %zero_point: int = prim::Constant[value=1]()
+      %bias: None = prim::Constant()
+      %packed_params = quantized::linear_prepack(%weights, %bias)
+      %1254 = quantized::linear(%input, %packed_params, %scale, %zero_point)
+      %1249: Tensor = aten::dequantize(%1254)
+      return (%1249)
+)IR";
index ec703ef..dfe2c14 100644 (file)
@@ -1172,3 +1172,17 @@ TEST(StaticRuntime, IndividualOps_Append) {
   testStaticRuntime(append_tensor_script, args_tensor);
   testStaticRuntime(append_tensor_script, args_tensor, args_tensor_large);
 }
+
+TEST(StaticRuntime, QuantizedLinear) {
+  at::Tensor weight =
+      at::quantize_per_tensor(torch::randn({3, 2}), 2, 3, torch::kQInt8);
+  at::Tensor input =
+      at::quantize_per_tensor(torch::randn({3, 2}), 2, 3, torch::kQUInt8);
+
+  at::Tensor weight_2 =
+      at::quantize_per_tensor(torch::randn({4, 3}), 2, 3, torch::kQInt8);
+  at::Tensor input_2 =
+      at::quantize_per_tensor(torch::randn({4, 3}), 2, 3, torch::kQUInt8);
+
+  testStaticRuntime(quantize_script, {input, weight}, {input_2, weight_2});
+}
index eef5595..2543182 100644 (file)
@@ -1528,6 +1528,53 @@ REGISTER_OPERATOR_FUNCTOR(quantized::linear, quantized_linear, [](Node* n) -> SR
   };
 });
 
+REGISTER_OPERATOR_FUNCTOR(
+    fb::quantized_linear,
+    fb_quantized_linear,
+    [](Node* n) -> SROperator {
+      if (!n->matches(torch::schema(
+              "fb::quantized_linear(Tensor X, __torch__.torch.classes.quantized.LinearPackedParamsBase w_prepack, Tensor Y_scale_i, Tensor Y_zero_point_i) -> Tensor"))) {
+        LogAndDumpSchema(n);
+        return nullptr;
+      }
+      const auto w = toIValue(n->inputs()[1]);
+      c10::intrusive_ptr<LinearPackedParamsBase> packed_weight;
+      if (w) {
+        packed_weight = w->toCustomClass<LinearPackedParamsBase>();
+      }
+      return [packed_weight](ProcessedNode* p_node) {
+        const auto& input = p_node->Input(0).toTensor();
+        const auto output_scale = p_node->Input(2).toTensor().item().toFloat();
+        const auto output_zero_point =
+            p_node->Input(3).toTensor().item().toLong();
+
+        if (p_node->Output(0).isNone()) {
+          p_node->Output(0) = at::native::empty_affine_quantized(
+              {0},
+              c10::kQUInt8,
+              c10::nullopt,
+              c10::kCPU,
+              false,
+              output_scale,
+              output_zero_point,
+              c10::nullopt);
+        }
+        auto& out_t = p_node->Output(0).toTensor();
+        fastResizeToZero(out_t);
+
+        if (packed_weight) {
+          packed_weight->apply_out(
+              input, output_scale, output_zero_point, out_t);
+        } else {
+          // Weights could be quantized on the fly
+          auto packed_weight_tmp =
+              p_node->Input(1).toCustomClass<LinearPackedParamsBase>();
+          packed_weight_tmp->apply_out(
+              input, output_scale, output_zero_point, out_t);
+        }
+      };
+    });
+
 REGISTER_OPERATOR_FUNCTOR(aten::full, aten_full, [](Node* n) -> SROperator {
   if (!n->matches(torch::schema(
           "aten::full(int[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"))) {