class StaticShapeInferer : public ir::OperationVisitor
{
public:
- StaticShapeInferer(compiler::LoweredGraph *lowered_subg)
+ StaticShapeInferer(compiler::ILoweredGraph *lowered_subg)
: _lowered_subg{lowered_subg}, _subg_input_observers{}, _controlflow_output_observer{nullptr},
_child_inferers{}
{
void dump();
/**
- * @brief Create a lowered model shape inferer map
- * @param[in] lowered_subgs lowered model subgraph map
+ * @brief Create a shape inferer map for a lowered model
+ * @param[in] lowered_subgs lowered model map
* @return Shape inferer map
*/
static std::unordered_map<ir::SubgraphIndex, std::unique_ptr<StaticShapeInferer>>
createStaticShapeInferers(
- const std::unordered_map<ir::SubgraphIndex, std::unique_ptr<LoweredGraph>> &lowered_subgs);
+ const std::unordered_map<ir::SubgraphIndex, ILoweredGraph *> &lowered_subgs);
private:
- bool checkDynamicInput(const ir::Operation &op);
- bool checkDynamicOutput(const ir::Operation &op);
- void setDynamicOutput(const ir::Operation &op);
+ bool checkDynamicInput(const ir::IOperation &op);
+ bool checkDynamicOutput(const ir::IOperation &op);
+ void setDynamicOutput(const ir::IOperation &op);
private:
// TODO Define visitors for operations. List them in alphabetic order.
void visit(const ir::operation::Gather &op) override;
void visit(const ir::operation::If &op) override;
void visit(const ir::operation::L2Normalization &op) override;
+ void visit(const ir::operation::Loss &op) override;
void visit(const ir::operation::LSTM &op) override;
void visit(const ir::operation::MatrixBandPart &op) override;
void visit(const ir::operation::OneHot &op) override;
void handleSimpleUnaryOp(const ir::Operation &op, const ir::OperandIndex input_idx);
private:
- compiler::LoweredGraph *_lowered_subg;
+ compiler::ILoweredGraph *_lowered_subg;
std::unordered_map<ir::SubgraphIndex, std::unique_ptr<OperandObserver>>
_subg_input_observers; // child subg input
std::unique_ptr<OperandObserver> _controlflow_output_observer; // parent controlflow op output