Automatically register c10 ops with JIT (#16534)
authorSebastian Messmer <messmer@fb.com>
Thu, 7 Feb 2019 05:14:20 +0000 (21:14 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 7 Feb 2019 05:21:33 +0000 (21:21 -0800)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16534

All c10 ops from the c10 dispatcher are now automatically registered with JIT

Reviewed By: dzhulgakov

Differential Revision: D13869275

fbshipit-source-id: 5ab5dec5b983fe661f977f9d29d8036768cdcab6

aten/src/ATen/core/dispatch/Dispatcher.cpp
aten/src/ATen/core/dispatch/Dispatcher.h
aten/src/ATen/core/jit_type.h
c10/util/flat_hash_map.h
caffe2/operators/layer_norm_op.cc
tools/build_variables.py
torch/CMakeLists.txt
torch/csrc/jit/c10_ops/layer_norm.cpp [deleted file]
torch/csrc/jit/operator.h
torch/csrc/jit/register_c10_ops.cpp [new file with mode: 0644]

index 832b687..da0df1a 100644 (file)
@@ -1,8 +1,89 @@
 #include <ATen/core/dispatch/Dispatcher.h>
 
 namespace c10 {
+
+namespace detail {
+class RegistrationListenerList final {
+public:
+  void addListener(std::unique_ptr<OpRegistrationListener> listener) {
+    listeners_.push_back(std::move(listener));
+  }
+
+  void callOnOperatorRegistered(const OperatorHandle& op) {
+    for (auto& listener : listeners_) {
+      listener->onOperatorRegistered(op);
+    }
+  }
+
+  void callOnOperatorDeregistered(const OperatorHandle& op) {
+    for (auto& listener : listeners_) {
+      listener->onOperatorDeregistered(op);
+    }
+  }
+private:
+  std::vector<std::unique_ptr<OpRegistrationListener>> listeners_;
+};
+}
+
+OpRegistrationListener::~OpRegistrationListener() {}
+
+Dispatcher::Dispatcher()
+: operators_()
+, listeners_(guts::make_unique<detail::RegistrationListenerList>())
+, mutex_() {}
+
+Dispatcher::~Dispatcher() {}
+
 C10_EXPORT Dispatcher& Dispatcher::singleton() {
   static Dispatcher _singleton;
   return _singleton;
 }
+
+OperatorHandle Dispatcher::registerSchema(FunctionSchema schema) {
+  // we need a lock to avoid concurrent writes
+  std::lock_guard<std::mutex> lock(mutex_);
+
+  operators_.emplace_back(std::move(schema));
+  auto op = OperatorHandle(--operators_.end());
+
+  // note: call listeners *after* operator is added, i.e. dispatcher is already valid for new op
+  listeners_->callOnOperatorRegistered(op);
+
+  return op;
+}
+
+void Dispatcher::deregisterSchema(const OperatorHandle& op) {
+  // we need a lock to avoid concurrent writes
+  std::lock_guard<std::mutex> lock(mutex_);
+
+  if (!op.operatorDefIterator_->dispatchTable.isEmpty()) {
+    AT_ERROR("Tried to deregister op schema that still has kernels registered");
+  }
+
+  // note: call listeners *before* operator is removed, i.e. dispatcher is still valid for removed op
+  listeners_->callOnOperatorDeregistered(op);
+
+  operators_.erase(op.operatorDefIterator_);
+}
+
+void Dispatcher::registerKernel(const OperatorHandle& op, TensorTypeId dispatch_key, KernelFunction* kernel_func, KernelCacheCreatorFunction* cache_creator_func) {
+  // note: this doesn't need the mutex because write operations on the list keep iterators intact.
+  op.operatorDefIterator_->dispatchTable.registerKernel(std::move(dispatch_key), DispatchTableEntry{kernel_func, cache_creator_func});
+}
+
+void Dispatcher::deregisterKernel(const OperatorHandle& op, TensorTypeId dispatch_key) {
+  // note: this doesn't need the mutex because write operations on the list keep iterators intact.
+  op.operatorDefIterator_->dispatchTable.deregisterKernel(dispatch_key);
+}
+
+void Dispatcher::addRegistrationListener(std::unique_ptr<OpRegistrationListener> listener) {
+  std::lock_guard<std::mutex> lock(mutex_);
+
+  for (auto iter = operators_.begin(); iter != operators_.end(); ++iter) {
+    listener->onOperatorRegistered(OperatorHandle(iter));
+  }
+
+  listeners_->addListener(std::move(listener));
+}
+
 }
index bcca890..d1e71ba 100644 (file)
@@ -52,6 +52,23 @@ private:
 };
 
 /**
+ * Implement this interface and register your instance with the dispatcher
+ * to get notified when operators are registered or deregistered with
+ * the dispatcher.
+ */
+class CAFFE2_API OpRegistrationListener {
+public:
+  virtual ~OpRegistrationListener();
+
+  virtual void onOperatorRegistered(const OperatorHandle& op) = 0;
+  virtual void onOperatorDeregistered(const OperatorHandle& op) = 0;
+};
+
+namespace detail {
+class RegistrationListenerList;
+}
+
+/**
  * Top-level dispatch interface for dispatching via the dynamic dispatcher.
  */
 class CAFFE2_API Dispatcher final {
@@ -67,6 +84,8 @@ private:
   friend class OperatorHandle;
 
 public:
+  ~Dispatcher();
+
   // Implementation note: this class abstracts over the fact that we have per-operator
   // dispatch tables.  This could be easily adjusted to have a single global hash
   // table.
@@ -100,8 +119,19 @@ public:
    */
   OpKernel lookup(const OperatorHandle& op, const Stack* stack) const;
 
+  /**
+   * Add a listener that gets called whenever a new op is registered or an existing
+   * op is deregistered. Immediately after registering, this listener gets called
+   * for all previously registered ops, so it can be used to keep track of ops
+   * registered with this dispatcher.
+   */
+  void addRegistrationListener(std::unique_ptr<OpRegistrationListener> listener);
+
 private:
+  Dispatcher();
+
   std::list<OperatorDef> operators_;
+  std::unique_ptr<detail::RegistrationListenerList> listeners_;
   std::mutex mutex_;
 };
 
@@ -130,35 +160,6 @@ private:
 };
 
 
