Supporting BiasAdd is needed for more ops (other than Conv2D) so pass is renamed into more general term.
Signed-off-by: Hyun Sik Yoon <hyunsik.yoon@samsung.com>
KNOB_BOOL(ConvertTensorTranspose, true, Resolve loco::TensorTranspose)
// Optimization pass
-KNOB_BOOL(UseFuseConv2DAddSubPass, true, Fuse TFLAdd or TFLSub into TFLConv2D)
+KNOB_BOOL(UseFuseBiasAddPass, true, Fuse TFLAdd or TFLSub into TFLConv2D)
KNOB_BOOL(UseFuseReluPass, true, Fuse TFLAdd or TFLSub into TFLConv2D or so)
* limitations under the License.
*/
-#include "FuseConv2DAddSubPass.h"
+#include "FuseBiasAddPass.h"
#include "Dialect/IR/TFLNodes.h"
#include "Dialect/IR/TFLDialect.h"
namespace exo
{
-bool FuseConv2DAddSubPass::run(loco::Graph *g)
+bool FuseBiasAddPass::run(loco::Graph *g)
{
Collector collector;
* limitations under the License.
*/
-#ifndef __PASS_FUSE_CONV2D_ADDSUB_PASS_H__
-#define __PASS_FUSE_CONV2D_ADDSUB_PASS_H__
+#ifndef __PASS_FUSE_BIASADD_PASS_H__
+#define __PASS_FUSE_BIASADD_PASS_H__
#include <logo/Pass.h>
{
/**
- * @brief Class to fuse TFLAdd or TFLSub that follows TFLConv2D
+ * @brief Class to fuse TFLAdd or TFLSub into Bias input of the following ops:
+ * - TFLConv2D, TFLDepthwiseConv2D
+ * - TODO Consider to add FullyConnected, LSTM-related op (see Toco's impelementation
+ * in ConvertLstmCellOperator)
*
* Case 1. Conv2D and TFLAdd
*
* @note TFLSub, of which x() == TFLConv2D and y() == TFLConst, will be fused.
* If x() == TFLConst and y() == TFLConv2D, it won't be fused.
*/
-struct FuseConv2DAddSubPass final : public logo::Pass
+struct FuseBiasAddPass final : public logo::Pass
{
- const char *name(void) const final { return "exo::FuseConv2DAddSubPass"; }
+ const char *name(void) const final { return "exo::FuseBiasAddPass"; }
bool run(loco::Graph *g) final;
};
} // namespace exo
-#endif // __PASS_FUSE_CONV2D_ADDSUB_PASS_H__
+#endif // __PASS_FUSE_BIASADD_PASS_H__
* limitations under the License.
*/
-#include "FuseConv2DAddSubPass.h"
+#include "FuseBiasAddPass.h"
#include "Dialect/IR/TFLNodes.h"
#include "TestGraph.h"
{
exo::test::TypeShapeReadyPhase test_phase;
- test_phase.add_pass<exo::FuseConv2DAddSubPass>();
+ test_phase.add_pass<exo::FuseBiasAddPass>();
test_phase.run(g.graph());
}
{
exo::test::TypeShapeReadyPhase test_phase;
- test_phase.add_pass<exo::FuseConv2DAddSubPass>();
+ test_phase.add_pass<exo::FuseBiasAddPass>();
test_phase.run(g.graph());
}
{
exo::test::TypeShapeReadyPhase test_phase;
- test_phase.add_pass<exo::FuseConv2DAddSubPass>();
+ test_phase.add_pass<exo::FuseBiasAddPass>();
test_phase.run(g.graph());
}
{
exo::test::TypeShapeReadyPhase test_phase;
- test_phase.add_pass<exo::FuseConv2DAddSubPass>();
+ test_phase.add_pass<exo::FuseBiasAddPass>();
test_phase.run(g.graph());
}
{
exo::test::TypeShapeReadyPhase test_phase;
- test_phase.add_pass<exo::FuseConv2DAddSubPass>();
+ test_phase.add_pass<exo::FuseBiasAddPass>();
test_phase.run(g.graph());
}
{
exo::test::TypeShapeReadyPhase test_phase;
- test_phase.add_pass<exo::FuseConv2DAddSubPass>();
+ test_phase.add_pass<exo::FuseBiasAddPass>();
test_phase.run(g.graph());
}
// Please add in alphabetical order
#include "Pass/FoldTransposeOfConst.h"
-#include "Pass/FuseConv2DAddSubPass.h"
+#include "Pass/FuseBiasAddPass.h"
#include "Pass/FuseReluPass.h"
#include "Pass/MergeConcatNodesPass.h"
#include "Pass/ShapeInferencePass.h"
phase.emplace_back(stdex::make_unique<FoldTransposeOfConst>());
- if (get<Knob::UseFuseConv2DAddSubPass>())
+ if (get<Knob::UseFuseBiasAddPass>())
{
- phase.emplace_back(stdex::make_unique<FuseConv2DAddSubPass>());
+ phase.emplace_back(stdex::make_unique<FuseBiasAddPass>());
}
if (get<Knob::UseFuseReluPass>())