[nnkit] Introduce randomize action (#399)
author박종현/동작제어Lab(SR)/Staff Engineer/삼성전자 <jh1302.park@samsung.com>
Thu, 28 Jun 2018 01:05:11 +0000 (10:05 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Thu, 28 Jun 2018 01:05:11 +0000 (10:05 +0900)
This commit introduces randomize action which randomizes all the tensors
in a tensor context.

Signed-off-by: Jonghyun Park <jh1302.park@samsung.com>
contrib/nnkit/actions/builtin/CMakeLists.txt [new file with mode: 0644]
contrib/nnkit/actions/builtin/Randomize.cpp [new file with mode: 0644]

diff --git a/contrib/nnkit/actions/builtin/CMakeLists.txt b/contrib/nnkit/actions/builtin/CMakeLists.txt
new file mode 100644 (file)
index 0000000..51db367
--- /dev/null
@@ -0,0 +1,2 @@
+add_library(nnkit_randomize_action SHARED Randomize.cpp)
+target_link_libraries(nnkit_randomize_action nnkit_intf_action)
diff --git a/contrib/nnkit/actions/builtin/Randomize.cpp b/contrib/nnkit/actions/builtin/Randomize.cpp
new file mode 100644 (file)
index 0000000..2c9e59d
--- /dev/null
@@ -0,0 +1,45 @@
+#include <nnkit/Action.h>
+
+#include <nncc/core/ADT/tensor/IndexRange.h>
+
+#include <chrono>
+#include <random>
+
+using nnkit::TensorContext;
+
+struct RandomizeAction final : public nnkit::Action
+{
+  void run(TensorContext &ctx) override
+  {
+    int seed = std::chrono::system_clock::now().time_since_epoch().count();
+
+    std::minstd_rand rand(seed);
+    std::normal_distribution<float> dist(0.0f, 2.0f);
+
+    for (uint32_t n = 0; n < ctx.size(); ++n)
+    {
+      using nncc::core::ADT::tensor::Accessor;
+
+      auto fn = [&dist, &rand] (const TensorContext &ctx, uint32_t n, Accessor<float> &t)
+      {
+        using nncc::core::ADT::tensor::range;
+        using nncc::core::ADT::tensor::Index;
+
+        range(ctx.shape(n)).iterate() << [&t, &dist, &rand] (const Index &i)
+        {
+          t.at(i) = dist(rand);
+        };
+      };
+
+      ctx.getMutableFloatTensor(n, fn);
+    }
+  }
+};
+
+#include <nnkit/CmdlineArguments.h>
+#include <nncc/foundation/Memory.h>
+
+extern "C" std::unique_ptr<nnkit::Action> make_action(const nnkit::CmdlineArguments &args)
+{
+  return nncc::foundation::make_unique<RandomizeAction>();
+}