for (auto output : node.getOutputs())
{
- auto child =
- std::make_shared<neurun::dumper::dot::DotOperandInfo>(output, operands.at(output));
+ using neurun::dumper::dot::DotOperandInfo;
+ auto child = std::make_shared<DotOperandInfo>(output, operands.at(output),
+ DotOperandInfo::Type::MODEL_OUTPUT);
node_info.appendChild(child);
}
}
if (showing_cond)
{
- neurun::dumper::dot::DotOperandInfo operand_info(index, object);
+ auto type = [&]() {
+ using neurun::dumper::dot::DotOperandInfo;
+ if (_graph.getInputs().contains(index))
+ return DotOperandInfo::Type::MODEL_INPUT;
+ if (_graph.getOutputs().contains(index))
+ return DotOperandInfo::Type::MODEL_OUTPUT;
+ return DotOperandInfo::Type::INTERNAL;
+ }();
+
+ neurun::dumper::dot::DotOperandInfo operand_info(index, object, type);
+
for (auto operation_index : object.getUses().list())
{
auto &node = operations.at(operation_index);
{
const std::string DotOperandInfo::INPUT_SHAPE = "Mdiamond";
+const std::string DotOperandInfo::OUTPUT_SHAPE = "Mdiamond";
const std::string DotOperandInfo::OPERAND_SHAPE = "ellipse";
const std::string DotOperandInfo::BG_COLOR_SCHEME = "set38";
// RED BLUE ORANGE YELLOW GREEN PUPLE CYAN PINK
const std::string DotOperandInfo::BG_COLORS[8] = {"4", "5", "6", "2", "7", "3", "1", "8"};
DotOperandInfo::DotOperandInfo(const neurun::graph::operand::Index &index,
- const neurun::graph::operand::Object &object)
- : _index(index), _object(object)
+ const neurun::graph::operand::Object &object, Type type)
+ : _index(index), _object(object), _type(type)
{
const auto &lower_info = object.lower_info();
if (lower_info)
std::string DotOperandInfo::dot_shape() const
{
- if (_object.isModelInput())
+ switch (_type)
{
- return INPUT_SHAPE;
- }
+ case Type::MODEL_INPUT:
+ return INPUT_SHAPE;
+
+ case Type::MODEL_OUTPUT:
+ return OUTPUT_SHAPE;
- return OPERAND_SHAPE;
+ case Type::UNDEFINED:
+ case Type::INTERNAL:
+ default:
+ return OPERAND_SHAPE;
+ }
}
std::string DotOperandInfo::bg_color_scheme() const { return BG_COLOR_SCHEME; }
class DotOperandInfo : public IDotInfo
{
public:
+ enum class Type
+ {
+ UNDEFINED,
+ MODEL_INPUT,
+ MODEL_OUTPUT,
+ INTERNAL
+ };
+
+public:
static const std::string INPUT_SHAPE;
+ static const std::string OUTPUT_SHAPE;
static const std::string OPERAND_SHAPE;
static const std::string BG_COLOR_SCHEME;
static const std::string BG_COLORS[8];
public:
DotOperandInfo(const neurun::graph::operand::Index &index,
- const neurun::graph::operand::Object &object);
+ const neurun::graph::operand::Object &object, Type type);
public:
virtual std::string index_str() const override;
private:
const neurun::graph::operand::Index &_index;
const neurun::graph::operand::Object &_object;
+ Type _type;
std::vector<std::string> _labels;
};