return false;
moco::tf::TFConst *mulparam = nullptr;
- moco::tf::TFConv2D *conv2d = nullptr;
+ moco::tf::TFNode *precedingOp = nullptr;
// TODO support DepthWiseConv2D
// TODO support FullyConnected
if (xc != nullptr)
{
mulparam = xc;
- conv2d = dynamic_cast<moco::tf::TFConv2D *>(node->y());
+ precedingOp = dynamic_cast<moco::tf::TFNode *>(node->y());
}
else // yc != nullptr
{
mulparam = yc;
- conv2d = dynamic_cast<moco::tf::TFConv2D *>(node->x());
+ precedingOp = dynamic_cast<moco::tf::TFNode *>(node->x());
}
- if (conv2d == nullptr)
- return false;
-
assert(mulparam->dtype() == loco::DataType::FLOAT32);
- auto conv2d_fused =
- fused_conv_node<FuseType::Conv2D, moco::tf::TFConv2D>(graph, mulparam, conv2d);
+ moco::tf::TFNode *fused_node = nullptr;
+ if (auto conv2d = dynamic_cast<moco::tf::TFConv2D *>(precedingOp))
+ fused_node = fused_conv_node<FuseType::Conv2D, moco::tf::TFConv2D>(graph, mulparam, conv2d);
// Not ready yet
- if (conv2d_fused == nullptr)
+ if (fused_node == nullptr)
return false;
// Replace TFMul node with new precedingOp with fused kernel
// This will leave existing precedingOp as-is but can be removed if not used
// from other transformations
- replace(node).with(conv2d_fused);
+ replace(node).with(fused_node);
// TODO check if need to disconnect
// node->x(nullptr);
// node->y(nullptr);
return false;
moco::tf::TFConst *addparam = nullptr;
- moco::tf::TFConv2D *conv2d = nullptr;
+ moco::tf::TFNode *precedingOp = nullptr;
moco::tf::TFBiasAdd *biasadd = nullptr;
// TODO support DepthWiseConv2D
// TODO support FullyConnected
if (xc != nullptr)
{
addparam = xc;
- conv2d = dynamic_cast<moco::tf::TFConv2D *>(node->y());
- biasadd = dynamic_cast<moco::tf::TFBiasAdd *>(node->y());
+ precedingOp = dynamic_cast<moco::tf::TFNode *>(node->y());
}
else // yc != nullptr
{
addparam = yc;
- conv2d = dynamic_cast<moco::tf::TFConv2D *>(node->x());
- biasadd = dynamic_cast<moco::tf::TFBiasAdd *>(node->x());
+ precedingOp = dynamic_cast<moco::tf::TFNode *>(node->x());
}
auto addparam_shape_inf = addparam->annot<moco::tf::ShapeInferenceData>();
return false;
}
- if (conv2d != nullptr)
+ if (auto conv2d = dynamic_cast<moco::tf::TFConv2D *>(precedingOp))
{
biasadd = create_biasadd_node<moco::tf::TFConv2D>(graph, addparam, conv2d);
}
+ else if (auto old_bias_add = dynamic_cast<moco::tf::TFBiasAdd *>(precedingOp))
+ {
+ biasadd = old_bias_add;
+ }
if (biasadd == nullptr)
{