class FC : public NeuralNetOperator {
public:
- FC() : NeuralNetOperator(NNKind::FC) {}
+ FC(int axis = 1, int axisW = 1)
+ : NeuralNetOperator(NNKind::FC), axis_(axis), axisW_(axisW) {}
~FC() {}
NOMNIGRAPH_DEFINE_NN_RTTI(FC);
+ int getAxis() const {
+ return axis_;
+ }
+
+ int getAxisW() const {
+ return axisW_;
+ }
+
+ void setAxis(int axis) {
+ axis_ = axis;
+ }
+
+ void setAxisW(int axisW) {
+ axisW_ = axisW;
+ }
+
private:
+ int axis_;
+ int axisW_;
};
class GivenTensorFill : public NeuralNetOperator {
- Max : float
FC
+- Axis : int : 1
+- AxisW : int : 1
+
GivenTensorFill
Concat
- Axis : int : -1
namespace caffe2 {
+USE_IDEEP_DEF_ALIASES();
+
+static inline itensor::dims CanonicalDims(itensor::dims adims, int32_t axis) {
+ CAFFE_ENFORCE(axis < (int32_t)adims.size(), "Invalid axis!");
+ CAFFE_ENFORCE(axis > (int32_t)-adims.size(), "Invalid axis!");
+ if (adims.size() == 2 || axis == 1) {
+ return adims;
+ }
+ if (axis < 0) {
+ axis += (int32_t)adims.size();
+ }
+
+ auto dim0 = std::accumulate(
+ adims.begin(),
+ adims.begin() + axis,
+ 1,
+ std::multiplies<itensor::dim_t>());
+ auto dim1 = std::accumulate(
+ adims.begin() + axis, adims.end(), 1, std::multiplies<itensor::dim_t>());
+ return itensor::dims({dim0, dim1});
+}
+
class IDEEPFullyConnectedOp final : public IDEEPOperator {
public:
USE_IDEEP_DEF_ALIASES();
IDEEPFullyConnectedOp(const OperatorDef& operator_def, Workspace* ws)
: IDEEPOperator(operator_def, ws),
- float16_compute_(
- OperatorBase::GetSingleArgument<bool>("float16_compute", false)) {}
+ axis_(OperatorBase::GetSingleArgument<int32_t>("axis", 1)),
+ axis_w_(OperatorBase::GetSingleArgument<int32_t>("axis_w", 1)) {}
virtual ~IDEEPFullyConnectedOp() {}
bool RunOnDevice() override {
const auto& filter = Input(FILTER);
auto* Y = Output(OUTPUT);
+ itensor X_in = X;
+ auto X_dims = CanonicalDims(X_in.get_dims(), axis_);
+ if (X_in.get_dims() != X_dims) {
+ X_in.reshape(X_dims);
+ }
+
+ itensor filter_in = filter;
+ auto filter_dims = CanonicalDims(filter_in.get_dims(), axis_w_);
+ if (filter_in.get_dims() != filter_dims) {
+ filter_in.reshape(filter_dims);
+ }
+
if (InputSize() > BIAS) {
- ideep::inner_product_forward::compute(X, filter, Input(BIAS), *Y);
+ ideep::inner_product_forward::compute(X_in, filter_in, Input(BIAS), *Y);
} else {
- ideep::inner_product_forward::compute(X, filter, *Y);
+ ideep::inner_product_forward::compute(X_in, filter_in, *Y);
}
return true;
}
private:
- bool float16_compute_;
+ size_t axis_{1};
+ size_t axis_w_{1};
INPUT_TAGS(INPUT, FILTER, BIAS);
OUTPUT_TAGS(OUTPUT);
IDEEPFullyConnectedGradientOp(const OperatorDef& operator_def, Workspace* ws)
: IDEEPOperator(operator_def, ws),
- float16_compute_(
- OperatorBase::GetSingleArgument<bool>("float16_compute", false)) {}
+ axis_(OperatorBase::GetSingleArgument<int32_t>("axis", 1)),
+ axis_w_(OperatorBase::GetSingleArgument<int32_t>("axis_w", 1)) {}
virtual ~IDEEPFullyConnectedGradientOp() {}
bool RunOnDevice() override {
auto* dfilter = Output(FILTER_GRAD);
auto* dbias = Output(BIAS_GRAD);
- ideep::inner_product_backward_weights::compute(X, dY, *dfilter, *dbias);
+ itensor X_in = X;
+ auto X_dims = CanonicalDims(X_in.get_dims(), axis_);
+ if (X_in.get_dims() != X_dims) {
+ X_in.reshape(X_dims);
+ }
+
+ itensor filter_in = filter;
+ auto filter_dims = CanonicalDims(filter_in.get_dims(), axis_w_);
+ if (filter_in.get_dims() != filter_dims) {
+ filter_in.reshape(filter_dims);
+ }
+
+ ideep::inner_product_backward_weights::compute(X_in, dY, *dfilter, *dbias);
if (OutputSize() > INPUT_GRAD) {
ideep::inner_product_backward_data::compute(
- dY, filter, X.get_dims(), *Output(INPUT_GRAD));
+ dY, filter_in, X_in.get_dims(), *Output(INPUT_GRAD));
}
return true;
}
private:
- bool float16_compute_;
+ size_t axis_{1};
+ size_t axis_w_{1};
INPUT_TAGS(INPUT, FILTER, OUTPUT_GRAD);
OUTPUT_TAGS(FILTER_GRAD, BIAS_GRAD, INPUT_GRAD);
#include <limits>
-#include "caffe2/opt/converter.h"
#include "caffe2/core/logging.h"
+#include "caffe2/opt/converter.h"
#include "nomnigraph/Graph/Algorithms.h"
return argMap;
}
-repr::NeuralNetOperator::NNLayout getLayout(std::map<std::string, caffe2::Argument> argMap) {
+repr::NeuralNetOperator::NNLayout getLayout(
+ std::map<std::string, caffe2::Argument> argMap) {
auto arg = argMap.find("order");
if (arg != argMap.end()) {
auto order = argMap["order"].s();
return op;
}
-std::vector<int> getKernelShape(std::map<std::string, caffe2::Argument> argMap) {
+std::vector<int> getKernelShape(
+ std::map<std::string, caffe2::Argument> argMap) {
// There are literally three ways to define shapes in Conv in Caffe2
std::vector<int> kernelShape;
if (argMap.count("kernel")) {
};
REGISTER_CONVERTER(Concat, ConcatConverter);
+class FCConverter : public Converter {
+ std::unique_ptr<nom::repr::NeuralNetOperator> convertToNeuralNetOperator(
+ const OperatorDef& op) override {
+ std::unique_ptr<repr::NeuralNetOperator> nnOp =
+ util::make_unique<repr::FC>();
+ auto argMap = getArgumentsFromOperator(op);
+
+ auto c = dyn_cast<repr::FC>(nnOp.get());
+ if (argMap.count("axis")) {
+ CAFFE_ENFORCE(argMap["axis"].has_i(), "Invalid axis argument");
+ int axis = static_cast<int>(argMap["axis"].i());
+ c->setAxis(axis);
+ }
+ if (argMap.count("axis_w")) {
+ CAFFE_ENFORCE(argMap["axis_w"].has_i(), "Invalid axis_w argument");
+ int axis_w = static_cast<int>(argMap["axis_w"].i());
+ c->setAxisW(axis_w);
+ }
+
+ return nnOp;
+ }
+ // Does not override default converter to OperatorDef
+
+ virtual ~FCConverter() {}
+};
+REGISTER_CONVERTER(FC, FCConverter);
+
} // namespace
std::unique_ptr<repr::NeuralNetOperator> convertToNeuralNetOperator(
return nnOp;
}
-
/// \brief Ingest a caffe2 protobuf model and output an NNModule.
/// \param net The caffe2 protobuf NetDef
repr::NNModule convertToNNModule(
for (const auto& op : net.op()) {
auto opNode = dfg.createNode(); // Create an empty node for the operator.
// First calculate in-edges (data dependencies).
- for (const auto &input : op.input()) {
+ for (const auto& input : op.input()) {
// If we've never seen this tensor, make one.
if (!blobMap.count(input)) {
auto tensor = util::make_unique<repr::Tensor>(input);
}
// Then save outputs into the blobMap for later consumption.
- for (const auto &output : op.output()) {
+ for (const auto& output : op.output()) {
auto tensor = util::make_unique<repr::Tensor>(output);
auto tensorNode =
dfg.createNode(unique_dyn_cast<repr::NeuralNetData>(tensor));
externalInputNames.size(),
" unused blobs: ",
os.str());
- // Otherwise, we add the blobs to the graph as no-ops
+ // Otherwise, we add the blobs to the graph as no-ops
} else {
for (const auto& input : externalInputNames) {
blobMap[input] = dfg.createNode(util::make_unique<repr::Tensor>(input));
caffe2::OperatorDef convertToOperatorDef(
const repr::NNGraph::NodeRef& instrNode) {
- auto *nnOp = repr::nn::get<repr::NeuralNetOperator>(instrNode);
+ auto* nnOp = repr::nn::get<repr::NeuralNetOperator>(instrNode);
auto op_type = nnOp->getName();
- auto *annotation = nnOp->getAnnotation();
+ auto* annotation = nnOp->getAnnotation();
caffe2::OperatorDef op;
if (ConverterRegistry()->Has(op_type)) {
return c2_annotation;
}
-caffe2::NetDef convertToCaffe2Proto(repr::NNModule &m) {
+caffe2::NetDef convertToCaffe2Proto(repr::NNModule& m) {
auto predictNet = caffe2::NetDef();
return convertToCaffe2Proto(m, predictNet);
}
return out;
}
-caffe2::NetDef convertToCaffe2Proto(repr::NNModule &m, const caffe2::NetDef& oldNet) {
+caffe2::NetDef convertToCaffe2Proto(
+ repr::NNModule& m,
+ const caffe2::NetDef& oldNet) {
auto predictNet = caffe2::NetDef();
// We copy the old net rather than mutate it.
predictNet.CopyFrom(oldNet);
// Simply iterate through the CFG and populate data dependencies
// with the DFG
- for (const auto &bbNode : m.controlFlow.getMutableNodes()) {
+ for (const auto& bbNode : m.controlFlow.getMutableNodes()) {
if (bbNode->getOutEdges().size() > 1) {
CAFFE_THROW("Control flow not yet supported in Caffe2 converter.");
}
for (const auto& instrNode : bb.getInstructions()) {
caffe2::OperatorDef op = convertToOperatorDef(instrNode);
- for (const auto &inEdge : instrNode->getInEdges()) {
- auto *tensorNode =
+ for (const auto& inEdge : instrNode->getInEdges()) {
+ auto* tensorNode =
dyn_cast<repr::NeuralNetData>(inEdge->tail()->data().get());
*op.add_input() = tensorNode->getName();
}
- for (const auto &outEdge : instrNode->getOutEdges()) {
- auto *tensorNode =
+ for (const auto& outEdge : instrNode->getOutEdges()) {
+ auto* tensorNode =
dyn_cast<repr::NeuralNetData>(outEdge->head()->data().get());
*op.add_output() = tensorNode->getName();
}
- auto *nnOp = repr::nn::get<repr::NeuralNetOperator>(instrNode);
+ auto* nnOp = repr::nn::get<repr::NeuralNetOperator>(instrNode);
if (nnOp->getLayout() != repr::NeuralNetOperator::NNLayout::Undefined) {
-
caffe2::Argument* arg = nullptr;
for (int i = 0; i < op.arg_size(); ++i) {
auto arg_ = op.mutable_arg(i);
from __future__ import unicode_literals
import unittest
+from functools import reduce
import hypothesis.strategies as st
from hypothesis import given, settings
import numpy as np
for i in range(3):
self.assertGradientChecks(gc, op, [X, W, b], i, [0])
+ @given(n=st.integers(1, 5),
+ m=st.integers(1, 5),
+ c=st.integers(1, 5),
+ h=st.integers(1, 5),
+ w=st.integers(1, 5),
+ axis=st.integers(1, 3),
+ **mu.gcs)
+ def test_fc_with_axis(self, n, m, c, h, w, axis, gc, dc):
+ X = np.random.rand(n, c, h, w).astype(np.float32) - 0.5
+ k = reduce((lambda x, y: x * y), [n, c, h, w][axis - 4:])
+ nn = reduce((lambda x, y: x * y), [n, c, h, w][:axis])
+ W = np.random.rand(m, k).astype(np.float32) - 0.5
+ b = np.random.rand(m).astype(np.float32) - 0.5
+ dY = np.random.rand(nn, m).astype(np.float32) - 0.5
+
+ op0 = core.CreateOperator(
+ 'FC',
+ ['X', 'W', 'b'],
+ ["Y"],
+ axis=axis,
+ device_option=dc[0]
+ )
+
+ op0_bw = core.CreateOperator(
+ 'FCGradient',
+ ['X', 'W', 'dY'],
+ ["dW", "db"],
+ axis=axis,
+ device_option=dc[0]
+ )
+
+ workspace.ResetWorkspace()
+ workspace.FeedBlob('X', X, dc[0])
+ workspace.FeedBlob('W', W, dc[0])
+ workspace.FeedBlob('b', b, dc[0])
+ workspace.RunOperatorOnce(op0)
+ Y0 = workspace.FetchBlob('Y')
+
+ workspace.FeedBlob('dY', dY, dc[0])
+ workspace.RunOperatorOnce(op0_bw)
+ dW0 = workspace.FetchBlob('dW')
+ db0 = workspace.FetchBlob('db')
+
+ op1 = core.CreateOperator(
+ 'FC',
+ ['X', 'W', 'b'],
+ ["Y"],
+ axis=axis,
+ device_option=dc[1]
+ )
+
+ op1_bw = core.CreateOperator(
+ 'FCGradient',
+ ['X', 'W', 'dY'],
+ ["dW", "db"],
+ axis=axis,
+ device_option=dc[1]
+ )
+
+ workspace.SwitchWorkspace("_device_check_", True)
+ workspace.FeedBlob('X', X, dc[1])
+ workspace.FeedBlob('W', W, dc[1])
+ workspace.FeedBlob('b', b, dc[1])
+ workspace.RunOperatorOnce(op1)
+ Y1 = workspace.FetchBlob('Y')
+
+ workspace.FeedBlob('dY', dY, dc[1])
+ workspace.RunOperatorOnce(op1_bw)
+ dW1 = workspace.FetchBlob('dW')
+ db1 = workspace.FetchBlob('db')
+
+ Y0 = Y0.flatten()
+ Y1 = Y1.flatten()
+ if not np.allclose(Y0, Y1, atol=0.01, rtol=0.01):
+ print(Y1)
+ print(Y0)
+ print(np.max(np.abs(Y1 - Y0)))
+ self.assertTrue(False)
+
+ dW0 = dW0.flatten()
+ dW1 = dW1.flatten()
+ if not np.allclose(dW0, dW1, atol=0.01, rtol=0.01):
+ print(dW1)
+ print(dW0)
+ print(np.max(np.abs(dW1 - dW0)))
+ self.assertTrue(False)
+
+ db0 = db0.flatten()
+ db1 = db1.flatten()
+ if not np.allclose(db0, db1, atol=0.01, rtol=0.01):
+ print(db1)
+ print(db0)
+ print(np.max(np.abs(db1 - db0)))
+ self.assertTrue(False)
+
+ @given(n=st.integers(1, 5),
+ o=st.integers(1, 5),
+ i=st.integers(1, 5),
+ h=st.integers(1, 5),
+ w=st.integers(1, 5),
+ axis_w=st.integers(1, 3),
+ **mu.gcs)
+ def test_fc_with_axis_w(self, n, o, i, h, w, axis_w, gc, dc):
+ W = np.random.rand(o, i, h, w).astype(np.float32) - 0.5
+ k = reduce((lambda x, y: x * y), [o, i, h, w][axis_w - 4:])
+ m = reduce((lambda x, y: x * y), [o, i, h, w][:axis_w])
+ X = np.random.rand(n, k).astype(np.float32) - 0.5
+ b = np.random.rand(m).astype(np.float32) - 0.5
+ dY = np.random.rand(n, m).astype(np.float32) - 0.5
+
+ op0 = core.CreateOperator(
+ 'FC',
+ ['X', 'W', 'b'],
+ ["Y"],
+ axis_w=axis_w,
+ device_option=dc[0]
+ )
+
+ op0_bw = core.CreateOperator(
+ 'FCGradient',
+ ['X', 'W', 'dY'],
+ ["dW", "db"],
+ axis_w=axis_w,
+ device_option=dc[0]
+ )
+
+ workspace.ResetWorkspace()
+ workspace.FeedBlob('X', X, dc[0])
+ workspace.FeedBlob('W', W, dc[0])
+ workspace.FeedBlob('b', b, dc[0])
+ workspace.RunOperatorOnce(op0)
+ Y0 = workspace.FetchBlob('Y')
+
+ workspace.FeedBlob('dY', dY, dc[0])
+ workspace.RunOperatorOnce(op0_bw)
+ dW0 = workspace.FetchBlob('dW')
+ db0 = workspace.FetchBlob('db')
+
+ op1 = core.CreateOperator(
+ 'FC',
+ ['X', 'W', 'b'],
+ ["Y"],
+ axis_w=axis_w,
+ device_option=dc[1]
+ )
+
+ op1_bw = core.CreateOperator(
+ 'FCGradient',
+ ['X', 'W', 'dY'],
+ ["dW", "db"],
+ axis_w=axis_w,
+ device_option=dc[1]
+ )
+
+ workspace.SwitchWorkspace("_device_check_", True)
+ workspace.FeedBlob('X', X, dc[1])
+ workspace.FeedBlob('W', W, dc[1])
+ workspace.FeedBlob('b', b, dc[1])
+ workspace.RunOperatorOnce(op1)
+ Y1 = workspace.FetchBlob('Y')
+
+ workspace.FeedBlob('dY', dY, dc[1])
+ workspace.RunOperatorOnce(op1_bw)
+ dW1 = workspace.FetchBlob('dW')
+ db1 = workspace.FetchBlob('db')
+
+ Y0 = Y0.flatten()
+ Y1 = Y1.flatten()
+ if not np.allclose(Y0, Y1, atol=0.01, rtol=0.01):
+ print(Y1)
+ print(Y0)
+ print(np.max(np.abs(Y1 - Y0)))
+ self.assertTrue(False)
+
+ dW0 = dW0.flatten()
+ dW1 = dW1.flatten()
+ if not np.allclose(dW0, dW1, atol=0.01, rtol=0.01):
+ print(dW1)
+ print(dW0)
+ print(np.max(np.abs(dW1 - dW0)))
+ self.assertTrue(False)
+
+ db0 = db0.flatten()
+ db1 = db1.flatten()
+ if not np.allclose(db0, db1, atol=0.01, rtol=0.01):
+ print(db1)
+ print(db0)
+ print(np.max(np.abs(db1 - db0)))
+ self.assertTrue(False)
+
if __name__ == "__main__":
unittest.main()