End to end hack to call server side Caffe2 ops (#18267)
authorDmytro Dzhulgakov <dzhulgakov@fb.com>
Fri, 22 Mar 2019 18:11:16 +0000 (11:11 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 22 Mar 2019 18:17:45 +0000 (11:17 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18267

Motivation: we don't actually want to use it for real under any circumstances. This is an idea to unblock our internal progress and parallelize workstreams. We can easily define schemas for all ops in question and implement forwarding to C2 ops which is NOT going to be performant. Then several things can be happening in parallel:
* move code of ops outside of C2 ops that depend on protobuf into c10
* development of optimization/fusion passes
* building python-level wrappers with clean API
* improving perf

This demonstrates, Relu, quant, dequant. It seems to cover all use cases necessary (maybe except weights prepacking). Ideally I'd demonstrate Conv, but will get to it later in a separate PR (contributions welcomed)

Reviewed By: ezyang

Differential Revision: D14531232

fbshipit-source-id: 4cd4a71ae0cb373c6c0e81f965c442b82a1b4069

test/test_quantized.py [new file with mode: 0644]
tools/build_variables.py
tools/run-clang-tidy-in-ci.sh
torch/CMakeLists.txt
torch/csrc/jit/register_quantized_ops.cpp [new file with mode: 0644]

diff --git a/test/test_quantized.py b/test/test_quantized.py
new file mode 100644 (file)
index 0000000..648be98
--- /dev/null
@@ -0,0 +1,51 @@
+import torch
+import torch.jit
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+import unittest
+from caffe2.python import core
+from common_utils import TestCase, run_tests, IS_WINDOWS, TEST_WITH_UBSAN, \
+    skipIfRocm, skipIfNoLapack, suppress_warnings, load_tests, IS_SANDCASTLE, \
+    freeze_rng_state, set_rng_seed
+
+
+def canonical(graph):
+    return str(torch._C._jit_pass_canonicalize(graph))
+
+
+@unittest.skipIf("Relu_ENGINE_DNNLOWP" not in core._REGISTERED_OPERATORS, "fbgemm-based Caffe2 ops are not linked")
+class TestQuantized(TestCase):
+    def test_relu(self):
+        a = (torch.tensor([4, 6, 1, 10], dtype=torch.uint8), 0.01, 5)
+        r = torch.ops.c10.quantized_relu(a)
+        np.testing.assert_equal(r[0].numpy(), torch.tensor([5, 6, 5, 10], dtype=torch.uint8).numpy())
+        np.testing.assert_almost_equal(0.01, r[1])
+        self.assertEqual(5, r[2])
+
+    def test_quantize(self):
+        a = (torch.tensor([4, 6, 1, 10], dtype=torch.uint8), 0.01, 5)
+        r = torch.ops.c10.dequantize(a)
+        np.testing.assert_almost_equal(r.numpy(), [-0.01, 0.01, -0.04, 0.05])
+        # default args
+        q_def = torch.ops.c10.quantize(r)
+        # specified
+        q = torch.ops.c10.quantize(r, scale=0.01, zero_point=5)
+        np.testing.assert_equal(q[0].numpy(), a[0].numpy())
+        np.testing.assert_almost_equal(q[1], a[1])
+        self.assertEqual(q[2], a[2])
+
+    def test_script(self):
+        @torch.jit.script
+        def foo(x):
+            # type: (Tuple[Tensor, float, int]) -> Tuple[Tensor, float, int]
+            return torch.ops.c10.quantized_relu(x)
+        self.assertExpectedInline(canonical(foo.graph), '''\
+graph(%x : (Tensor, float, int)):
+  %1 : (Tensor, float, int) = c10::quantized_relu(%x)
+  return (%1)
+''')
+
+
+if __name__ == '__main__':
+    run_tests()
index 39dce39..686287b 100644 (file)
@@ -89,6 +89,7 @@ libtorch_sources = [
     "torch/csrc/jit/passes/utils/memory_dag.cpp",
     "torch/csrc/jit/register_prim_ops.cpp",
     "torch/csrc/jit/register_special_ops.cpp",
+    "torch/csrc/jit/register_quantized_ops.cpp",
     "torch/csrc/jit/scope.cpp",
     "torch/csrc/jit/script/compiler.cpp",
     "torch/csrc/jit/script/edit_distance.cpp",
@@ -199,6 +200,7 @@ def add_torch_libs():
             "//caffe2/aten:ATen-cpu",
             "//caffe2/caffe2:caffe2_cpu",
             "//caffe2/torch/lib/libshm:libshm",
+            "//caffe2/caffe2/quantization/server:dnnlowp_ops",
         ],
         external_deps=[
             ("nanopb", None, "protobuf-nanopb"),
index 6426529..57ce282 100755 (executable)
@@ -38,12 +38,13 @@ fi
 # Run Clang-Tidy
 # The negative filters below are to exclude files that include onnx_pb.h or
 # caffe2_pb.h, otherwise we'd have to build protos as part of this CI job.
-time python tools/clang_tidy.py            \
-  --verbose                                \
-  --paths torch/csrc/                      \
-  --diff "$BASE_BRANCH"                    \
-  -g"-torch/csrc/distributed/Module.cpp"   \
-  -g"-torch/csrc/jit/export.cpp"           \
-  -g"-torch/csrc/jit/import.cpp"           \
-  -g"-torch/csrc/jit/netdef_converter.cpp" \
+time python tools/clang_tidy.py                  \
+  --verbose                                      \
+  --paths torch/csrc/                            \
+  --diff "$BASE_BRANCH"                          \
+  -g"-torch/csrc/distributed/Module.cpp"         \
+  -g"-torch/csrc/jit/export.cpp"                 \
+  -g"-torch/csrc/jit/import.cpp"                 \
+  -g"-torch/csrc/jit/netdef_converter.cpp"       \
+  -g"-torch/csrc/jit/register_quantized_ops.cpp" \
   "$@"
index 434ae44..bee3ef5 100644 (file)
@@ -169,6 +169,7 @@ set(TORCH_SRCS
   ${TORCH_SRC_DIR}/csrc/jit/fuser/interface.cpp
   ${TORCH_SRC_DIR}/csrc/jit/register_prim_ops.cpp
   ${TORCH_SRC_DIR}/csrc/jit/register_special_ops.cpp
+  ${TORCH_SRC_DIR}/csrc/jit/register_quantized_ops.cpp
   ${TORCH_SRC_DIR}/csrc/jit/scope.cpp
   ${TORCH_SRC_DIR}/csrc/jit/script/compiler.cpp
   ${TORCH_SRC_DIR}/csrc/jit/testing/file_check.cpp
diff --git a/torch/csrc/jit/register_quantized_ops.cpp b/torch/csrc/jit/register_quantized_ops.cpp
new file mode 100644 (file)
index 0000000..9173cb2
--- /dev/null
@@ -0,0 +1,146 @@
+// WARNING! WARNING! WARNING!
+// This file is a temporary hack to enable development of pytorch quantization
+//
+// It effectively wraps Caffe2 ops as is through custom jit ops mechanism
+// It obviously has terrible performance - caffe2 operator instance is created
+// on each invocation and also creation involves creating a protobuf (sigh...)
+//
+// Our plan is to implement quantized operators natively in c10 as operators and
+// also enforce some additional contracts on operator semantics:
+// - explicitly express weights prepacking as a separate operator to signify
+//   reliance on weights being constant
+// - don't modify arguments of the op (OperatorDef) to store data
+// - explicitly model figuring out quantization params for dynamic quantization
+//   instead of memorizing the first batch's params
+
+#include <torch/csrc/jit/custom_operator.h>
+#include <torch/csrc/jit/operator.h>
+
+#include <caffe2/core/operator.h>
+#include <caffe2/core/tensor_int8.h>
+#include <torch/csrc/autograd/variable.h>
+
+namespace torch {
+namespace jit {
+
+using caffe2::int8::Int8TensorCPU;
+
+namespace {
+
+caffe2::Tensor from_at_tensor(const c10::IValue& v) {
+  return caffe2::Tensor(autograd::Variable(std::move(v).toTensor()).data());
+}
+
+Int8TensorCPU from_proxy(const c10::IValue& proxy) {
+  auto t = std::move(proxy).toTuple();
+  Int8TensorCPU r;
+  r.t = from_at_tensor(t->elements()[0]);
+  r.scale = t->elements()[1].toDouble();
+  r.zero_point = t->elements()[2].toInt();
+  return r;
+}
+
+at::Tensor to_proxy(const caffe2::Tensor& t) {
+  return autograd::make_variable(at::Tensor(t.UnsafeSharedInstance()), false);
+}
+
+c10::intrusive_ptr<c10::ivalue::Tuple> to_proxy(const Int8TensorCPU& t) {
+  return c10::ivalue::Tuple::create({to_proxy(t.t), t.scale, t.zero_point});
+}
+
+// TODO: replace this with c10 registration when it's ready
+RegisterOperators reg({
+    Operator(
+        // NOTE: we put outout in double parens because it's an output of type
+        // tuple, not a tuple of multiple outputs
+        "c10::quantized_relu((Tensor, float, int) self) -> ((Tensor, float, int))",
+        // TODO: can't use C++ inference - doesn't work yet for tuple types
+        [](Stack& stack) {
+          AT_ASSERT(caffe2::GetRegisteredOperators().count(
+              caffe2::OpRegistryKey("Relu", "DNNLOWP")))
+
+          // TODO: refactor the underlying op implementation and inline it in
+          // c10 kernel
+          caffe2::Workspace ws;
+          ws.CreateBlob("X")->Reset(
+              new Int8TensorCPU(from_proxy(std::move(peek(stack, 0, 1)))));
+
+          auto def = caffe2::CreateOperatorDef(
+              "Relu", "proxy", {"X"}, {"Y"}, caffe2::DeviceOption(), "DNNLOWP");
+          auto op = caffe2::CreateOperator(def, &ws);
+
+          op->Run();
+
+          drop(stack, 1);
+          pack(stack, to_proxy(ws.GetBlob("Y")->Get<Int8TensorCPU>()));
+          return 0;
+        }),
+
+    Operator(
+        "c10::quantize(Tensor X, float? scale = None, int? zero_point = None) -> ((Tensor, float, int))",
+        [](Stack& stack) {
+          AT_ASSERT(caffe2::GetRegisteredOperators().count(
+              caffe2::OpRegistryKey("Quantize", "DNNLOWP")))
+
+          // TODO: refactor the underlying op implementation and inline it in
+          // c10 kernel
+          caffe2::Workspace ws;
+          ws.CreateBlob("X")->Reset(
+              new caffe2::Tensor(from_at_tensor(std::move(peek(stack, 0, 3)))));
+
+          auto def = caffe2::CreateOperatorDef(
+              "Quantize",
+              "proxy",
+              {"X"},
+              {"Y"},
+              caffe2::DeviceOption(),
+              "DNNLOWP");
+          auto s = peek(stack, 1, 3).toOptional<float>();
+          if (s.has_value()) {
+            def.add_arg()->CopyFrom(caffe2::MakeArgument("Y_scale", *s));
+          }
+          auto zp = peek(stack, 2, 3).toOptional<int32_t>();
+          if (zp.has_value()) {
+            def.add_arg()->CopyFrom(caffe2::MakeArgument("Y_zero_point", *zp));
+          }
+          auto op = caffe2::CreateOperator(def, &ws);
+
+          op->Run();
+
+          drop(stack, 3);
+          pack(stack, to_proxy(ws.GetBlob("Y")->Get<Int8TensorCPU>()));
+          return 0;
+        }),
+
+    Operator(
+        "c10::dequantize((Tensor, float, int) x_q) -> Tensor",
+        // TODO: can't use C++ inference - doesn't work yet for tuple types
+        [](Stack& stack) {
+          AT_ASSERT(caffe2::GetRegisteredOperators().count(
+              caffe2::OpRegistryKey("Dequantize", "DNNLOWP")))
+
+          // TODO: refactor the underlying op implementation and inline it in
+          // c10 kernel
+          caffe2::Workspace ws;
+          ws.CreateBlob("X")->Reset(
+              new Int8TensorCPU(from_proxy(std::move(peek(stack, 0, 1)))));
+
+          auto def = caffe2::CreateOperatorDef(
+              "Dequantize",
+              "proxy",
+              {"X"},
+              {"Y"},
+              caffe2::DeviceOption(),
+              "DNNLOWP");
+          auto op = caffe2::CreateOperator(def, &ws);
+
+          op->Run();
+
+          drop(stack, 1);
+          pack(stack, to_proxy(ws.GetBlob("Y")->Get<caffe2::Tensor>()));
+          return 0;
+        }),
+});
+} // namespace
+} // namespace jit
+} // namespace torch