*/
bool requireLabel() const override { return true; }
+ /**
+ * @brief set the Tensor Type for the layer
+ * @param Tensor Type : NCHW or NHWC
+ */
+ void setTensorType(std::array<const std::string, 2> t_type) {
+ if (t_type[0].compare("NCHW") == 0 || t_type[0].compare("nchw") == 0) {
+ tensor_format = ml::train::TensorDim::Format::NCHW;
+ } else {
+ tensor_format = ml::train::TensorDim::Format::NHWC;
+ }
+
+ nntrainer::props::TensorDataType type_;
+
+ from_string(t_type[1], type_);
+
+ tensor_dtype = type_;
+ }
+
+private:
+ ml::train::TensorDim::Format tensor_format;
+ ml::train::TensorDim::DataType tensor_dtype;
+
protected:
/**
* @brief update loss