-
-inline OperatorHandle Dispatcher::registerSchema(FunctionSchema schema) {
-  // we need a lock to avoid concurrent writes
-  std::lock_guard<std::mutex> lock(mutex_);
-
-  operators_.emplace_back(std::move(schema));
-  return OperatorHandle(--operators_.end());
-}
-
-inline void Dispatcher::deregisterSchema(const OperatorHandle& op) {
-  // we need a lock to avoid concurrent writes
-  std::lock_guard<std::mutex> lock(mutex_);
-
-  if (!op.operatorDefIterator_->dispatchTable.isEmpty()) {
-    AT_ERROR("Tried to deregister op schema that still has kernels registered");
-  }
-  operators_.erase(op.operatorDefIterator_);
-}
-
-inline void Dispatcher::registerKernel(const OperatorHandle& op, TensorTypeId dispatch_key, KernelFunction* kernel_func, KernelCacheCreatorFunction* cache_creator_func) {
-  // note: this doesn't need the mutex because write operations on the list keep iterators intact.
-  op.operatorDefIterator_->dispatchTable.registerKernel(std::move(dispatch_key), DispatchTableEntry{kernel_func, cache_creator_func});
-}
-
-inline void Dispatcher::deregisterKernel(const OperatorHandle& op, TensorTypeId dispatch_key) {
-  // note: this doesn't need the mutex because write operations on the list keep iterators intact.
-  op.operatorDefIterator_->dispatchTable.deregisterKernel(dispatch_key);
-}
-
 inline OpKernel Dispatcher::lookup(const OperatorHandle& op, const Stack* stack) const {
   // note: this doesn't need the mutex because write operations on the list keep iterators intact.
   const DispatchTableEntry& kernel = op.operatorDefIterator_->dispatchTable.lookup(stack);
index 0b18380..abaa41d 100644 (file)
@@ -514,7 +514,7 @@ private:
 
 struct DictType;
 using DictTypePtr = std::shared_ptr<DictType>;
-struct DictType : public Type {
+struct CAFFE2_API DictType : public Type {
   friend struct Type;
   static const TypeKind Kind = TypeKind::DictType;
 
index 7f37919..5955b17 100644 (file)
@@ -924,6 +924,7 @@ private:
         return static_cast<Equal &>(*this)(lhs, rhs);
     }
 
+private:
     struct convertible_to_iterator
     {
         EntryPointer it;
index f452f4b..45eab81 100644 (file)
@@ -196,12 +196,32 @@ void layer_norm_c10(c10::Stack* stack, c10::KernelCache* cache_) { // TODO Pass
   c10::ArrayRef<c10::IValue> inputs = torch::jit::peekSlice(*stack, 0, 3, 6);
   c10::ArrayRef<c10::IValue> outputs = torch::jit::peekSlice(*stack, 3, 3, 6);
 
-  caffe2::Tensor X{c10::C10Tensor(inputs[0].toTensor())};
+
+  caffe2::Tensor X{inputs[0].toTensor()};
   int64_t axis = inputs[1].toInt();
   float epsilon = inputs[2].toDouble();
-  caffe2::Tensor Y{c10::C10Tensor(outputs[0].toTensor())};
-  caffe2::Tensor mean{c10::C10Tensor(outputs[1].toTensor())};
-  caffe2::Tensor sig{c10::C10Tensor(outputs[2].toTensor())};
+
+  auto device = X.GetDevice();
+
+  caffe2::Tensor Y, mean, sig;
+  if (outputs[0].isTensor()) {
+    Y = caffe2::Tensor(std::move(torch::jit::peek(*stack, 0, 3)).toTensor());
+  }
+  if (outputs[1].isTensor()) {
+    mean = caffe2::Tensor(std::move(torch::jit::peek(*stack, 1, 3)).toTensor());
+  }
+  if (outputs[2].isTensor()) {
+    sig = caffe2::Tensor(std::move(torch::jit::peek(*stack, 2, 3)).toTensor());
+  }
+  if (!Y.defined()) {
+    Y = caffe2::empty({0}, device);
+  }
+  if (!mean.defined()) {
+    mean = caffe2::empty({0}, device);
+  }
+  if (!sig.defined()) {
+    sig = caffe2::empty({0}, device);
+  }
 
   caffe2::CPUContext context;
   Cache* cache = static_cast<Cache*>(cache_);
@@ -226,9 +246,9 @@ void layer_norm_c10(c10::Stack* stack, c10::KernelCache* cache_) { // TODO Pass
 
   torch::jit::drop(*stack, 6);
   torch::jit::push(*stack,
-    at::Tensor(c10::C10Tensor(std::move(Y))),
-    at::Tensor(c10::C10Tensor(std::move(mean))),
-    at::Tensor(c10::C10Tensor(std::move(sig)))
+    at::Tensor(std::move(Y)),
+    at::Tensor(std::move(mean)),
+    at::Tensor(std::move(sig))
   );
 
   return;
index fc56cc5..0eccbae 100644 (file)
@@ -59,6 +59,7 @@ libtorch_sources = [
     "torch/csrc/jit/ir.cpp",
     "torch/csrc/jit/caffe2_operator.cpp",
     "torch/csrc/jit/register_caffe2_ops.cpp",
+    "torch/csrc/jit/register_c10_ops.cpp",
     "torch/csrc/jit/symbolic_script.cpp",
     "torch/csrc/jit/operator.cpp",
     "torch/csrc/jit/passes/alias_analysis.cpp",
@@ -101,7 +102,6 @@ libtorch_sources = [
     "torch/csrc/jit/script/lexer.cpp",
     "torch/csrc/jit/script/module.cpp",
     "torch/csrc/jit/tracer.cpp",
-    "torch/csrc/jit/c10_ops/layer_norm.cpp",
     "torch/csrc/utils/tensor_flatten.cpp",
     "torch/csrc/utils/variadic.cpp",
     "torch/csrc/jit/fuser/kernel_cache.cpp",
index d7678e4..6884aef 100644 (file)
@@ -135,6 +135,7 @@ set(TORCH_SRCS
   ${TORCH_SRC_DIR}/csrc/jit/ir.cpp
   ${TORCH_SRC_DIR}/csrc/jit/operator.cpp
   ${TORCH_SRC_DIR}/csrc/jit/caffe2_operator.cpp
+  ${TORCH_SRC_DIR}/csrc/jit/register_c10_ops.cpp
   ${TORCH_SRC_DIR}/csrc/jit/symbolic_script.cpp
   ${TORCH_SRC_DIR}/csrc/jit/passes/alias_analysis.cpp
   ${TORCH_SRC_DIR}/csrc/jit/passes/batch_mm.cpp
@@ -179,7 +180,6 @@ set(TORCH_SRCS
   ${TORCH_SRC_DIR}/csrc/jit/script/module.cpp
   ${TORCH_SRC_DIR}/csrc/jit/tracer.cpp
   ${TORCH_SRC_DIR}/csrc/jit/hooks_for_testing.cpp
-  ${TORCH_SRC_DIR}/csrc/jit/c10_ops/layer_norm.cpp
   ${TORCH_SRC_DIR}/csrc/utils/tensor_flatten.cpp
   ${TORCH_SRC_DIR}/csrc/utils/variadic.cpp
   ${TORCH_SRC_DIR}/csrc/jit/fuser/kernel_cache.cpp
diff --git a/torch/csrc/jit/c10_ops/layer_norm.cpp b/torch/csrc/jit/c10_ops/layer_norm.cpp
deleted file mode 100644 (file)
index 02fdf89..0000000
+++ /dev/null
@@ -1,64 +0,0 @@
-#include <ATen/core/dispatch/Dispatcher.h>
-#include <ATen/core/opschema/layer_norm.h>
-#include <ATen/core/ivalue.h>
-#include <torch/csrc/autograd/variable.h>
-#include <torch/csrc/jit/operator.h>
-#include <ATen/core/stack.h>
-#include <torch/csrc/jit/custom_operator.h>
-
-using at::Tensor;
-using c10::IValue;
-using c10::ArrayRef;
-
-namespace torch {
-namespace jit {
-
-// TODO This code is currently written specifically for LayerNorm, but it is
-//      *not* the plan to have to write this manually for each operation.
-//      This is just a proof of concept. To expand this to all operators,
-//      we'd ideally not need any per-operator code (possibly thanks to boxing
-//      or templates). If that's not possible, then we should at least offer
-//      a macro that takes this burden so that we only need to write one line
-//      for each operation we want to support (i.e. the macro invocation).
-
-// TODO This currently only handles tensors with requires_grad==False correctly.
-//      It should also handle autograd.
-
-namespace {
-RegisterOperators reg({
-  Operator(
-    //Note: This schema is: caffe2::layer_norm_dont_use_this_op_yet(Tensor input, int axis, float epsilon, Tensor? output = None, Tensor? output_mean = None, Tensor? output_stdev = None) -> (Tensor, Tensor, Tensor)
-    c10::core::opschema::LayerNorm().schema(),
-    [](Stack& stack) {
-        Tensor tensor_input = std::move(stack[stack.size()-6]).toTensor();
-        if (tensor_input.requires_grad()) {
-          throw std::runtime_error("Autograd not yet supported for c10 ops.");
-        }
-        auto device = tensor_input.device();
-
-        // unwrap inputs from variable
-        torch::jit::peek(stack, 0, 6) = torch::autograd::Variable(std::move(tensor_input)).data();
-
-        // allocate the output tensors that aren't set yet
-        for (int i = 3; i < 6; ++i) {
-          // TODO this should just check for isNone, not for undefined tensor. @wanchaol is working on this.
-          if (torch::jit::peek(stack, i, 6).isNone() || !torch::jit::peek(stack, i, 6).toTensor().defined()) {
-            torch::jit::peek(stack, i, 6) = at::empty({0}, device);
-          }
-        }
-
-        // call caffe2 kernel
-        c10::Dispatcher::singleton().lookup(c10::core::opschema::LayerNorm(), &stack).call(&stack);
-
-        // wrap outputs into Variable
-        for (int i = 0; i < 3; ++i) {
-          torch::jit::peek(stack, i, 3) = torch::autograd::make_variable(std::move(torch::jit::peek(stack, i, 3)).toTensor(), false);
-        }
-
-        return 0;
-      })
-  });
-}
-
-}
-}
index 0c16352..75b56de 100644 (file)
@@ -27,6 +27,36 @@ TORCH_API FunctionSchema parseSchema(const std::string& schema);
 
 using OperationCreator = std::function<Operation(const Node*)>;
 
+/*
+ * Note: JIT relies on Operator instances having static lifetime, because
+ * it for example stores a non-owning FunctionSchema* pointer in the Node class,
+ * which points to the function shema stored in the Operator instance.
+ * Also, jit::Operator is meant to store more operator related information like
+ * symbolic derivatives, which also requires them to have static lifetime
+ * so that changes to symbolic derivatives are remembered.
+ *
+ * Now, currently, the c10 operator library doesn't store jit::Operator instances,
+ * but we use a listener pattern that notifies JIT about changes in the
+ * c10 operator library and then registers jit::Operator instances to the JIT
+ * operator registry, acting as wrappers to the c10 operators.
+ *
+ * However, that results in code duplication as JIT and c10 will likely get
+ * their own mechanisms for storing derivatives and other operator related
+ * information, and all of this would have to be wrapped from c10 into JIT.
+ *
+ * We should consider merging the JIT and c10 registries, moving jit::Operator
+ * to c10 and storing these jit::Operator instances in the c10 operator library
+ * instead, allowing us to have these mechanisms only implemented once.
+ * However, the current jit::Operator implementation has additional features
+ * like OperationCreator that aren't needed in c10 (they're only used for
+ * prim ops like If/Else or While which wouldn't be in the c10 operator library),
+ * and which depend on other JIT features which we don't want to move to c10
+ * (notably jit/ir.h). We might, however, be able, to split jit::Operator into
+ * a c10::Operator with the core features and a jit::Operator that adds the
+ * JIT-only features like OperationCreator, and then use c10::Operator in the
+ * c10 operator library.
+ */
+
 struct TORCH_API Operator {
   Operator(FunctionSchema schema, OperationCreator op_creator)
       : schema_(std::make_shared<FunctionSchema>(std::move(schema))),
diff --git a/torch/csrc/jit/register_c10_ops.cpp b/torch/csrc/jit/register_c10_ops.cpp
new file mode 100644 (file)
index 0000000..0ed473c
--- /dev/null
@@ -0,0 +1,66 @@
+#include <ATen/core/dispatch/Dispatcher.h>
+#include <torch/csrc/jit/operator.h>
+
+namespace torch {
+namespace jit {
+namespace {
+
+// TODO This currently only handles tensors with requires_grad==False correctly.
+//      It should also handle autograd.
+Operator createOperatorFromC10(const c10::OperatorHandle& op) {
+  return Operator(op.schema(), [op](Stack& stack) {
+      const auto input_size = op.schema().arguments().size();
+      const auto output_size = op.schema().returns().size();
+
+      // unwrap tensor inputs from variable
+      for (auto iter = stack.end() - input_size; iter != stack.end(); ++iter) {
+        // TODO Remove the .defined() check once we don't have undefined tensors on the stack anymore (@wanchaol is working on this)
+        if (iter->isTensor() && iter->toTensor().defined()) {
+          at::Tensor tensor = std::move(*iter).toTensor();
+          if (tensor.requires_grad()) {
+            throw std::runtime_error("Autograd not yet supported for c10 ops.");
+          }
+          *iter = torch::autograd::Variable(std::move(tensor)).data();
+        }
+      }
+
+      c10::Dispatcher::singleton().lookup(op, &stack).call(&stack);
+
+      // wrap tensor outputs as variable
+      for (auto iter = stack.end() - output_size; iter != stack.end(); ++iter) {
+        if (iter->isTensor()) {
+          *iter = torch::autograd::make_variable(std::move(*iter).toTensor());
+        }
+      }
+
+      return 0;
+  });
+}
+
+class RegistrationListener final : public c10::OpRegistrationListener {
+public:
+  void onOperatorRegistered(const c10::OperatorHandle& op) override {
+    torch::jit::registerOperator(createOperatorFromC10(op));
+  }
+
+  void onOperatorDeregistered(const c10::OperatorHandle& op) override {
+    // TODO Do something like torch::jit::deregisterOperator(op.schema());
+  }
+};
+
+struct Registerer final {
+  Registerer() {
+    // this immediately calls the listener on all existing ops,
+    // and calls it in future whenever a new op is registered
+    c10::Dispatcher::singleton().addRegistrationListener(
+      c10::guts::make_unique<RegistrationListener>()
+    );
+  }
+};
+
+// global instance to run its constructor on startup
+Registerer registerer;
+
+}
+}
+}