*operator_def_, name);
}
template <typename T>
+ inline vector<T> GetVectorFromIValueList(const c10::IValue& value) const {
+ return value.template to<vector<T>>();
+ }
+
+ template <typename T>
inline vector<T> GetRepeatedArgument(
const string& name,
const vector<T>& default_value = {}) const {
- CAFFE_ENFORCE(operator_def_, "operator_def was null!");
- return ArgumentHelper::GetRepeatedArgument<OperatorDef, T>(
- *operator_def_, name, default_value);
+ if (isLegacyOperator()) {
+ CAFFE_ENFORCE(operator_def_, "operator_def was null!");
+ return ArgumentHelper::GetRepeatedArgument<OperatorDef, T>(
+ *operator_def_, name, default_value);
+ }
+ auto index = getFunctionSchema().argumentIndexWithName(name);
+ CAFFE_ENFORCE(index.has_value(), "Couldn't get index for argument!", name);
+ const auto& value = ivalue_inputs_[index.value()];
+ return GetVectorFromIValueList<T>(value);
}
// Get the inputs and outputs as specific types.
return NetDef();
}
+template <>
+inline vector<int> OperatorBase::GetVectorFromIValueList<int>(
+ const c10::IValue& value) const {
+ const auto& vs = value.toIntListRef();
+ vector<int> out;
+ out.reserve(vs.size());
+ for (const auto& v : vs) {
+ out.emplace_back(v);
+ }
+ return out;
+}
+
+template <>
+inline vector<float> OperatorBase::GetVectorFromIValueList<float>(
+ const c10::IValue& value) const {
+ const auto& vs = value.toDoubleListRef();
+ vector<float> out;
+ out.reserve(vs.size());
+ for (const auto& v : vs) {
+ out.emplace_back(v);
+ }
+ return out;
+}
+
+template <>
+inline vector<string> OperatorBase::GetVectorFromIValueList<string>(
+ const c10::IValue& value) const {
+ CAFFE_THROW("Cannot extract vector<string> from ivalue.");
+ vector<string> out;
+ return out;
+}
+
// OP_SINGLE_ARG provides a shorter initialization choice for initialization of
// member variables for the class constructors.
// This is a workaround for CUDA9.2 and GCC7