#include "ResolveFusedBatchNorm.h"
#include "IR/TFAdd.h"
+#include "IR/TFConst.h"
#include "IR/TFMul.h"
#include "Convert.h"
namespace
{
-bool is_same_shape(loco::ConstGen *lc, loco::ConstGen *rc)
+bool is_same_shape(moco::tf::TFConst *lc, moco::tf::TFConst *rc)
{
if (lc->rank() != rc->rank())
return false;
return true;
}
-void copy_shape(const loco::ConstGen *src, loco::ConstGen *dst)
+void copy_shape(const moco::tf::TFConst *src, moco::tf::TFConst *dst)
{
assert(src != nullptr);
assert(dst != nullptr);
return false;
}
- auto tffbn_gamma = dynamic_cast<loco::ConstGen *>(node->gamma());
- auto tffbn_beta = dynamic_cast<loco::ConstGen *>(node->beta());
- auto tffbn_mean = dynamic_cast<loco::ConstGen *>(node->mean());
- auto tffbn_variance = dynamic_cast<loco::ConstGen *>(node->variance());
+ auto tffbn_gamma = dynamic_cast<moco::tf::TFConst *>(node->gamma());
+ auto tffbn_beta = dynamic_cast<moco::tf::TFConst *>(node->beta());
+ auto tffbn_mean = dynamic_cast<moco::tf::TFConst *>(node->mean());
+ auto tffbn_variance = dynamic_cast<moco::tf::TFConst *>(node->variance());
// all should be const
if (tffbn_gamma == nullptr || tffbn_beta == nullptr || tffbn_mean == nullptr ||
* %11:fbn_mul_0_param = ConstGen(fbn_mul_0_param)
* %21:fbn_add_param = ConstGen(fbn_add_param)
*/
- auto const_fbn_mul_0_param = graph->nodes()->create<loco::ConstGen>();
+ auto const_fbn_mul_0_param = graph->nodes()->create<moco::tf::TFConst>();
const_fbn_mul_0_param->dtype(loco::DataType::FLOAT32);
copy_shape(tffbn_gamma, const_fbn_mul_0_param);
const_fbn_mul_0_param->size<loco::DataType::FLOAT32>(const_num_elements);
{
const_fbn_mul_0_param->at<loco::DataType::FLOAT32>(i) = fbn_mul_0_param.get()[i];
}
- auto const_fbn_add_param = graph->nodes()->create<loco::ConstGen>();
+ auto const_fbn_add_param = graph->nodes()->create<moco::tf::TFConst>();
const_fbn_add_param->dtype(loco::DataType::FLOAT32);
copy_shape(tffbn_gamma, const_fbn_add_param);
const_fbn_add_param->size<loco::DataType::FLOAT32>(const_num_elements);