Fix InputSize/OutputSize for ivalue based operators (#17579)
authorSebastian Messmer <messmer@fb.com>
Mon, 4 Mar 2019 22:17:11 +0000 (14:17 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Mon, 4 Mar 2019 22:20:12 +0000 (14:20 -0800)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17579

These methods previously just returned 0 when it was not a legacy operator,
making it impossible to convert some operators.

Reviewed By: dzhulgakov

Differential Revision: D14253094

fbshipit-source-id: 72bfdcf6da291a4ab80d1e0ceb20984b86edc408

caffe2/core/operator.cc
caffe2/core/operator.h

index 16a696f..30a4d50 100644 (file)
@@ -34,6 +34,7 @@ OperatorBase::OperatorBase(const OperatorDef& operator_def, Workspace* ws)
       device_option_(
           operator_def.has_device_option() ? operator_def.device_option()
                                            : DeviceOption()),
+      input_size_(operator_def.input_size()),
       event_(caffe2::make_unique<Event>(device_option_)) {
   static GlobalInitIsCalledGuard guard;
   for (const string& input_str : operator_def.input()) {
@@ -56,14 +57,44 @@ OperatorBase::OperatorBase(const OperatorDef& operator_def, Workspace* ws)
   type_ = operator_def.type();
 }
 
+namespace {
+int compute_input_size_(const std::vector<c10::IValue>& inputs) {
+  if (inputs.empty()) {
+    return 0;
+  }
+  if (inputs[0].isTensorList()) {
+    // if the first input is a tensor list, we get input tensors by indexing
+    // into that list. currently, this means that only tensors from that list
+    // are accessible as inputs. any hypothetical input tensors that come after
+    // the list are not accessible.
+    return inputs[0].toTensorListRef().size();
+  }
+  // it's not a tensor list. Count the number of tensor inputs and return them.
+  size_t num_tensor_inputs = 0;
+  bool found_nontensor = false;
+  for (const auto& input : inputs) {
+    if (input.isTensor()) {
+      AT_ASSERTM(
+          !found_nontensor,
+          "All tensor arguments must come before non-tensor arguments");
+      ++num_tensor_inputs;
+    } else {
+      found_nontensor = true;
+    }
+  }
+  return num_tensor_inputs;
+}
+} // namespace
+
 OperatorBase::OperatorBase(
     const c10::FunctionSchema& fn_schema,
     std::vector<c10::IValue> inputs,
     std::vector<c10::IValue*> outputs)
     : fn_schema_(make_unique<c10::FunctionSchema>(std::move(fn_schema))),
       ivalue_inputs_(std::move(inputs)),
-      ivalue_outputs_(std::move(outputs)) {
-  input_tensors_.resize(ivalue_inputs_.size());
+      ivalue_outputs_(std::move(outputs)),
+      input_size_(compute_input_size_(ivalue_inputs_)) {
+  input_tensors_.resize(input_size_);
   output_tensors_.resize(ivalue_outputs_.size());
 }
 
index 75c73c6..d4c109e 100644 (file)
@@ -35,10 +35,18 @@ typedef ObserverBase<OperatorBase> OperatorObserver;
 class CAFFE2_API OperatorBase : public Observable<OperatorBase> {
  public:
   explicit OperatorBase(const OperatorDef& operator_def, Workspace* ws);
+
+  /*
+   * Notes: All outputs ivalues must be tensors. Input ivalue list must start
+   * with all tensors ("inputs" in caffe2 terminology),
+   * followed by non-tensors ("arguments" in caffe2 terminology).
+   * Alternatively, inputs can be one tensor list ivalue followed by non-tensors
+   * to represent operators with a variable number of inputs.
+   */
   explicit OperatorBase(
-      const c10::FunctionSchema&,
-      std::vector<c10::IValue>,
-      std::vector<c10::IValue*>);
+      const c10::FunctionSchema& schema,
+      std::vector<c10::IValue> inputs,
+      std::vector<c10::IValue*> outputs);
 
   virtual ~OperatorBase() noexcept {}
 
@@ -326,10 +334,14 @@ class CAFFE2_API OperatorBase : public Observable<OperatorBase> {
   }
 
   inline int InputSize() const {
-    return inputs_.size();
+    return input_size_;
   }
+
   inline int OutputSize() const {
-    return outputs_.size();
+    if (isLegacyOperator()) {
+      return outputs_.size();
+    }
+    return ivalue_outputs_.size();
   }
   inline const vector<const Blob*>& Inputs() const { return inputs_; }
   inline const vector<Blob*>& Outputs() { return outputs_; }
@@ -542,6 +554,8 @@ class CAFFE2_API OperatorBase : public Observable<OperatorBase> {
   vector<caffe2::Tensor> input_tensors_;
   vector<caffe2::Tensor> output_tensors_;
 
+  int input_size_;
+
   int net_position_{kNoNetPositionSet};
 
   ExecutorHelper* helper_ = nullptr;