Expose c10 operators to caffe2 by operator name (#18160)
authorSebastian Messmer <messmer@fb.com>
Tue, 26 Mar 2019 19:29:02 +0000 (12:29 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Tue, 26 Mar 2019 19:36:11 +0000 (12:36 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18160

When exposing a c10 operator to the caffe2 frontend, don't use the operator schema but use the operator name instead.
This allows us to get rid of the existing mechanism for operator schema registration in a diff stacked on top.

Reviewed By: dzhulgakov

Differential Revision: D14513420

fbshipit-source-id: 6b08a9c6d9497eaf18b62361dd44bc07c7b4b76b

21 files changed:
caffe2/core/operator_c10wrapper.h
caffe2/operators/experimental/c10/schemas/add.cc
caffe2/operators/experimental/c10/schemas/averaged_loss.cc
caffe2/operators/experimental/c10/schemas/batch_gather.cc
caffe2/operators/experimental/c10/schemas/batch_matmul.cc
caffe2/operators/experimental/c10/schemas/cast.cc
caffe2/operators/experimental/c10/schemas/concat.cc
caffe2/operators/experimental/c10/schemas/enforce_finite.cc
caffe2/operators/experimental/c10/schemas/expand_dims.cc
caffe2/operators/experimental/c10/schemas/fc.cc
caffe2/operators/experimental/c10/schemas/filler.cc
caffe2/operators/experimental/c10/schemas/flatten.cc
caffe2/operators/experimental/c10/schemas/mul.cc
caffe2/operators/experimental/c10/schemas/relu.cc
caffe2/operators/experimental/c10/schemas/sigmoid.cc
caffe2/operators/experimental/c10/schemas/sigmoid_cross_entropy_with_logits.cc
caffe2/operators/experimental/c10/schemas/sparse_lengths_sum.cc
caffe2/operators/experimental/c10/schemas/stop_gradient.cc
caffe2/operators/layer_norm_op.cc
caffe2/operators/layer_norm_op.cu
caffe2/operators/layer_norm_op.h

index 649fcbd..ade91ca 100644 (file)
@@ -1,10 +1,11 @@
 #pragma once
 
 #include <ATen/core/dispatch/Dispatcher.h>
-#include "caffe2/core/operator.h"
+#include <ATen/core/ivalue.h>
 #include <c10/util/ArrayRef.h>
+#include <c10/util/C++17.h>
 #include <c10/util/Metaprogramming.h>
-#include <ATen/core/ivalue.h>
+#include "caffe2/core/operator.h"
 
 namespace caffe2 {
 
@@ -202,10 +203,19 @@ class C10OperatorWrapper final : public Operator<Context> {
 template <class Context>
 inline std::function<
     std::unique_ptr<OperatorBase>(const OperatorDef&, Workspace*)>
-createC10OperatorWrapper(const c10::OperatorHandle& op_handle) {
-  return [op_handle](const OperatorDef& op_def, Workspace* ws) {
+createC10OperatorWrapper(const char* op_name, const char* overload_name) {
+  return [op_name, overload_name](const OperatorDef& op_def, Workspace* ws) {
+    auto op_handle =
+        c10::Dispatcher::singleton().findSchema(op_name, overload_name);
+    AT_ASSERTM(
+        op_handle.has_value(),
+        "Tried to register c10 operator ",
+        op_name,
+        ".",
+        overload_name,
+        " with caffe2, but didn't find the c10 operator.");
     return c10::guts::make_unique<C10OperatorWrapper<Context>>(
-        op_handle, op_def, ws);
+        *op_handle, op_def, ws);
   };
 }
 
@@ -215,18 +225,30 @@ createC10OperatorWrapper(const c10::OperatorHandle& op_handle) {
 #ifndef C10_MOBILE
 // TODO Currently we only register the CPU variant. This is going to be fixed
 //      once the tensor detemplatization lands.
-#define REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_CPU(OperatorHandle, Name) \
-  REGISTER_CPU_OPERATOR_CREATOR(                                            \
-      Name, detail::createC10OperatorWrapper<CPUContext>(OperatorHandle))
-#define REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_CUDA(OperatorHandle, Name) \
-  REGISTER_CUDA_OPERATOR_CREATOR(                                            \
-      Name, detail::createC10OperatorWrapper<CUDAContext>(OperatorHandle))
-#define REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_HIP(OperatorHandle, Name) \
-  REGISTER_HIP_OPERATOR_CREATOR(                                            \
-      Name, detail::createC10OperatorWrapper<HIPContext>(OperatorHandle))
+#define REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_CPU(        \
+    OperatorName, Name)                                       \
+  REGISTER_CPU_OPERATOR_CREATOR(                              \
+      Name,                                                   \
+      ::caffe2::detail::createC10OperatorWrapper<CPUContext>( \
+          OperatorName, ""))
+#define REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_CUDA(        \
+    OperatorName, Name)                                        \
+  REGISTER_CUDA_OPERATOR_CREATOR(                              \
+      Name,                                                    \
+      ::caffe2::detail::createC10OperatorWrapper<CUDAContext>( \
+          OperatorName, ""))
+#define REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_HIP(        \
+    OperatorName, Name)                                       \
+  REGISTER_HIP_OPERATOR_CREATOR(                              \
+      Name,                                                   \
+      ::caffe2::detail::createC10OperatorWrapper<HIPContext>( \
+          OperatorName, ""))
 #else
-#define REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_CPU(OperatorHandle, Name)
-#define REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_CUDA(OperatorHandle, Name)
-#define REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_HIP(OperatorHandle, Name)
+#define REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_CPU( \
+    OperatorName, Name)
+#define REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_CUDA( \
+    OperatorName, Name)
+#define REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_HIP( \
+    OperatorName, Name)
 #endif
 } // namespace caffe2
