{
/**
- * @brief GraphBuilder for TF FusedBatchNorm node
+ * @brief GraphBuilder for FusedBatchNorm node
*/
-class TFFusedBatchNormGraphBuilder final : public GraphBuilder
+class FusedBatchNormGraphBuilder final : public GraphBuilder
{
public:
bool validate(const tensorflow::NodeDef &) const override;
};
/**
- * @brief GraphUpdate for TF FusedBatchNorm node
+ * @brief GraphUpdate for FusedBatchNorm node
*/
-class TFFusedBatchNormGraphUpdate final : public GraphUpdate
+class FusedBatchNormGraphUpdate final : public GraphUpdate
{
public:
- TFFusedBatchNormGraphUpdate(TFFusedBatchNorm *node, std::vector<TensorName> names)
+ FusedBatchNormGraphUpdate(TFFusedBatchNorm *node, std::vector<TensorName> names)
: _node(node), _names(names)
{
}
std::vector<TensorName> _names;
};
-bool TFFusedBatchNormGraphBuilder::validate(const tensorflow::NodeDef &node) const
+bool FusedBatchNormGraphBuilder::validate(const tensorflow::NodeDef &node) const
{
assert(node.input_size() == 5);
return has_attrs(node, {"epsilon"});
}
-void TFFusedBatchNormGraphBuilder::build(const tensorflow::NodeDef &node,
- GraphBuilderContext *context) const
+void FusedBatchNormGraphBuilder::build(const tensorflow::NodeDef &node,
+ GraphBuilderContext *context) const
{
assert(context != nullptr);
fbn_input_names.push_back(TensorName(node.input(3))); // mean
fbn_input_names.push_back(TensorName(node.input(4))); // variance
- auto tf_fbn_update = stdex::make_unique<TFFusedBatchNormGraphUpdate>(tf_fbn, fbn_input_names);
+ auto tf_fbn_update = stdex::make_unique<FusedBatchNormGraphUpdate>(tf_fbn, fbn_input_names);
updates->enroll(std::move(tf_fbn_update));
}
-void TFFusedBatchNormGraphUpdate::input(const SymbolTable *tensor_names) const
+void FusedBatchNormGraphUpdate::input(const SymbolTable *tensor_names) const
{
int num_inputs = _names.size();
assert(num_inputs == 5);
#include "GraphBuilderRegistry.h"
-REGISTER_OP_BUILDER(FusedBatchNorm, TFFusedBatchNormGraphBuilder)
+REGISTER_OP_BUILDER(FusedBatchNorm, FusedBatchNormGraphBuilder)