TFLType *_fusable_node;
locoex::TFLConst *_const_node;
locoex::TFLConv2D *_conv2d_node;
+
+ locoex::TFLConst *create_fused_bias();
};
+template <typename TFLType> locoex::TFLConst *Fuser<TFLType>::create_fused_bias()
+{
+ // we have to create a new bias by adding/substracting bias and const node (of TFLAdd or TFLSub)
+ auto bias = dynamic_cast<locoex::TFLConst *>(_conv2d_node->bias());
+ assert(bias->dtype() == loco::DataType::FLOAT32 &&
+ _const_node->dtype() == loco::DataType::FLOAT32);
+
+ assert(bias->rank() == 1 && _const_node->rank() == 1);
+ assert(bias->dim(0) == _const_node->dim(0));
+
+ // build a new bias
+ auto new_bias = _graph->nodes()->create<locoex::TFLConst>();
+ {
+ new_bias->dtype(loco::DataType::FLOAT32);
+
+ new_bias->rank(1);
+ new_bias->dim(0) = bias->dim(0);
+
+ new_bias->size<loco::DataType::FLOAT32>(bias->dim(0).value());
+
+ for (uint32_t x = 0; x < bias->dim(0).value(); x++)
+ new_bias->at<loco::DataType::FLOAT32>(x) = calc<TFLType>(
+ bias->at<loco::DataType::FLOAT32>(x), _const_node->at<loco::DataType::FLOAT32>(x));
+ }
+
+ return new_bias;
+}
+
/**
* @brief fuse TFLAdd or TFLSub into conv2d
*/
template <typename TFLType> void Fuser<TFLType>::fuse(void)
{
+ // TODO replace code below with `auto new_bias = create_fused_bias();`
+
// conv2d bias could be either all 0 values or set with other values.
// we have to create a new bias by adding/substracting const_node and bias
auto bias = dynamic_cast<locoex::TFLConst *>(_conv2d_node->bias());
new_bias->at<loco::DataType::FLOAT32>(x) = calc<TFLType>(
bias->at<loco::DataType::FLOAT32>(x), _const_node->at<loco::DataType::FLOAT32>(x));
}
+ // TODO replace the above
// replace node with new_bias
// note that loco::replace() is not used because bias could be input of other op just in case