index dae61f0..63ecc97 100644 (file)
@@ -23,6 +23,6 @@ C10_DEFINE_OP_SCHEMA(
 
 namespace caffe2 {
 REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_CPU(
-    ops::Add(),
+    "_c10_experimental::Add",
     C10Add_DontUseThisOpYet)
 }
index 4fbf909..dfc41a6 100644 (file)
@@ -20,6 +20,6 @@ C10_DEFINE_OP_SCHEMA(
 
 namespace caffe2 {
 REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_CPU(
-    ops::AveragedLoss(),
+    "_c10_experimental::AveragedLoss",
     C10AveragedLoss_DontUseThisOpYet)
 }
index a4d9ca3..9fb84e5 100644 (file)
@@ -21,6 +21,6 @@ C10_DEFINE_OP_SCHEMA(
 
 namespace caffe2 {
 REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_CPU(
-    ops::BatchGather(),
+    "_c10_experimental::BatchGather",
     C10BatchGather_DontUseThisOpYet)
 }
index 5273d2b..addb95e 100644 (file)
@@ -24,8 +24,7 @@ C10_DEFINE_OP_SCHEMA(
 }
 
 namespace caffe2 {
-
 REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_CPU(
-    ops::BatchMatmul(),
+    "_c10_experimental::BatchMatmul",
     C10BatchMatMul_DontUseThisOpYet)
 }
index cb0a990..c1133ce 100644 (file)
@@ -24,6 +24,6 @@ C10_DEFINE_OP_SCHEMA(
 
 namespace caffe2 {
 REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_CPU(
-    ops::Cast(),
+    "_c10_experimental::Cast",
     C10Cast_DontUseThisOpYet)
 }
index 82f4840..d9a7e33 100644 (file)
@@ -24,6 +24,6 @@ C10_DEFINE_OP_SCHEMA(
 
 namespace caffe2 {
 REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_CPU(
-    ops::Concat(),
+    "_c10_experimental::Concat",
     C10Concat_DontUseThisOpYet)
 }
index db18274..c8d58de 100644 (file)
@@ -19,6 +19,6 @@ C10_DEFINE_OP_SCHEMA(
 
 namespace caffe2 {
 REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_CPU(
-    ops::EnforceFinite(),
+    "_c10_experimental::EnforceFinite",
     C10EnforceFinite_DontUseThisOpYet)
 }
index f7bbfd3..e2c4c75 100644 (file)
@@ -22,8 +22,7 @@ C10_DEFINE_OP_SCHEMA(
 }
 
 namespace caffe2 {
-
 REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_CPU(
-    ops::ExpandDims(),
+    "_c10_experimental::ExpandDims",
     C10ExpandDims_DontUseThisOpYet)
 }
index f90ded5..773964e 100644 (file)
@@ -23,8 +23,7 @@ C10_DEFINE_OP_SCHEMA(
 }
 
 namespace caffe2 {
-
 REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_CPU(
-    ops::FullyConnected(),
+    "_c10_experimental::FullyConnected",
     C10FC_DontUseThisOpYet)
 }
index f4fdbbb..8fe8707 100644 (file)
@@ -86,19 +86,18 @@ C10_DEFINE_OP_SCHEMA(
 
 namespace caffe2 {
 REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_CPU(
-    ops::ConstantFill(),
+    "_c10_experimental::ConstantFill",
     C10ConstantFill_DontUseThisOpYet)
 REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_CPU(
-    ops::UniformFill(),
+    "_c10_experimental::UniformFill",
     C10UniformFill_DontUseThisOpYet)
-
 REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_CPU(
-    ops::GivenTensorFill(),
+    "_c10_experimental::GivenTensorFill",
     C10GivenTensorFill_DontUseThisOpYet)
 REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_CPU(
-    ops::GivenTensorIntFill(),
+    "_c10_experimental::GivenTensorIntFill",
     C10GivenTensorIntFill_DontUseThisOpYet)
 REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_CPU(
-    ops::GivenTensorInt64Fill(),
+    "_c10_experimental::GivenTensorInt64Fill",
     C10GivenTensorInt64Fill_DontUseThisOpYet)
 } // namespace caffe2
index f74ddde..42353b2 100644 (file)
@@ -21,6 +21,6 @@ C10_DEFINE_OP_SCHEMA(
 
 namespace caffe2 {
 REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_CPU(
-    ops::Flatten(),
+    "_c10_experimental::Flatten",
     C10Flatten_DontUseThisOpYet)
 }
index 7fe8bc7..af7a7b7 100644 (file)
@@ -24,6 +24,6 @@ C10_DEFINE_OP_SCHEMA(
 
 namespace caffe2 {
 REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_CPU(
-    ops::Mul(),
+    "_c10_experimental::Mul",
     C10Mul_DontUseThisOpYet)
 }
index eb85589..43528d9 100644 (file)
@@ -20,6 +20,6 @@ C10_DEFINE_OP_SCHEMA(
 
 namespace caffe2 {
 REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_CPU(
-    ops::Relu(),
+    "_c10_experimental::Relu",
     C10Relu_DontUseThisOpYet)
 }
index 0d71dc6..2261a19 100644 (file)
@@ -20,6 +20,6 @@ C10_DEFINE_OP_SCHEMA(
 
 namespace caffe2 {
 REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_CPU(
-    ops::Sigmoid(),
+    "_c10_experimental::Sigmoid",
     C10Sigmoid_DontUseThisOpYet)
 }
index 35a0897..d1be6b9 100644 (file)
@@ -24,6 +24,6 @@ C10_DEFINE_OP_SCHEMA(
 
 namespace caffe2 {
 REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_CPU(
-    ops::SigmoidCrossEntropyWithLogits(),
+    "_c10_experimental::SigmoidCrossEntropyWithLogits",
     C10SigmoidCrossEntropyWithLogits_DontUseThisOpYet)
 }
index ee3d646..f26abcf 100644 (file)
@@ -22,6 +22,6 @@ C10_DEFINE_OP_SCHEMA(
 
 namespace caffe2 {
 REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_CPU(
-    ops::SparseLengthsSum(),
+    "_c10_experimental::SparseLengthsSum",
     C10SparseLengthsSum_DontUseThisOpYet)
 }
index a32c379..3305845 100644 (file)
@@ -20,6 +20,6 @@ C10_DEFINE_OP_SCHEMA(
 
 namespace caffe2 {
 REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_CPU(
-    ops::StopGradient(),
+    "_c10_experimental::StopGradient",
     C10StopGradient_DontUseThisOpYet)
 }
index b25a41a..1add4e5 100644 (file)
@@ -1,7 +1,5 @@
 #include "caffe2/operators/layer_norm_op.h"
 
-#include <ATen/core/dispatch/KernelRegistration.h>
-#include <ATen/core/dispatch/OpSchemaRegistration.h>
 #include <c10/core/Tensor.h>
 
 #include "caffe2/core/operator_c10wrapper.h"
@@ -201,6 +199,6 @@ C10_REGISTER_CAFFE2_OPERATOR_CPU(
 
 namespace caffe2 {
 REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_CPU(
-    _c10_ops::LayerNorm(),
+    "_caffe2::LayerNorm",
     C10LayerNorm_DontUseThisOpYet);
 }
index d9b3a20..a37ac66 100644 (file)
@@ -284,6 +284,6 @@ C10_REGISTER_CAFFE2_OPERATOR_CUDA(
 
 namespace caffe2 {
 REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_CUDA(
-    _c10_ops::LayerNorm(),
+    "_caffe2::LayerNorm",
     C10LayerNorm_DontUseThisOpYet);
 }
index a34a082..4ecc12a 100644 (file)
@@ -4,8 +4,6 @@
 #include <array>
 #include <vector>
 
-#include <ATen/core/dispatch/OpSchemaRegistration.h>
-
 #include "caffe2/core/context.h"
 #include "caffe2/core/operator.h"
 #include "caffe2/core/types.h"