From 8fedde5530f307bc89be4506f8fa742e3379f63b Mon Sep 17 00:00:00 2001 From: Duc Ngo Date: Thu, 13 Dec 2018 20:42:59 -0800 Subject: [PATCH] caffe2 - easy - test utils to create operator (#15180) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/15180 Test utils to create an operator On top of D13370461 Reviewed By: ZolotukhinM Differential Revision: D13382773 fbshipit-source-id: a88040ed5a60f31d3e73f1f958219cd7338dc52e --- caffe2/core/test_utils.cc | 24 ++++++++++++++++++++++++ caffe2/core/test_utils.h | 21 +++++++++++++++++++++ 2 files changed, 45 insertions(+) diff --git a/caffe2/core/test_utils.cc b/caffe2/core/test_utils.cc index 41a7b3c..eafe5e2 100644 --- a/caffe2/core/test_utils.cc +++ b/caffe2/core/test_utils.cc @@ -18,5 +18,29 @@ caffe2::Tensor* createTensor( return BlobGetMutableTensor(workspace->CreateBlob(name), caffe2::CPU); } +caffe2::OperatorDef* createOperator( + const std::string& type, + const std::vector& inputs, + const std::vector& outputs, + caffe2::NetDef* net) { + auto* op = net->add_op(); + op->set_type(type); + for (const auto& in : inputs) { + op->add_input(in); + } + for (const auto& out : outputs) { + op->add_output(out); + } + return op; +} + +NetMutator& NetMutator::newOp( + const std::string& type, + const std::vector& inputs, + const std::vector& outputs) { + createOperator(type, inputs, outputs, net_); + return *this; +} + } // namespace testing } // namespace caffe2 diff --git a/caffe2/core/test_utils.h b/caffe2/core/test_utils.h index fea24d3..cd4763b 100644 --- a/caffe2/core/test_utils.h +++ b/caffe2/core/test_utils.h @@ -29,6 +29,27 @@ caffe2::Tensor* createTensor( const std::string& name, caffe2::Workspace* workspace); +// Create a new operator in the net. +caffe2::OperatorDef* createOperator( + const std::string& type, + const std::vector& inputs, + const std::vector& outputs, + caffe2::NetDef* net); + +// Coincise util class to mutate a net in a chaining fashion. +class NetMutator { + public: + explicit NetMutator(caffe2::NetDef* net) : net_(net) {} + + NetMutator& newOp( + const std::string& type, + const std::vector& inputs, + const std::vector& outputs); + + private: + caffe2::NetDef* net_; +}; + } // namespace testing } // namespace caffe2 -- 2.7.4