}
/**
+ * @brief Create a fused convolution opertion from kernel of fused mulparam
+ * @return Fused convolution operation
+ */
+template <FuseType FT, class T>
+T *fused_conv_node(loco::Graph *graph, moco::tf::TFConst *mulparam, T *conv_node)
+{
+ LOGGER(l);
+
+ // ker should be constant
+ auto ker = dynamic_cast<moco::tf::TFConst *>(conv_node->ker());
+ if (ker == nullptr)
+ {
+ // Wait until ker is becomes TFConst: there are cases when it's Identity.
+ INFO(l) << "Mul fuse_to_preceding: precedingOp ker is not TFConst";
+ return nullptr;
+ }
+ auto ifm = conv_node->ifm();
+ assert(ifm != nullptr);
+
+ // we need shape information, if not wait till it's ready
+ if (ker->annot<moco::tf::ShapeInferenceData>() == nullptr)
+ {
+ INFO(l) << "Mul fuse_to_preceding: precedingOp ker has no shape";
+ return nullptr;
+ }
+
+ auto mulparam_shape_inf = mulparam->annot<moco::tf::ShapeInferenceData>();
+ if (mulparam_shape_inf == nullptr)
+ {
+ INFO(l) << "Mul fuse_to_preceding: precedingOp mulparam has no shape";
+ return nullptr;
+ }
+ // if MulParam rank is not 1 we cannot fuse, just skip
+ auto mulparam_shape = mulparam_shape_inf->tensor_shape();
+ if (mulparam_shape.rank() != 1)
+ {
+ INFO(l) << "Mul fuse_to_preceding: Mul rank is not 1";
+ return nullptr;
+ }
+
+ auto ker_fused = create_kernel_from_fuse_mulparam<FT>(graph, ker, mulparam);
+ auto conv_fused = graph->nodes()->create<T>();
+
+ conv_fused->ifm(ifm);
+ conv_fused->ker(ker_fused);
+ conv_fused->padding(conv_node->padding());
+ conv_fused->data_layout(conv_node->data_layout());
+ conv_fused->strides(conv_node->strides());
+
+ return conv_fused;
+}
+
+/**
* @note This creates fused ker:2 from ker:1, 'mulparam' and
* new precedingOp:2 that uses ker:2 as the kernel.
* Then make C to use precedingOp:2 as new input.
assert(mulparam->dtype() == loco::DataType::FLOAT32);
- // ker should be constant
- auto ker = dynamic_cast<moco::tf::TFConst *>(conv2d->ker());
- if (ker == nullptr)
- {
- // Wait until ker is becomes TFConst: there are cases when it's Identity.
- INFO(l) << "Mul fuse_to_preceding: precedingOp ker is not TFConst";
- return false;
- }
- auto ifm = conv2d->ifm();
- assert(ifm != nullptr);
+ auto conv2d_fused =
+ fused_conv_node<FuseType::Conv2D, moco::tf::TFConv2D>(graph, mulparam, conv2d);
- // we need shape information, if not wait till it's ready
- if (ker->annot<moco::tf::ShapeInferenceData>() == nullptr)
- {
- INFO(l) << "Mul fuse_to_preceding: precedingOp ker has no shape";
+ // Not ready yet
+ if (conv2d_fused == nullptr)
return false;
- }
-
- auto mulparam_shape_inf = mulparam->annot<moco::tf::ShapeInferenceData>();
- if (mulparam_shape_inf == nullptr)
- {
- INFO(l) << "Mul fuse_to_preceding: mulparam has no shape";
- return false;
- }
- // if MulParam rank is not 1 we cannot fuse, just skip
- auto mulparam_shape = mulparam_shape_inf->tensor_shape();
- if (mulparam_shape.rank() != 1)
- {
- INFO(l) << "Mul fuse_to_preceding: Mul rank is not 1";
- return false;
- }
-
- auto ker_fused = create_kernel_from_fuse_mulparam<FuseType::Conv2D>(graph, ker, mulparam);
- auto conv2d_fused = graph->nodes()->create<moco::tf::TFConv2D>();
-
- conv2d_fused->ifm(ifm);
- conv2d_fused->ker(ker_fused);
- conv2d_fused->padding(conv2d->padding());
- conv2d_fused->data_layout(conv2d->data_layout());
- conv2d_fused->strides(conv2d->strides());
// Replace TFMul node with new precedingOp with fused kernel
// This will leave existing precedingOp as-is but can be removed if not used