#include <loco.h>
#include <memory>
+#include <vector>
namespace locomotiv
{
{
public:
Session() = delete;
- Session(loco::Graph *g) : _graph(g)
+
+ /// @brief Make Session for graph with graph outputs themselves
+ Session(loco::Graph *g) : Session(g, loco::output_nodes(g))
+ {
+ // DO NOTHING
+ }
+
+ /**
+ * @brief Make Session for graph with selective custom outputs. Only
+ * subgraph to calculate given outputs would be executed.
+ *
+ * @note Set required inputs for given outputs, or inference may fail.
+ * @note custom_outputs don't need to be graph output, but can be any nodes
+ * in the middle of the graph.
+ * @warn This approach may fail in case of graph with control flow
+ */
+ Session(loco::Graph *g, const std::vector<loco::Node *> &custom_outputs)
+ : _graph(g), _outputs(custom_outputs)
{
// DO NOTHING
}
void infer();
/// @brief Get number of graph outputs held by this Session
- uint32_t output_size() const { return _graph->outputs()->size(); }
+ uint32_t output_size() const { return _outputs.size(); }
/**
* @brief Get output of graph as NodeData
*/
const NodeData *get_output(uint32_t index);
+ const loco::Node *get_output_node(uint32_t index) { return _outputs.at(index); }
+
private:
loco::Graph *_graph;
+ std::vector<loco::Node *> _outputs;
};
} // namespace locomotiv
}
}
+TEST(Session, session_for_subgraph)
+{
+ /*
+ * Make following graph:
+ * ConstGen_1 --
+ * \
+ * ConstGen_2 --- TensorConcat_1 --- TensorConcat_3 --- Push
+ * /
+ * ConstGen_3 --- TensorConcat_2 --
+ * /
+ * ConstGen_4 --
+ */
+ auto g = loco::make_graph();
+
+ auto c1 = g->nodes()->create<loco::ConstGen>();
+ auto c2 = g->nodes()->create<loco::ConstGen>();
+ auto c3 = g->nodes()->create<loco::ConstGen>();
+ auto c4 = g->nodes()->create<loco::ConstGen>();
+
+ c1->dtype(loco::DataType::FLOAT32);
+ c2->dtype(loco::DataType::FLOAT32);
+ c3->dtype(loco::DataType::FLOAT32);
+ c4->dtype(loco::DataType::FLOAT32);
+ c1->shape({1});
+ c2->shape({1});
+ c3->shape({1});
+ c4->shape({1});
+ c1->size<loco::DataType::FLOAT32>(1);
+ c2->size<loco::DataType::FLOAT32>(1);
+ c3->size<loco::DataType::FLOAT32>(1);
+ c4->size<loco::DataType::FLOAT32>(1);
+
+ c1->at<loco::DataType::FLOAT32>(0) = 0.1f;
+ c2->at<loco::DataType::FLOAT32>(0) = 0.2f;
+ c3->at<loco::DataType::FLOAT32>(0) = 0.3f;
+ c4->at<loco::DataType::FLOAT32>(0) = 0.4f;
+
+ auto t1 = g->nodes()->create<loco::TensorConcat>();
+ auto t2 = g->nodes()->create<loco::TensorConcat>();
+ auto t3 = g->nodes()->create<loco::TensorConcat>();
+
+ // Note: default concat axis is 0
+ t1->lhs(c1);
+ t1->rhs(c2);
+ t2->lhs(c3);
+ t2->rhs(c4);
+ t3->lhs(t1);
+ t3->rhs(t2);
+
+ auto push = g->nodes()->create<loco::Push>();
+ push->from(t3);
+
+ {
+ // Session to get t1 only
+ locomotiv::Session s(g.get(), {t1});
+ ASSERT_EQ(s.output_size(), 1);
+ ASSERT_EQ(s.get_output_node(0), dynamic_cast<loco::Node *>(t1));
+
+ s.infer();
+
+ auto t1_data = s.get_output(0);
+ ASSERT_NE(t1_data, nullptr);
+ ASSERT_EQ(*(t1_data->shape()), Shape{2});
+
+ auto t1_buf = t1_data->as_f32_bufptr();
+ ASSERT_EQ(t1_buf->at({0}), 0.1f);
+ ASSERT_EQ(t1_buf->at({1}), 0.2f);
+ }
+
+ {
+ // Session to get t2 only
+ locomotiv::Session s(g.get(), {t2});
+ ASSERT_EQ(s.output_size(), 1);
+ ASSERT_EQ(s.get_output_node(0), dynamic_cast<loco::Node *>(t2));
+
+ s.infer();
+
+ auto t2_data = s.get_output(0);
+ ASSERT_NE(t2_data, nullptr);
+ ASSERT_EQ(*(t2_data->shape()), Shape{2});
+
+ auto t2_buf = t2_data->as_f32_bufptr();
+ ASSERT_EQ(t2_buf->at({0}), 0.3f);
+ ASSERT_EQ(t2_buf->at({1}), 0.4f);
+ }
+
+ {
+ // Session to get t2 and push
+ locomotiv::Session s(g.get(), {t2, push});
+ ASSERT_EQ(s.output_size(), 2);
+ ASSERT_EQ(s.get_output_node(0), dynamic_cast<loco::Node *>(t2));
+ ASSERT_EQ(s.get_output_node(1), dynamic_cast<loco::Node *>(push));
+
+ s.infer();
+
+ auto t2_data = s.get_output(0);
+ ASSERT_NE(t2_data, nullptr);
+ ASSERT_EQ(*(t2_data->shape()), Shape{2});
+
+ auto t2_buf = t2_data->as_f32_bufptr();
+ ASSERT_EQ(t2_buf->at({0}), 0.3f);
+ ASSERT_EQ(t2_buf->at({1}), 0.4f);
+
+ auto push_data = s.get_output(1);
+ ASSERT_NE(push_data, nullptr);
+ ASSERT_EQ(*(push_data->shape()), Shape{4});
+
+ auto push_buf = push_data->as_f32_bufptr();
+ ASSERT_EQ(push_buf->at({0}), 0.1f);
+ ASSERT_EQ(push_buf->at({1}), 0.2f);
+ ASSERT_EQ(push_buf->at({2}), 0.3f);
+ ASSERT_EQ(push_buf->at({3}), 0.4f);
+ }
+}
+
// Below here is internal test for locomotiv, i.e. not public usage of locomotiv
#include "NodeDataImpl.h"
#include "NodeDomain.h"