#include <tensorflow/core/framework/graph.pb.h>
-namespace moco
-{
-namespace tf
+namespace
{
-/**
- * @brief GraphBuilder for FusedBatchNorm node
- */
-class FusedBatchNormGraphBuilder final : public GraphBuilder
-{
-public:
- bool validate(const tensorflow::NodeDef &) const override;
- void build(const tensorflow::NodeDef &, GraphBuilderContext *) const override;
-};
+using namespace moco::tf;
/**
* @brief GraphUpdate for FusedBatchNorm node
std::vector<TensorName> _names;
};
+void FusedBatchNormGraphUpdate::input(const SymbolTable *tensor_names) const
+{
+ int num_inputs = _names.size();
+ assert(num_inputs == 5);
+
+ _node->input(tensor_names->node(_names[0]));
+ _node->gamma(tensor_names->node(_names[1]));
+ _node->beta(tensor_names->node(_names[2]));
+ _node->mean(tensor_names->node(_names[3]));
+ _node->variance(tensor_names->node(_names[4]));
+}
+
+} // namespace
+
+namespace moco
+{
+namespace tf
+{
+
+/**
+ * @brief GraphBuilder for FusedBatchNorm node
+ */
+class FusedBatchNormGraphBuilder final : public GraphBuilder
+{
+public:
+ bool validate(const tensorflow::NodeDef &) const override;
+ void build(const tensorflow::NodeDef &, GraphBuilderContext *) const override;
+};
+
bool FusedBatchNormGraphBuilder::validate(const tensorflow::NodeDef &node) const
{
assert(node.input_size() == 5);
updates->enroll(std::move(tf_fbn_update));
}
-void FusedBatchNormGraphUpdate::input(const SymbolTable *tensor_names) const
-{
- int num_inputs = _names.size();
- assert(num_inputs == 5);
-
- _node->input(tensor_names->node(_names[0]));
- _node->gamma(tensor_names->node(_names[1]));
- _node->beta(tensor_names->node(_names[2]));
- _node->mean(tensor_names->node(_names[3]));
- _node->variance(tensor_names->node(_names[4]));
-}
-
} // namespace tf
} // namespace moco