Add FunctionSchema based Operator Registry (#13789)
authorBram Wasti <bwasti@fb.com>
Thu, 6 Dec 2018 01:16:24 +0000 (17:16 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 6 Dec 2018 01:20:24 +0000 (17:20 -0800)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/13789

This enables creation of operators with FunctionSchema and IValue

Reviewed By: smessmer

Differential Revision: D13008791

fbshipit-source-id: 151efc88ac315f4a0ab0171a99774caaf767ef1e

aten/src/ATen/core/Type.h
aten/src/ATen/core/ivalue.h
aten/src/ATen/templates/Type.h
caffe2/core/blob.h
caffe2/core/operator.cc
caffe2/core/operator.h
caffe2/core/operator_test.cc
caffe2/core/tensor.h

index a699684..296c5c4 100644 (file)
@@ -135,7 +135,10 @@ struct CAFFE2_API Type {
     return backendToDeviceType(backend());
   }
 
-  virtual Tensor copy(const Tensor & src, bool non_blocking=false, c10::optional<Device> to_device={}) const = 0;
+  virtual Tensor copy(
+      const Tensor& src,
+      bool non_blocking = false,
+      c10::optional<Device> to_device = {}) const = 0;
   virtual Tensor & copy_(Tensor & self, const Tensor & src, bool non_blocking=false) const = 0;
 
   virtual void backward(
@@ -168,7 +171,7 @@ struct CAFFE2_API Type {
 
   /// Constructs the `TensorOptions` from a type and a Device.  Asserts that
   /// the device type matches the device type of the type.
-  TensorOptions options(optional<Device> device_opt) const {
+  TensorOptions options(c10::optional<Device> device_opt) const {
     if (!device_opt.has_value()) {
       return options(-1);
     } else {
index e358573..381d662 100644 (file)
@@ -575,6 +575,18 @@ struct C10_EXPORT ivalue::Future final : c10::intrusive_ptr_target {
 
 #undef TORCH_FORALL_TAGS
 
+namespace detail {
+
+struct _guarded_unsigned_long_unique_dummy final {
+  _guarded_unsigned_long_unique_dummy(int64_t){};
+};
+using _guarded_unsigned_long = c10::guts::conditional_t<
+    std::is_same<unsigned long, uint32_t>::value ||
+        std::is_same<unsigned long, uint64_t>::value,
+    _guarded_unsigned_long_unique_dummy,
+    unsigned long>;
+
+} // namespace detail
 
 #define DEFINE_TO(type, method_name) \
 template<> \
@@ -587,7 +599,16 @@ inline type IValue::to<type>() const & { \
 }
 DEFINE_TO(at::Tensor, toTensor)
 DEFINE_TO(c10::intrusive_ptr<ivalue::Tuple>, toTuple)
+DEFINE_TO(float, toDouble)
 DEFINE_TO(double, toDouble)
+DEFINE_TO(unsigned char, toInt)
+DEFINE_TO(signed char, toInt)
+DEFINE_TO(unsigned short, toInt)
+DEFINE_TO(short, toInt)
+DEFINE_TO(int, toInt)
+DEFINE_TO(uint32_t, toInt)
+DEFINE_TO(uint64_t, toInt)
+DEFINE_TO(detail::_guarded_unsigned_long, toInt)
 DEFINE_TO(int64_t, toInt)
 DEFINE_TO(bool, toBool)
 DEFINE_TO(c10::intrusive_ptr<ivalue::DoubleList>, toDoubleList)
index 60e1c19..3f6f72c 100644 (file)
@@ -106,7 +106,10 @@ struct CAFFE2_API Type {
     return backendToDeviceType(backend());
   }
 
-  virtual Tensor copy(const Tensor & src, bool non_blocking=false, c10::optional<Device> to_device={}) const = 0;
+  virtual Tensor copy(
+      const Tensor& src,
+      bool non_blocking = false,
+      c10::optional<Device> to_device = {}) const = 0;
   virtual Tensor & copy_(Tensor & self, const Tensor & src, bool non_blocking=false) const = 0;
 
   virtual void backward(
@@ -139,7 +142,7 @@ struct CAFFE2_API Type {
 
   /// Constructs the `TensorOptions` from a type and a Device.  Asserts that
   /// the device type matches the device type of the type.
-  TensorOptions options(optional<Device> device_opt) const {
+  TensorOptions options(c10::optional<Device> device_opt) const {
     if (!device_opt.has_value()) {
       return options(-1);
     } else {
index 0227c67..ba574c1 100644 (file)
@@ -28,6 +28,29 @@ inline Tensor* BlobSetTensor(Blob* blob, const Tensor& tensor) {
   return blob->Reset<Tensor>(new Tensor(tensor));
 }
 
+inline Tensor GetSizedTensorWithOptions(
+    const Tensor& t,
+    at::IntList dims,
+    at::TensorOptions options) {
+  Tensor tensor = t;
+  if (tensor.GetDevice() == options.device() ||
+      (!tensor.GetDevice().has_index() &&
+       tensor.GetDeviceType() == options.device().type())) {
+    if (tensor.sizes() != dims) {
+      // Resize when the dims doesn't match
+      tensor.Resize(dims);
+    }
+    if (tensor.dtype() == options.dtype()) {
+      tensor.raw_mutable_data();
+    } else {
+      // create a new Tensor when the data_type doesn't match
+      return caffe2::empty(dims, options);
+    }
+    return tensor;
+  }
+  return caffe2::empty(dims, options);
+}
+
 // need to keep both functions that returns Tensor* and the one
 // returns Tensor for clangr codemod
 inline Tensor*
index 9775b83..f83b4f4 100644 (file)
@@ -56,6 +56,16 @@ OperatorBase::OperatorBase(const OperatorDef& operator_def, Workspace* ws)
   type_ = operator_def.type();
 }
 
+OperatorBase::OperatorBase(
+    const c10::FunctionSchema& fn_schema,
+    const std::vector<c10::IValue>& inputs,
+    const std::vector<c10::IValue*>& outputs)
+    : fn_schema_(make_unique<c10::FunctionSchema>(fn_schema)),
+      ivalue_inputs_(inputs),
+      ivalue_outputs_(outputs) {
+  output_tensors_.resize(ivalue_outputs_.size());
+}
+
 vector<TensorShape> OperatorBase::InputTensorShapes() const {
   vector<TensorShape> tps;
   for (const auto& blob : inputs_) {
@@ -344,6 +354,15 @@ C10_DEFINE_REGISTRY(
     const OperatorDef&,
     const vector<GradientWrapper>&);
 
+C10_DEFINE_REGISTRY(
+    FunctionSchemaOperatorRegistry,
+    OperatorBase,
+    const c10::FunctionSchema,
+    const std::vector<c10::IValue>&,
+    const std::vector<c10::IValue*>&);
+
+C10_DEFINE_REGISTRY(FunctionSchemaRegistry, FunctionSchemaStorageBase);
+
 GradientOpsMeta GetGradientForOp(
     const OperatorDef& def, const vector<GradientWrapper>& g_output) {
   std::unique_ptr<GradientMakerBase> maker(
@@ -689,6 +708,11 @@ std::set<std::string> GetRegisteredOperators() {
     all_keys.emplace(name);
   }
 
+  // FunctionSchema registered operators
+  for (const auto& name : FunctionSchemaOperatorRegistry()->Keys()) {
+    all_keys.emplace(name);
+  }
+
   return all_keys;
 }
 
index 3db643a..559db16 100644 (file)
 #include "caffe2/proto/caffe2_pb.h"
 #include "caffe2/utils/proto_utils.h"
 
+#include <ATen/core/Tensor.h>
+#include <ATen/core/function_schema.h>
+#include <ATen/core/ivalue.h>
+
 namespace caffe2 {
 
 class CAFFE2_API OperatorBase;
@@ -31,23 +35,50 @@ typedef ObserverBase<OperatorBase> OperatorObserver;
 class CAFFE2_API OperatorBase : public Observable<OperatorBase> {
  public:
   explicit OperatorBase(const OperatorDef& operator_def, Workspace* ws);
+  explicit OperatorBase(
+      const c10::FunctionSchema&,
+      const std::vector<c10::IValue>&,
+      const std::vector<c10::IValue*>&);
+
   virtual ~OperatorBase() noexcept {}
 
+  /** @brief Return true if the operator was instantiated with OperatorDef
+   * New operators should be instantiated with FunctionSchema
+   */
+  bool isLegacyOperator() const {
+    return !fn_schema_;
+  }
+
+  const c10::FunctionSchema& getFunctionSchema() const {
+    CAFFE_ENFORCE(!isLegacyOperator());
+    return *fn_schema_.get();
+  }
+
   /** @brief Checks if the operator has an argument of the given name.
    */
   inline bool HasArgument(const string& name) const {
-    CAFFE_ENFORCE(operator_def_, "operator_def was null!");
-    return ArgumentHelper::HasArgument(*operator_def_, name);
+    if (isLegacyOperator()) {
+      CAFFE_ENFORCE(operator_def_, "operator_def was null!");
+      return ArgumentHelper::HasArgument(*operator_def_, name);
+    }
+    return getFunctionSchema().argumentIndexWithName(name).has_value();
   }
 
   // Functions that deal with arguments. Basically, this allows us to map an
   // argument name to a specific type of argument that we are trying to access.
   template <typename T>
   inline T GetSingleArgument(const string& name, const T& default_value) const {
-    CAFFE_ENFORCE(operator_def_, "operator_def was null!");
-    return ArgumentHelper::GetSingleArgument<OperatorDef, T>(
-        *operator_def_, name, default_value);
+    if (isLegacyOperator()) {
+      CAFFE_ENFORCE(operator_def_, "operator_def was null!");
+      return ArgumentHelper::GetSingleArgument<OperatorDef, T>(
+          *operator_def_, name, default_value);
+    }
+    auto index = getFunctionSchema().argumentIndexWithName(name);
+    CAFFE_ENFORCE(index.has_value(), "Couldn't get index for argument!", name);
+    const auto& value = ivalue_inputs_[index.value()];
+    return value.template to<T>();
   }
+
   template <typename T>
   inline bool HasSingleArgumentOfType(const string& name) const {
     CAFFE_ENFORCE(operator_def_, "operator_def was null!");
@@ -120,11 +151,26 @@ class CAFFE2_API OperatorBase : public Observable<OperatorBase> {
   // TODO(jerryzh): Remove this template
   template <typename T>
   inline T* Output(int idx, DeviceType type) {
-    static_assert(
-        std::is_same<T, Tensor>::value,
-        "Output(int, DeviceType) is only available for Tensor");
-    // When you get a Tensor here it is not fully initialized
-    return BlobGetMutableTensor(outputs_.at(idx), type);
+    if (isLegacyOperator()) {
+      static_assert(
+          std::is_same<T, Tensor>::value,
+          "Output(int, DeviceType) is only available for Tensor");
+      // When you get a Tensor here it is not fully initialized
+      return BlobGetMutableTensor(outputs_.at(idx), type);
+    }
+    auto* ival = ivalue_outputs_[idx];
+    CAFFE_ENFORCE(
+        ival->isTensor(),
+        "Output(int, DeviceType) is only available for IValues that store Tensors");
+    Tensor tensor = caffe2::Tensor(ival->toTensor());
+    if (tensor.GetDeviceType() != type) {
+      // Fix tensor type
+      tensor = Tensor(type);
+      auto at_tensor = at::Tensor(std::move(tensor.getIntrusivePtr()));
+      *ival = IValue(at_tensor);
+    }
+    output_tensors_[idx] = caffe2::Tensor(ival->toTensor());
+    return &output_tensors_[idx];
   }
 
   inline Tensor
@@ -137,10 +183,23 @@ class CAFFE2_API OperatorBase : public Observable<OperatorBase> {
 
   inline Tensor*
   OutputTensor(int idx, at::IntList dims, at::TensorOptions options) {
-    CAFFE_ENFORCE_WITH_CALLER(
-        options.device_opt() != c10::nullopt,
-        "device must be provided in option.");
-    return BlobGetMutableTensor(outputs_.at(idx), dims, options);
+    if (isLegacyOperator()) {
+      CAFFE_ENFORCE_WITH_CALLER(
+          options.device_opt() != c10::nullopt,
+          "device must be provided in option.");
+      return BlobGetMutableTensor(outputs_.at(idx), dims, options);
+    }
+    auto* ival = ivalue_outputs_[idx];
+    CAFFE_ENFORCE(
+        ival->isTensor(),
+        "Output(int, DeviceType) is only available for IValues that store Tensors");
+    Tensor tensor = caffe2::Tensor(ival->toTensor());
+    tensor = GetSizedTensorWithOptions(tensor, dims, options);
+    auto at_tensor = at::Tensor(std::move(tensor.getIntrusivePtr()));
+    *ival = IValue(at_tensor);
+
+    output_tensors_[idx] = caffe2::Tensor(ival->toTensor());
+    return &output_tensors_[idx];
   }
 
   // Get output Tensor of the operator and CopyFrom the given Tensor
@@ -414,6 +473,15 @@ class CAFFE2_API OperatorBase : public Observable<OperatorBase> {
   std::string type_;
   vector<const Blob*> inputs_;
   vector<Blob*> outputs_;
+  // Preferrably use c10::optional, but nvcc doesn't work
+  std::unique_ptr<const c10::FunctionSchema> fn_schema_ = nullptr;
+  vector<c10::IValue> ivalue_inputs_;
+  vector<c10::IValue*> ivalue_outputs_;
+  // HACK
+  // We preserve the fact that Output() returns Tensor*
+  // by storing Tensor in a vector owned by the
+  // operator.
+  vector<caffe2::Tensor> output_tensors_;
 
   int net_position_{kNoNetPositionSet};
 
@@ -450,6 +518,19 @@ class CAFFE2_API OperatorBase : public Observable<OperatorBase> {
   C10_DISABLE_COPY_AND_ASSIGN(OperatorBase);
 };
 
+template <>
+inline NetDef OperatorBase::GetSingleArgument<NetDef>(
+    const std::string& name,
+    const NetDef& default_value) const {
+  if (isLegacyOperator()) {
+    CAFFE_ENFORCE(operator_def_, "operator_def was null!");
+    return ArgumentHelper::GetSingleArgument<OperatorDef, NetDef>(
+        *operator_def_, name, default_value);
+  }
+  CAFFE_THROW("Cannot get NetDefs from IValue");
+  return NetDef();
+}
+
 // If your operator does not need any specialized contructor or destructor,
 // you can simply use this to save two lines of code.
 #define USE_SIMPLE_BASE_CTOR_DTOR(name)                                        \
@@ -495,6 +576,15 @@ class Operator : public OperatorBase {
     // constructors will run on that device.
     context_.SwitchToDevice(0);
   }
+  explicit Operator(
+      const c10::FunctionSchema& fn_schema,
+      const std::vector<c10::IValue>& inputs,
+      const std::vector<c10::IValue*>& outputs)
+      : OperatorBase(fn_schema, inputs, outputs) {
+    // In the constructor, we switch to the device so that the child class
+    // constructors will run on that device.
+    context_.SwitchToDevice(0);
+  }
   ~Operator() noexcept override {}
 
   inline const Tensor& Input(
@@ -965,6 +1055,34 @@ C10_DECLARE_REGISTRY(
   REGISTER_HIP_OPERATOR_WITH_ENGINE(name, MIOPEN, __VA_ARGS__) \
   REGISTER_HIP_OPERATOR_WITH_ENGINE(name, CUDNN, __VA_ARGS__) // Make CUDNN an alias of MIOPEN for HIP ops
 
+C10_DECLARE_REGISTRY(
+    FunctionSchemaOperatorRegistry,
+    OperatorBase,
+    const c10::FunctionSchema,
+    const std::vector<c10::IValue>&,
+    const std::vector<c10::IValue*>&);
+
+struct FunctionSchemaStorageBase {
+  FunctionSchemaStorageBase() {}
+  virtual c10::FunctionSchema getSchema() = 0;
+  virtual ~FunctionSchemaStorageBase() {}
+};
+
+C10_DECLARE_REGISTRY(FunctionSchemaRegistry, FunctionSchemaStorageBase);
+
+#define REGISTER_FUNCTION_SCHEMA_OPERATOR(name, inputs, outputs, impl)        \
+  C10_REGISTER_CLASS(FunctionSchemaOperatorRegistry, name, impl)              \
+  struct FunctionSchemaStorageBase##name : public FunctionSchemaStorageBase { \
+    c10::FunctionSchema getSchema() override {                                \
+      return c10::FunctionSchema(#name, inputs, outputs);                     \
+    }                                                                         \
+  };                                                                          \
+  C10_REGISTER_CLASS(                                                         \
+      FunctionSchemaRegistry, name, FunctionSchemaStorageBase##name)
+
+#define GET_FUNCTION_SCHEMA(name) \
+  FunctionSchemaRegistry()->Create(name)->getSchema()
+
 // StaticLinkingProtector is a helper class that ensures that the Caffe2
 // library is linked correctly with whole archives (in the case of static
 // linking). What happens is that when CreateOperator is called for the first
index 1ce881d..c813e04 100644 (file)
@@ -595,4 +595,90 @@ TEST(IsTestArg, non_standard) {
       "JustTestWithNonStandardIsTestArg");
 }
 
+class TestOperatorWithFunctionSchema final : public Operator<CPUContext> {
+ public:
+  TestOperatorWithFunctionSchema(const OperatorDef& def, Workspace* ws)
+      : Operator<CPUContext>(def, ws) {}
+
+  TestOperatorWithFunctionSchema(
+      const c10::FunctionSchema& f,
+      const std::vector<c10::IValue>& i,
+      const std::vector<c10::IValue*>& o)
+      : Operator<CPUContext>(f, i, o) {
+    if (HasArgument("test_arg")) {
+      test_arg_ =
+          static_cast<float>(this->GetSingleArgument<float>("test_arg", 0.01));
+    }
+  }
+
+  bool RunOnDevice() override {
+    auto out =
+        OutputTensor(0, {1, 1}, at::TensorOptions(TypeMeta::Make<float>()));
+    out->mutable_data<float>()[0] = test_arg_;
+    return true;
+  }
+
+ private:
+  float test_arg_ = 0;
+};
+
+REGISTER_CPU_OPERATOR(
+    TestOperatorWithFunctionSchema,
+    TestOperatorWithFunctionSchema);
+OPERATOR_SCHEMA(TestOperatorWithFunctionSchema)
+    .NumInputs(0, 1)
+    .NumOutputs(0, 1)
+    .Arg("test_arg", "this arg is required", true);
+
+// The new way combines OPERATOR_SCHEMA and REGISTER_OPERATOR
+REGISTER_FUNCTION_SCHEMA_OPERATOR(
+    TestOperatorWithFunctionSchema,
+    {c10::Argument("test_arg")},
+    {c10::Argument("output")},
+    TestOperatorWithFunctionSchema)
+
+TEST(FunctionSchema, Creation) {
+  std::vector<c10::IValue> inputs;
+  float test_val = 1337.0f;
+  inputs.emplace_back(test_val);
+
+  caffe2::Tensor out = TensorCPUFromValues<float>({1, 1}, {123.0f});
+  std::vector<c10::IValue*> outputs;
+  auto t = at::Tensor(std::move(out.getIntrusivePtr()));
+  auto out_val = c10::IValue(t);
+  outputs.emplace_back(&out_val);
+
+  auto fn = FunctionSchemaRegistry()
+                ->Create("TestOperatorWithFunctionSchema")
+                ->getSchema();
+  auto op = FunctionSchemaOperatorRegistry()->Create(
+      "TestOperatorWithFunctionSchema", fn, inputs, outputs);
+
+  op->Run();
+  EXPECT_EQ(out.data<float>()[0], test_val);
+}
+
+TEST(FunctionSchema, OutputChange) {
+  std::vector<c10::IValue> inputs;
+  float test_val = 1337.0f;
+  inputs.emplace_back(test_val);
+
+  // Wrong type
+  caffe2::Tensor out = TensorCPUFromValues<int>({1, 1}, {123});
+  std::vector<c10::IValue*> outputs;
+  auto t = at::Tensor(std::move(out.getIntrusivePtr()));
+  auto out_val = c10::IValue(t);
+  outputs.emplace_back(&out_val);
+
+  auto fn = FunctionSchemaRegistry()
+                ->Create("TestOperatorWithFunctionSchema")
+                ->getSchema();
+  auto op = FunctionSchemaOperatorRegistry()->Create(
+      "TestOperatorWithFunctionSchema", fn, inputs, outputs);
+
+  op->Run();
+  out = caffe2::Tensor(out_val.toTensor());
+  EXPECT_EQ(out.data<float>()[0], test_val);
+}
+
 }  // namespace caffe2
index 4015422..2759a1d 100644 (file)
@@ -6,6 +6,7 @@
 
 #include <ATen/core/UndefinedTensorImpl.h>
 #include <c10/util/intrusive_ptr.h>
+#include "ATen/core/Tensor.h"
 #include "ATen/core/TensorOptions.h"
 
 namespace caffe2 {
@@ -51,6 +52,12 @@ class CAFFE2_API Tensor final {
   }
 
   /**
+   * @brief Creates a caffe2 tensor from an ATen tensor
+   */
+  explicit Tensor(const at::Tensor& tensor)
+      : impl_(std::move(tensor.getIntrusivePtr())) {}
+
+  /**
    * @brief Creates a tensor of the given dimension.
    *
    * Note that the actual data allocation is not going to be carried out until
@@ -245,6 +252,11 @@ class CAFFE2_API Tensor final {
     impl_.get()->ShareExternalPointer(std::move(data_ptr), data_type, capacity);
   }
 
+  const c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl>& getIntrusivePtr()
+      const {
+    return impl_;
+  }
+
   /**
    * Returns a const raw void* pointer of the underlying storage. mutable_data()
    * or raw_mutable_data() must have been called prior to this function call.