is_output(false),
is_virtual(false),
is_trainable(true),
+ is_to_be_removed(false),
need_reorder_weight(false),
node_owned_variable(),
/// @todo distinguish between uninitialized and ADD operator.
return &t;
},
context.getNumWeights());
+
+ if (context.getNumWeights() == 0) {
+ is_trainable = false;
+ }
}
void TfOpNode::setWeightTransformFn(TransformFn fn) { weight_transform = fn; }
void TfOpNode::setInputTransformFn(TransformFn fn) { input_transform = fn; }
+void TfOpNode::setWeights(Variables weights_) {
+ unsigned int cnt = 0;
+ for (auto &w : weights_) {
+ const unsigned int UNIT = w->batch();
+ const unsigned int CHANNEL = w->channel();
+ const unsigned int HEIGHT = w->height();
+ const unsigned int WIDTH = w->width();
+
+ auto weight_data = weights.at(cnt)->getData();
+ auto *ptr = const_cast<float *>(weight_data);
+ memcpy(&ptr[0], &w->getData()[0],
+ sizeof(float) * (UNIT * CHANNEL * HEIGHT * WIDTH));
+ cnt++;
+ }
+}
+
void TfOpNode::weightReorder(unsigned int node_count) {
if (need_reorder_weight == true) {
case tflite::BuiltinOperator_SOFTMAX:
case tflite::BuiltinOperator_TRANSPOSE:
case tflite::BuiltinOperator_MUL:
+
return builtin_ops;
default:
throw std::runtime_error{"Unsupported operator"};
void setNeedReorderWeight() { need_reorder_weight = true; }
/**
+ * @brief Set the To Be Removed object
+ *
+ */
+ void setToBeRemoved(bool to_be_removed) { is_to_be_removed = to_be_removed; }
+
+ /**
* @brief Set the Trainable object
*
*/
void setTrainable(bool trainable) { is_trainable = trainable; }
/**
+ * @brief Set the Inputs object
+ *
+ * @param inputs_
+ */
+ void setInputs(const Variables &inputs_) { inputs = inputs_; }
+
+ /**
+ * @brief Set the Outputs object
+ *
+ * @param outputs_
+ */
+ void setOutputs(const Variables &outputs_) { outputs = outputs_; }
+
+ /**
+ * @brief Set the Weights object
+ *
+ * @param weights_
+ */
+ void setWeights(Variables weights_);
+ /**
+ * @brief Replace the Weights object
+ *
+ * @param weights_
+ */
+ void replaceWeights(const Variables &weights_) { weights = weights_; }
+ /**
+ * @brief Set(Append) the Props object
+ *
+ * @param value
+ */
+ void AppendProps(const int &value) { props_vector.push_back(value); }
+
+ /**
+ * @brief Set(Append) the Additional Props object
+ *
+ * @param value
+ */
+ void AppendAdditionalProps(const float &value) {
+ additional_props.push_back(value);
+ }
+
+ /**
* @brief Reorder Weight in case of NCHW --> NHWC
*
* @param node_count
bool isTrainable() const { return is_trainable; }
/**
+ * @brief check if this layer is to be removed
+ *
+ * @return true
+ * @return false
+ */
+ bool isToBeRemoved() const { return is_to_be_removed; }
+
+ /**
+ * @brief Get the Props Vector
+ *
+ * @return const std::vector<int> props_vector
+ */
+ std::vector<int> getProps() const { return props_vector; }
+
+ /**
+ * @brief Get the Additional Props Vector
+ *
+ * @return const std::vector<float> additional_props
+ */
+ std::vector<float> getAdditionalProps() const { return additional_props; }
+
+ /**
* @brief Get the Op Type object
*
* @return const tflite::BuiltinOperator
Variables outputs; /**< output variables */
Variables weights; /**< weight variables */
std::vector<TfOpNode *> input_nodes; /**< input nodes */
+ std::vector<int> props_vector; /**< props vector */
+ std::vector<float> additional_props; /**< additional props vector */
- TransformFn weight_transform; /**< weight transforms */
/**
* Q) Why do we need input transform?
* A) To transform the nntrainer input data format(NCHW) to tflite
*format(NHWC)
**/
- TransformFn input_transform; /**< input transforms */
+ TransformFn weight_transform; /**< weight transforms */
+ TransformFn input_transform; /**< input transforms */
bool is_input; /**< true if given input is input; */
bool is_output; /**< true if given output is output; */
bool is_virtual; /**< true if given node is virtual; */
bool is_trainable; /**< true if given node has weight and trainable */
+ bool is_to_be_removed; /**< true if given node is to be removed */
bool need_reorder_weight; /**< true if given node need to reorder weight; */
/// @todo change to shared_ptr or unique_ptr