From f661ddf3ad248ea4912790b980d3b68cfbd09c5f Mon Sep 17 00:00:00 2001 From: hyeonseok lee Date: Tue, 26 Oct 2021 17:44:17 +0900 Subject: [PATCH] [layer] override setBatch in bn layer - Override setBatch function Self evaluation: Build test: [X]Passed [ ]Failed [ ]Skipped Run test: [X]Passed [ ]Failed [ ]Skipped Signed-off-by: hyeonseok lee --- nntrainer/layers/bn_layer.cpp | 6 ++++++ nntrainer/layers/bn_layer.h | 5 +++++ 2 files changed, 11 insertions(+) diff --git a/nntrainer/layers/bn_layer.cpp b/nntrainer/layers/bn_layer.cpp index 252b2ce..5d58cf4 100644 --- a/nntrainer/layers/bn_layer.cpp +++ b/nntrainer/layers/bn_layer.cpp @@ -255,4 +255,10 @@ void BatchNormalizationLayer::exportTo(Exporter &exporter, exporter.saveResult(bn_props, method, this); } +void BatchNormalizationLayer::setBatch(RunLayerContext &context, + unsigned int batch) { + context.updateTensor(wt_idx[BNParams::deviation], batch); + context.updateTensor(wt_idx[BNParams::t_full], batch); +} + } /* namespace nntrainer */ diff --git a/nntrainer/layers/bn_layer.h b/nntrainer/layers/bn_layer.h index 67acabd..d0f47b5 100644 --- a/nntrainer/layers/bn_layer.h +++ b/nntrainer/layers/bn_layer.h @@ -110,6 +110,11 @@ public: */ bool supportInPlace() const override { return true; } + /** + * @copydoc Layer::setBatch(RunLayerContext &context, unsigned int batch) + */ + void setBatch(RunLayerContext &context, unsigned int batch) override; + inline static const std::string type = "batch_normalization"; private: -- 2.7.4