[ Mixed Tensor ] Enable FP32 unittest cases
[platform/core/ml/nntrainer.git] / nntrainer / layers / loss / loss_layer.h
index 00b520f..0307f5c 100644 (file)
@@ -52,6 +52,28 @@ public:
    */
   bool requireLabel() const override { return true; }
 
+  /**
+   * @brief set the Tensor Type for the layer
+   * @param     Tensor Type : NCHW or NHWC
+   */
+  void setTensorType(std::array<const std::string, 2> t_type) {
+    if (t_type[0].compare("NCHW") == 0 || t_type[0].compare("nchw") == 0) {
+      tensor_format = ml::train::TensorDim::Format::NCHW;
+    } else {
+      tensor_format = ml::train::TensorDim::Format::NHWC;
+    }
+
+    nntrainer::props::TensorDataType type_;
+
+    from_string(t_type[1], type_);
+
+    tensor_dtype = type_;
+  }
+
+private:
+  ml::train::TensorDim::Format tensor_format;
+  ml::train::TensorDim::DataType tensor_dtype;
+
 protected:
   /**
    * @brief     update loss