From e4c1b51d82a925a2d66fd1f297136317e75d735c Mon Sep 17 00:00:00 2001 From: Bram Wasti Date: Thu, 31 Jan 2019 17:25:16 -0800 Subject: [PATCH] Shim caffe2 GetRepeatedArgument helper for use with Ivalue (#16519) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/16519 GetRepeatedArguments is needed for some ops Reviewed By: dzhulgakov Differential Revision: D13864293 fbshipit-source-id: a39255cd391c28acd75a6f0e81d558542417e032 --- caffe2/core/operator.h | 49 ++++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 46 insertions(+), 3 deletions(-) diff --git a/caffe2/core/operator.h b/caffe2/core/operator.h index 7464c06..2534286 100644 --- a/caffe2/core/operator.h +++ b/caffe2/core/operator.h @@ -86,12 +86,23 @@ class CAFFE2_API OperatorBase : public Observable { *operator_def_, name); } template + inline vector GetVectorFromIValueList(const c10::IValue& value) const { + return value.template to>(); + } + + template inline vector GetRepeatedArgument( const string& name, const vector& default_value = {}) const { - CAFFE_ENFORCE(operator_def_, "operator_def was null!"); - return ArgumentHelper::GetRepeatedArgument( - *operator_def_, name, default_value); + if (isLegacyOperator()) { + CAFFE_ENFORCE(operator_def_, "operator_def was null!"); + return ArgumentHelper::GetRepeatedArgument( + *operator_def_, name, default_value); + } + auto index = getFunctionSchema().argumentIndexWithName(name); + CAFFE_ENFORCE(index.has_value(), "Couldn't get index for argument!", name); + const auto& value = ivalue_inputs_[index.value()]; + return GetVectorFromIValueList(value); } // Get the inputs and outputs as specific types. @@ -553,6 +564,38 @@ inline NetDef OperatorBase::GetSingleArgument( return NetDef(); } +template <> +inline vector OperatorBase::GetVectorFromIValueList( + const c10::IValue& value) const { + const auto& vs = value.toIntListRef(); + vector out; + out.reserve(vs.size()); + for (const auto& v : vs) { + out.emplace_back(v); + } + return out; +} + +template <> +inline vector OperatorBase::GetVectorFromIValueList( + const c10::IValue& value) const { + const auto& vs = value.toDoubleListRef(); + vector out; + out.reserve(vs.size()); + for (const auto& v : vs) { + out.emplace_back(v); + } + return out; +} + +template <> +inline vector OperatorBase::GetVectorFromIValueList( + const c10::IValue& value) const { + CAFFE_THROW("Cannot extract vector from ivalue."); + vector out; + return out; +} + // OP_SINGLE_ARG provides a shorter initialization choice for initialization of // member variables for the class constructors. // This is a workaround for CUDA9.2 and GCC7 -- 2.7.4