[Tensor] Support initializer with tensor
authorParichay Kapoor <pk.kapoor@samsung.com>
Tue, 27 Jul 2021 06:09:32 +0000 (15:09 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Tue, 10 Aug 2021 22:46:58 +0000 (07:46 +0900)
Support initializer with tensor object natively.
Initialization will be done when the memory is allocated or can be
invoked manually as well.

Signed-off-by: Parichay Kapoor <pk.kapoor@samsung.com>
nntrainer/tensor/tensor.cpp
nntrainer/tensor/tensor.h

index 9f0070b6796bbf5afb1a54f5823b865ff229c5b9..bafbdc2471cc06624527718e76fc2f87a1f8e9e7 100644 (file)
@@ -239,6 +239,62 @@ void Tensor::setRandUniform(float min, float max) {
     std::uniform_real_distribution<float>(min, max));
 }
 
+void Tensor::initialize() {
+  if (empty() || !isAllocated())
+    return;
+
+  unsigned int fan_in, fan_out;
+
+  /// @fixme: when unit is equal to one, this does not work, we need to rely on
+  /// effective dimension then actual numbers here. For now, some heuristics
+  /// added to infer what would be fan_in/fan_out
+  if (dim.batch() * dim.channel() * dim.height() == 1) {
+    fan_out = fan_in = dim.width();
+  } else if (dim.batch() * dim.channel() == 1) { /// fc layer - 2-D tensor
+    fan_in = dim.height();
+    fan_out = dim.width();
+  } else { /// conv2d filters - 4d tensor, @todo extend this to > 4
+    auto field_size = dim.height() * dim.width();
+
+    // this also handles below cases.
+    // 1. fan_in = fan_out = 1 as well.
+    // 2. batch == 1, channel == 1 and height == 1, theoretical rank of 1
+    fan_in = dim.channel() * field_size;
+    fan_out = dim.batch() * field_size;
+  }
+
+  switch (initializer) {
+  case Tensor::Initializer::ZEROS:
+    setZero();
+    break;
+  case Tensor::Initializer::ONES:
+    setValue(1.0f);
+    break;
+  case Tensor::Initializer::LECUN_NORMAL:
+    setRandNormal(0.0f, sqrtFloat(1.0f / fan_in));
+    break;
+  case Tensor::Initializer::XAVIER_NORMAL:
+    setRandNormal(0.0f, sqrtFloat(2.0f / (fan_in + fan_out)));
+    break;
+  case Tensor::Initializer::HE_NORMAL:
+    setRandNormal(0.0f, sqrtFloat(2.0f / (fan_in)));
+    break;
+  case Tensor::Initializer::LECUN_UNIFORM:
+    setRandUniform(-1.0f * sqrtFloat(1.0f / fan_in), sqrtFloat(1.0f / fan_in));
+    break;
+  case Tensor::Initializer::XAVIER_UNIFORM:
+    setRandUniform(-1.0f * sqrtFloat(6.0f / (fan_in + fan_out)),
+                   sqrtFloat(6.0 / (fan_in + fan_out)));
+    break;
+  case Tensor::Initializer::HE_UNIFORM:
+    setRandUniform(-1.0f * sqrtFloat(6.0f / (fan_in)),
+                   sqrtFloat(6.0 / (fan_in)));
+    break;
+  default:
+    break;
+  }
+}
+
 Tensor::Tensor(
   std::vector<std::vector<std::vector<std::vector<float>>>> const &d) {
 
index aba85af857fc25852949fa6f40a4575dc1a3f763..f1ef5e8d50a1f0814e2d1850fc658b7e03007123 100644 (file)
@@ -207,9 +207,10 @@ public:
 
   friend void swap(Tensor &lhs, Tensor &rhs) noexcept {
     std::swap(lhs.dim, rhs.dim);
-    std::swap(lhs.data, rhs.data);
     std::swap(lhs.strides, rhs.strides);
     std::swap(lhs.is_contiguous, rhs.is_contiguous);
+    std::swap(lhs.initializer, rhs.initializer);
+    std::swap(lhs.data, rhs.data);
   }
 
   /**
@@ -658,11 +659,6 @@ public:
    */
   void standardization_i();
 
-  /**
-   * @brief     Fill the Tensor elements with zero
-   */
-  void setZero();
-
   /**
    * @brief Apply instantly to the element
    *
@@ -767,6 +763,11 @@ public:
    */
   void setValue(float value);
 
+  /**
+   * @brief     Fill the Tensor elements with zero
+   */
+  void setZero();
+
   /**
    * @brief     Set the tensor with random normal distribution
    * @param[in] mean mean of the distribution
@@ -781,6 +782,20 @@ public:
    */
   void setRandUniform(float min = -0.05f, float max = 0.05f);
 
+  /**
+   * @brief     Initialize the memory of the given tensor
+   */
+  void initialize();
+
+  /**
+   * @brief     Initialize the memory of the given tensor
+   * @param     init Initiailizer to use for the initialization
+   */
+  void initialize(Initializer init) {
+    initializer = init;
+    initialize();
+  }
+
   /**
    * @brief     Copy the Tensor
    * @param[in] from Tensor to be copied