Shim caffe2 GetRepeatedArgument helper for use with Ivalue (#16519)
authorBram Wasti <bwasti@fb.com>
Fri, 1 Feb 2019 01:25:16 +0000 (17:25 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 1 Feb 2019 01:33:57 +0000 (17:33 -0800)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16519

GetRepeatedArguments is needed for some ops

Reviewed By: dzhulgakov

Differential Revision: D13864293

fbshipit-source-id: a39255cd391c28acd75a6f0e81d558542417e032

caffe2/core/operator.h

index 7464c06..2534286 100644 (file)
@@ -86,12 +86,23 @@ class CAFFE2_API OperatorBase : public Observable<OperatorBase> {
         *operator_def_, name);
   }
   template <typename T>
+  inline vector<T> GetVectorFromIValueList(const c10::IValue& value) const {
+    return value.template to<vector<T>>();
+  }
+
+  template <typename T>
   inline vector<T> GetRepeatedArgument(
       const string& name,
       const vector<T>& default_value = {}) const {
-    CAFFE_ENFORCE(operator_def_, "operator_def was null!");
-    return ArgumentHelper::GetRepeatedArgument<OperatorDef, T>(
-        *operator_def_, name, default_value);
+    if (isLegacyOperator()) {
+      CAFFE_ENFORCE(operator_def_, "operator_def was null!");
+      return ArgumentHelper::GetRepeatedArgument<OperatorDef, T>(
+          *operator_def_, name, default_value);
+    }
+    auto index = getFunctionSchema().argumentIndexWithName(name);
+    CAFFE_ENFORCE(index.has_value(), "Couldn't get index for argument!", name);
+    const auto& value = ivalue_inputs_[index.value()];
+    return GetVectorFromIValueList<T>(value);
   }
 
   // Get the inputs and outputs as specific types.
@@ -553,6 +564,38 @@ inline NetDef OperatorBase::GetSingleArgument<NetDef>(
   return NetDef();
 }
 
+template <>
+inline vector<int> OperatorBase::GetVectorFromIValueList<int>(
+    const c10::IValue& value) const {
+  const auto& vs = value.toIntListRef();
+  vector<int> out;
+  out.reserve(vs.size());
+  for (const auto& v : vs) {
+    out.emplace_back(v);
+  }
+  return out;
+}
+
+template <>
+inline vector<float> OperatorBase::GetVectorFromIValueList<float>(
+    const c10::IValue& value) const {
+  const auto& vs = value.toDoubleListRef();
+  vector<float> out;
+  out.reserve(vs.size());
+  for (const auto& v : vs) {
+    out.emplace_back(v);
+  }
+  return out;
+}
+
+template <>
+inline vector<string> OperatorBase::GetVectorFromIValueList<string>(
+    const c10::IValue& value) const {
+  CAFFE_THROW("Cannot extract vector<string> from ivalue.");
+  vector<string> out;
+  return out;
+}
+
 // OP_SINGLE_ARG provides a shorter initialization choice for initialization of
 // member variables for the class constructors.
 // This is a workaround for CUDA9.2 and GCC7