From d247296d05e01689e703235702e8fcee2973aabe Mon Sep 17 00:00:00 2001 From: Adwaith Anand Date: Wed, 12 Jul 2023 18:19:06 +0530 Subject: [PATCH] [FullyConnected] Added NHWC support for FC_Layer inference part. This also contains the unit tests to evaluate. **Self evaluation:** 1. Build test: [X]Passed [ ]Failed [ ]Skipped 2. Run test: [X]Passed [ ]Failed [ ]Skipped Signed-off-by: Adwaith Anand --- nntrainer/layers/fc_layer.cpp | 11 ++++++--- nntrainer/layers/layer_devel.h | 8 +++---- nntrainer/layers/layer_node.cpp | 7 +++--- test/unittest/layers/layers_golden_tests.cpp | 4 ++++ .../layers/unittest_layers_fully_connected.cpp | 26 +++++++++++++++++++++- 5 files changed, 44 insertions(+), 12 deletions(-) diff --git a/nntrainer/layers/fc_layer.cpp b/nntrainer/layers/fc_layer.cpp index 3b7cdeb..afe228e 100644 --- a/nntrainer/layers/fc_layer.cpp +++ b/nntrainer/layers/fc_layer.cpp @@ -66,17 +66,22 @@ void FullyConnectedLayer::finalize(InitLayerContext &context) { context.setEffDimFlagInputDimension(0, 0b1001); context.setDynDimFlagInputDimension(0, 0b1000); + bool is_nchw = (getTensorType() == Tformat::NCHW) ? true : false; /** set output dimensions */ auto const &in_dim = context.getInputDimensions()[0]; output_dims[0] = in_dim; - output_dims[0].width(unit); + is_nchw ? output_dims[0].width(unit) : output_dims[0].channel(unit); context.setOutputDimensions(output_dims); /** set weight specifications */ // @todo : This NCHW format setting is just temporal, it needs to be set by // global configuration - TensorDim bias_dim(1, 1, 1, unit, getTensorType(), 0b0001); - TensorDim weight_dim(1, 1, in_dim.width(), unit, getTensorType(), 0b0011); + + TensorDim bias_dim(1, is_nchw ? 1 : unit, 1, is_nchw ? unit : 1, + getTensorType(), is_nchw ? 0b0001 : 0b0100); + TensorDim weight_dim(1, is_nchw ? 1 : unit, is_nchw ? in_dim.width() : 1, + is_nchw ? unit : in_dim.channel(), getTensorType(), + is_nchw ? 0b0011 : 0b0101); weight_idx[FCParams::weight] = context.requestWeight( weight_dim, weight_initializer, weight_regularizer, diff --git a/nntrainer/layers/layer_devel.h b/nntrainer/layers/layer_devel.h index 23cf09f..98b1206 100644 --- a/nntrainer/layers/layer_devel.h +++ b/nntrainer/layers/layer_devel.h @@ -255,11 +255,9 @@ public: * @param Tensor Type : NCHW or NHWC */ void setTensorType(const std::string &values) { - if (values.compare("NCHW") || values.compare("nchw")) { - tensor_type = ml::train::TensorDim::Format::NCHW; - } else { - tensor_type = ml::train::TensorDim::Format::NHWC; - } + tensor_type = (values.compare("NCHW") == 0 || values.compare("nchw") == 0) + ? ml::train::TensorDim::Format::NCHW + : ml::train::TensorDim::Format::NHWC; } /** diff --git a/nntrainer/layers/layer_node.cpp b/nntrainer/layers/layer_node.cpp index 4bee254..d489e1e 100644 --- a/nntrainer/layers/layer_node.cpp +++ b/nntrainer/layers/layer_node.cpp @@ -249,9 +249,10 @@ void LayerNode::setOutputConnection(unsigned nth, const std::string &name, } void LayerNode::setTensorType(const std::string &type_) { - TensorDim::Format type = (type_.compare("NCHW") || type_.compare("nchw")) - ? TensorDim::Format::NCHW - : TensorDim::Format::NHWC; + TensorDim::Format type = + (type_.compare("NCHW") == 0 || type_.compare("nchw") == 0) + ? TensorDim::Format::NCHW + : TensorDim::Format::NHWC; getLayer()->setTensorType(type); } diff --git a/test/unittest/layers/layers_golden_tests.cpp b/test/unittest/layers/layers_golden_tests.cpp index 8189468..a320e64 100644 --- a/test/unittest/layers/layers_golden_tests.cpp +++ b/test/unittest/layers/layers_golden_tests.cpp @@ -50,6 +50,10 @@ static InitLayerContext createInitContext(Layer *layer, std::vector parsed; from_string(input_shape_str, parsed); + for (auto &p : parsed) { + p.get().setFormat(layer->getTensorType()); + } + InitLayerContext context({parsed.begin(), parsed.end()}, {true}, false, "golden_test"); layer->finalize(context); diff --git a/test/unittest/layers/unittest_layers_fully_connected.cpp b/test/unittest/layers/unittest_layers_fully_connected.cpp index c25168a..adace85 100644 --- a/test/unittest/layers/unittest_layers_fully_connected.cpp +++ b/test/unittest/layers/unittest_layers_fully_connected.cpp @@ -38,6 +38,30 @@ auto fc_basic_no_decay = LayerGoldenTestParamType( {"unit=5", "weight_decay=0.0", "bias_decay=0.0"}, "3:1:1:10", "fc_plain.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT, "nchw"); +auto fc_basic_plain_nhwc = LayerGoldenTestParamType( + nntrainer::createLayer, {"unit=5"}, + "3:10:1:1", "fc_plain.nnlayergolden", + LayerGoldenTestParamOptions::SKIP_CALC_DERIV | + LayerGoldenTestParamOptions::SKIP_CALC_GRAD, + "nhwc"); + +auto fc_basic_single_batch_nhwc = LayerGoldenTestParamType( + nntrainer::createLayer, {"unit=4"}, + "1:10:1:1", "fc_single_batch.nnlayergolden", + LayerGoldenTestParamOptions::SKIP_CALC_DERIV | + LayerGoldenTestParamOptions::SKIP_CALC_GRAD, + "nhwc"); + +auto fc_basic_no_decay_nhwc = LayerGoldenTestParamType( + nntrainer::createLayer, + {"unit=5", "weight_decay=0.0", "bias_decay=0.0"}, "3:10:1:1", + "fc_plain.nnlayergolden", + LayerGoldenTestParamOptions::SKIP_CALC_DERIV | + LayerGoldenTestParamOptions::SKIP_CALC_GRAD, + "nhwc"); + GTEST_PARAMETER_TEST(FullyConnected, LayerGoldenTest, ::testing::Values(fc_basic_plain, fc_basic_single_batch, - fc_basic_no_decay)); + fc_basic_no_decay, fc_basic_plain_nhwc, + fc_basic_single_batch_nhwc, + fc_basic_no_decay_nhwc)); -- 2.7.4