From: Сергей Баранников/AI Tools Lab /SRR/Engineer/삼성전자 Date: Fri, 6 Sep 2019 21:11:14 +0000 (+0300) Subject: [mir] Refactor graph dot dumper (#7273) X-Git-Tag: accepted/tizen/unified/20190911.111615~51 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=fc589ab0ef880554d3e5a4b3aa133e14a3b43e49;p=platform%2Fcore%2Fml%2Fnnfw.git [mir] Refactor graph dot dumper (#7273) Refactor graph dot dumper to loosen dependencies. Signed-off-by: Sergei Barannikov --- diff --git a/compiler/mir/CMakeLists.txt b/compiler/mir/CMakeLists.txt index 56a8037..9f6db34 100644 --- a/compiler/mir/CMakeLists.txt +++ b/compiler/mir/CMakeLists.txt @@ -1,29 +1,29 @@ set(MIR_SOURCES - src/ops/AvgPool2DOp.cpp - src/ops/BinaryElementwiseOp.cpp - src/ops/ConcatOp.cpp - src/ops/Conv2DOp.cpp - src/ops/DeConv2DOp.cpp - src/ops/DepthwiseConv2DOp.cpp - src/ops/FullyConnectedOp.cpp - src/ops/GatherOp.cpp - src/ops/MaxPool2DOp.cpp - src/ops/PadOp.cpp - src/ops/ReduceOp.cpp - src/ops/SqueezeOp.cpp - src/ops/SliceOp.cpp - src/ops/TransposeOp.cpp - src/Graph.cpp - src/Index.cpp - src/ir_dot_builder.cpp - src/IrDotDumper.cpp - src/GraphPatternMatcher.cpp - src/ir_dot_node_info.cpp - src/Operation.cpp - src/Shape.cpp - src/Tensor.cpp - src/TensorVariant.cpp - src/Visitor.cpp) + src/ops/AvgPool2DOp.cpp + src/ops/BinaryElementwiseOp.cpp + src/ops/ConcatOp.cpp + src/ops/Conv2DOp.cpp + src/ops/DeConv2DOp.cpp + src/ops/DepthwiseConv2DOp.cpp + src/ops/FullyConnectedOp.cpp + src/ops/GatherOp.cpp + src/ops/MaxPool2DOp.cpp + src/ops/PadOp.cpp + src/ops/ReduceOp.cpp + src/ops/SqueezeOp.cpp + src/ops/SliceOp.cpp + src/ops/TransposeOp.cpp + src/DotGraph.cpp + src/DotNodeBuilder.cpp + src/Graph.cpp + src/GraphPatternMatcher.cpp + src/Index.cpp + src/IrDotDumper.cpp + src/Operation.cpp + src/Shape.cpp + src/Tensor.cpp + src/TensorVariant.cpp + src/Visitor.cpp) add_library(mir STATIC ${MIR_SOURCES}) target_include_directories(mir PUBLIC include) diff --git a/compiler/mir/include/mir/DataFormat.h b/compiler/mir/include/mir/DataFormat.h index 5979d08..44edfa8 100644 --- a/compiler/mir/include/mir/DataFormat.h +++ b/compiler/mir/include/mir/DataFormat.h @@ -17,6 +17,9 @@ #ifndef _MIR_DATA_FORMAT_H_ #define _MIR_DATA_FORMAT_H_ +#include +#include + namespace mir { @@ -68,6 +71,20 @@ inline int getDataSpatialDimIndex(DataFormat data_format, int dim) } } +inline std::string toString(DataFormat data_format) +{ + switch (data_format) + { + case DataFormat::NCHW: + return "NCHW"; + case DataFormat::NHWC: + return "NHWC"; + default: + assert(false); + return ""; // Dummy value to silence compiler warning. + } +} + } // namespace mir #endif //_MIR_DATA_FORMAT_H_ diff --git a/compiler/mir/include/mir/IrDotDumper.h b/compiler/mir/include/mir/IrDotDumper.h index ae86c16..e6c295c 100644 --- a/compiler/mir/include/mir/IrDotDumper.h +++ b/compiler/mir/include/mir/IrDotDumper.h @@ -17,58 +17,12 @@ #ifndef _MIR_IR_DOT_DUMPER_H_ #define _MIR_IR_DOT_DUMPER_H_ -#include "mir/Visitor.h" -#include "mir/ir_dot_builder.h" +#include namespace mir { -/** - * @brief Model IR visitor that can be used to output Model IR as a .dot graph. - * @usage Run on a Model IR graph as a visitor, and then call writeDot passing it a stream - */ -class IrDotDumper : public IVisitor -{ -public: - void visit(ops::AddOp &op) override; - void visit(ops::AvgPool2DOp &op) override; - void visit(ops::CappedReluOp &op) override; - void visit(ops::ConcatOp &op) override; - void visit(ops::ConstantOp &op) override; - void visit(ops::Conv2DOp &op) override; - void visit(ops::DeConv2DOp &op) override; - void visit(ops::DepthwiseConv2DOp &op) override; - void visit(ops::DivOp &op) override; - void visit(ops::EluOp &op) override; - void visit(ops::FullyConnectedOp &op) override; - void visit(ops::GatherOp &op) override; - void visit(ops::InputOp &op) override; - void visit(ops::LeakyReluOp &op) override; - void visit(ops::MaxOp &op) override; - void visit(ops::MaxPool2DOp &op) override; - void visit(ops::MulOp &op) override; - void visit(ops::OutputOp &op) override; - void visit(ops::PadOp &op) override; - void visit(ops::ReduceMeanOp &op) override; - void visit(ops::ReluOp &op) override; - void visit(ops::ReshapeOp &op) override; - void visit(ops::ResizeOp &op) override; - void visit(ops::SigmoidOp &op) override; - void visit(ops::SliceOp &op) override; - void visit(ops::SoftmaxOp &op) override; - void visit(ops::SqrtOp &op) override; - void visit(ops::SqueezeOp &op) override; - void visit(ops::SubOp &op) override; - void visit(ops::TanhOp &op) override; - void visit(ops::TransposeOp &op) override; - - void writeDot(std::ostream &os) { _dot_builder.writeDot(os); }; - -private: - void addSimpleOperation(const Operation &op); - - IrDotBuilder _dot_builder; -}; +class Graph; void dumpGraph(const Graph *graph, std::ostream &stream); diff --git a/compiler/mir/include/mir/Shape.h b/compiler/mir/include/mir/Shape.h index f3082e8..fbd4951 100644 --- a/compiler/mir/include/mir/Shape.h +++ b/compiler/mir/include/mir/Shape.h @@ -20,8 +20,6 @@ #include #include #include -#include -#include #include "adtidas/SmallVector.h" #include "mir/Common.h" @@ -68,7 +66,7 @@ private: adt::small_vector _dims; }; -std::ostream &operator<<(std::ostream &s, const Shape &sh); +std::string toString(const Shape &shape); } // namespace mir diff --git a/compiler/mir/include/mir/ir_dot_node_info.h b/compiler/mir/include/mir/ir_dot_node_info.h deleted file mode 100644 index 8852dd2..0000000 --- a/compiler/mir/include/mir/ir_dot_node_info.h +++ /dev/null @@ -1,115 +0,0 @@ -/* - * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef _MIR_IR_DOT_NODE_INFO_H_ -#define _MIR_IR_DOT_NODE_INFO_H_ - -#include "mir/Operation.h" -#include "mir/Shape.h" -#include "mir/ops/CommonProps.h" - -namespace mir -{ - -/** - * @brief Can collect information about a NN operator, and then use it to output - * this info as a node label in .dot format. - * @usage Provides a typical builder interface for collecting NN operator info, for example: - * auto info = DotIrNodeInfo().withType("Softmax").withInShapes({{1, 2, 3}, {2, 3, - * 4}}).withMisc("Axis", 0)); - * Then resulting .dot node label is accessed with info.getLabel(); - */ -class DotIrNodeInfo -{ -public: - using Shapes = std::vector; - using NamedShape = std::pair; - using MiscVal = std::pair; - using PadType = ops::PaddingType; - - class Stringable - { - public: - template - /*implicit*/ Stringable( - T val); // NOLINT(google-explicit-constructor, hicpp-explicit-conversions) - - /*implicit*/ operator std::string &&() - { // NOLINT(google-explicit-constructor, hicpp-explicit-conversions) - return std::move(_val); - } - - private: - std::string _val; - }; - - DotIrNodeInfo() = default; - - DotIrNodeInfo &withType(const std::string &type_name); - DotIrNodeInfo &withInShapes(Shapes &&in_shapes); - DotIrNodeInfo &withOutShapes(Shapes &&out_shapes); - - DotIrNodeInfo &withStride(const Shape &stride_shape); - DotIrNodeInfo &withShape(const std::string &shape_name, const Shape &shape); - DotIrNodeInfo &withPadType(PadType pad_type); - DotIrNodeInfo &withMisc(const std::string &name, Stringable val); - - /** - * Create a label in dot format for the Model IR node. - * Label is created in the form of the table with node name and type on top, - * then in and out shapes in the left column, and other parameters on the right column. - * - * Dot syntax for tables: - * - pipe ("|") symbol adds another line/column; by default it adds columns - * - when something gets inside of "{}", it changes what the pipe adds - * Example: label="leftCol | middleCol | {firstRowInRightCol | secondRowInRightCol }" - */ - std::string getLabel() const; - -private: - void writeInShapesLabel(std::stringstream &ss) const; - void writeOutShapesLabel(std::stringstream &ss) const; - - std::string labelForPadAndPool() const; - std::string labelForNodeParams() const; - void addPipeIfNeeded(std::stringstream &ss, bool needed, bool &need_pipe) const; - - std::string _type_name; - - Shapes _in_shapes; - Shapes _out_shapes; - - Shape _kernel_shape; - - Shape _stride_shape; - std::vector _shapes; - std::vector _misc_vals; - - bool _has_pad = false; - PadType _pad_type = PadType::Valid; - - bool _has_pool = false; -}; - -template DotIrNodeInfo::Stringable::Stringable(T val) : _val(std::to_string(val)) {} - -template <> DotIrNodeInfo::Stringable::Stringable(std::string val); - -template <> DotIrNodeInfo::Stringable::Stringable(const char *val); - -} // namespace mir - -#endif //_MIR_IR_DOT_NODE_INFO_H_ diff --git a/compiler/mir/src/ir_dot_builder.cpp b/compiler/mir/src/DotGraph.cpp similarity index 50% rename from compiler/mir/src/ir_dot_builder.cpp rename to compiler/mir/src/DotGraph.cpp index e085275..b0e92e1 100644 --- a/compiler/mir/src/ir_dot_builder.cpp +++ b/compiler/mir/src/DotGraph.cpp @@ -14,33 +14,28 @@ * limitations under the License. */ -#include "mir/ir_dot_builder.h" +#include "DotGraph.h" namespace mir { -void IrDotBuilder::updateWithOp(const Operation *op, const DotIrNodeInfo &ir_node_info) -{ - addNode(op, ir_node_info); - for (auto &prev : op->getInputs()) - { - addEdge(prev.getProducer()->getNode(), op); - } -} +void DotGraph::addNode(DotNode node) { _nodes.emplace_back(std::move(node)); } -void IrDotBuilder::writeDot(std::ostream &os) -{ - os << "digraph D {" << std::endl << dot.str() << std::endl << "}" << std::endl; -} +void DotGraph::addEdge(DotEdge edge) { _edges.emplace_back(edge); } -void IrDotBuilder::addNode(const Operation *op, const DotIrNodeInfo &ir_node) +std::ostream &operator<<(std::ostream &stream, const DotGraph &graph) { - dot << op->getId() << " [shape=record label=\"" << ir_node.getLabel() << "\"];" << std::endl; -} - -void IrDotBuilder::addEdge(const Operation *op1, const Operation *op2) -{ - dot << op1->getId() << " -> " << op2->getId() << ";" << std::endl; + stream << "digraph D {" << std::endl; + for (const auto &node : graph._nodes) + { + stream << node.id << " [shape=record label=\"" << node.label << "\"];" << std::endl; + } + for (const auto &edge : graph._edges) + { + stream << edge.src_id << " -> " << edge.dst_id << ";" << std::endl; + } + stream << "}" << std::endl; + return stream; } } // namespace mir diff --git a/compiler/mir/include/mir/ir_dot_builder.h b/compiler/mir/src/DotGraph.h similarity index 54% rename from compiler/mir/include/mir/ir_dot_builder.h rename to compiler/mir/src/DotGraph.h index 311ec71..29698bb 100644 --- a/compiler/mir/include/mir/ir_dot_builder.h +++ b/compiler/mir/src/DotGraph.h @@ -14,35 +14,42 @@ * limitations under the License. */ -#ifndef _MIR_IR_DOT_BUILDER_H_ -#define _MIR_IR_DOT_BUILDER_H_ +#ifndef _MIR_DOT_GRAPH_ +#define _MIR_DOT_GRAPH_ +#include #include - -#include "mir/ir_dot_node_info.h" +#include +#include namespace mir { -/** - * @brief Provides an API to add nodes and edges to the .dot Model IR representation - * and then write the whole graph to a provided stream. - */ -class IrDotBuilder +struct DotNode +{ + std::size_t id; + std::string label; +}; + +struct DotEdge +{ + std::size_t src_id; + std::size_t dst_id; +}; + +class DotGraph { public: - explicit IrDotBuilder() = default; + void addNode(DotNode node); + void addEdge(DotEdge edge); - void updateWithOp(const Operation *op, const DotIrNodeInfo &ir_node_info); - void writeDot(std::ostream &os); + friend std::ostream &operator<<(std::ostream &stream, const DotGraph &graph); private: - void addNode(const Operation *op, const DotIrNodeInfo &ir_node); - void addEdge(const Operation *op1, const Operation *op2); - - std::stringstream dot; + std::vector _nodes; + std::vector _edges; }; } // namespace mir -#endif //_MIR_IR_DOT_BUILDER_H_ +#endif //_MIR_DOT_GRAPH_ diff --git a/compiler/mir/src/DotNodeBuilder.cpp b/compiler/mir/src/DotNodeBuilder.cpp new file mode 100644 index 0000000..0afaa37 --- /dev/null +++ b/compiler/mir/src/DotNodeBuilder.cpp @@ -0,0 +1,203 @@ +/* + * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "DotNodeBuilder.h" +#include "mir/OpDefs.h" + +#include + +namespace mir +{ + +template static std::string toString(const std::vector &v) +{ + std::stringstream ss; + ss << "["; + for (std::size_t i = 0; i < v.size(); ++i) + { + if (i != 0) + ss << ", "; + ss << v[i]; + } + return ss.str(); +} + +DotNodeBuilder::DotNodeBuilder(const Operation &op) +{ + _type_name = getTypeName(op.getType()); + _id = op.getId(); + + for (std::size_t i = 0; i < op.getNumInputs(); ++i) + { + _in_shapes.push_back(toString(op.getInputShape(i))); + } + + for (std::size_t i = 0; i < op.getNumOutputs(); ++i) + { + _out_shapes.push_back(toString(op.getOutputShape(i))); + } + + // Get attributes. + const_cast(op).accept(this); +} + +void DotNodeBuilder::visit(ops::AvgPool2DOp &op) +{ + addAttribute("Window size", toString(op.getWindowSize())); + addAttribute("Strides", toString(op.getStrides())); + addAttribute("Padding before", toString(op.getPaddingBefore())); + addAttribute("Padding after", toString(op.getPaddingAfter())); + addAttribute("Include pad", std::to_string(op.getIncludePad())); +} + +void DotNodeBuilder::visit(ops::CappedReluOp &op) +{ + addAttribute("Cap", std::to_string(op.getCap())); +} + +void DotNodeBuilder::visit(ops::ConcatOp &op) +{ + addAttribute("Axis", std::to_string(op.getAxis())); +} + +void DotNodeBuilder::visit(ops::Conv2DOp &op) +{ + addAttribute("Strides", toString(op.getStrides())); + addAttribute("Padding before", toString(op.getPaddingBefore())); + addAttribute("Padding after", toString(op.getPaddingAfter())); + addAttribute("Data format", toString(op.getDataFormat())); +} + +void DotNodeBuilder::visit(ops::DepthwiseConv2DOp &op) +{ + addAttribute("Strides", toString(op.getStrides())); + addAttribute("Padding before", toString(op.getPaddingBefore())); + addAttribute("Padding after", toString(op.getPaddingAfter())); + addAttribute("Data format", toString(op.getDataFormat())); +} + +void DotNodeBuilder::visit(ops::MaxPool2DOp &op) +{ + addAttribute("Window size", toString(op.getWindowSize())); + addAttribute("Strides", toString(op.getStrides())); + addAttribute("Padding before", toString(op.getPaddingBefore())); + addAttribute("Padding after", toString(op.getPaddingAfter())); + addAttribute("Data format", toString(op.getDataFormat())); +} + +void DotNodeBuilder::visit(ops::SoftmaxOp &op) +{ + addAttribute("Axis", std::to_string(op.getAxis())); +} + +void DotNodeBuilder::visit(ops::SliceOp &op) +{ + addAttribute("Starts", toString(op.getStarts())); + addAttribute("Sizes", toString(op.getSizes())); +} + +void DotNodeBuilder::visit(ops::DeConv2DOp &op) +{ + addAttribute("Padding before", toString(op.getPaddingBefore())); + addAttribute("Padding after", toString(op.getPaddingAfter())); + addAttribute("Strides", toString(op.getStrides())); + addAttribute("Data format", toString(op.getDataFormat())); +} + +void DotNodeBuilder::visit(ops::EluOp &op) { addAttribute("Alpha", std::to_string(op.getAlpha())); } + +void DotNodeBuilder::visit(ops::SqueezeOp &op) +{ + addAttribute("Dims to squeeze", toString(op.getDimsToSqueeze())); +} + +void mir::DotNodeBuilder::visit(ops::PadOp &op) +{ + addAttribute("Padding before", toString(op.getPaddingBefore())); + addAttribute("Padding after", toString(op.getPaddingAfter())); + addAttribute("Padding value", std::to_string(op.getPaddingValue())); +} + +void DotNodeBuilder::visit(ops::ReduceMeanOp &op) +{ + addAttribute("Reduction dims", toString(op.getReductionDims())); + addAttribute("Keep dims", std::to_string(op.getKeepDims())); +} + +void DotNodeBuilder::visit(ops::ResizeOp &op) +{ + assert(op.getMode() == ops::ResizeOp::ResizeMethod::nearestNeighbor); + + addAttribute("Interpolation mode", "nearestNeighbor"); +} + +void DotNodeBuilder::visit(ops::TransposeOp &op) +{ + addAttribute("Axis order", toString(op.getAxisOrder())); +} + +void DotNodeBuilder::visit(ops::GatherOp &op) +{ + addAttribute("Axis", std::to_string(op.getAxis())); +} + +void DotNodeBuilder::visit(mir::ops::LeakyReluOp &op) +{ + addAttribute("Alpha", std::to_string(op.getAlpha())); +} + +void DotNodeBuilder::addAttribute(std::string name, std::string val) +{ + this->_attributes.emplace_back(std::move(name), std::move(val)); +} + +std::string DotNodeBuilder::getLabel() const +{ + std::stringstream ss; + + ss << "{" << _type_name << " | {{"; + + for (std::size_t i = 0; i < _in_shapes.size(); ++i) + { + if (i != 0) + ss << " | "; + ss << "in" << i << ": " << _in_shapes[i]; + } + + ss << " | "; + + for (std::size_t i = 0; i < _out_shapes.size(); ++i) + { + if (i != 0) + ss << " | "; + ss << "out" << i << ": " << _out_shapes[i]; + } + + ss << "} | {"; + + for (std::size_t i = 0; i < _attributes.size(); ++i) + { + if (i != 0) + ss << " | "; + ss << _attributes[i].first << ": " << _attributes[i].second; + } + + ss << "}}}"; + + return ss.str(); +} + +} // namespace mir diff --git a/compiler/mir/src/DotNodeBuilder.h b/compiler/mir/src/DotNodeBuilder.h new file mode 100644 index 0000000..09eba0c --- /dev/null +++ b/compiler/mir/src/DotNodeBuilder.h @@ -0,0 +1,70 @@ +/* + * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef _MIR_DOT_NODE_BUILDER_H_ +#define _MIR_DOT_NODE_BUILDER_H_ + +#include "DotGraph.h" + +#include "mir/Visitor.h" + +#include +#include +#include + +namespace mir +{ + +class DotNodeBuilder : public Visitor +{ +public: + explicit DotNodeBuilder(const Operation &op); + + void visit(ops::AvgPool2DOp &op) override; + void visit(ops::CappedReluOp &op) override; + void visit(ops::ConcatOp &op) override; + void visit(ops::Conv2DOp &op) override; + void visit(ops::DeConv2DOp &op) override; + void visit(ops::DepthwiseConv2DOp &op) override; + void visit(ops::EluOp &op) override; + void visit(ops::GatherOp &op) override; + void visit(ops::LeakyReluOp &op) override; + void visit(ops::MaxPool2DOp &op) override; + void visit(ops::PadOp &op) override; + void visit(ops::ReduceMeanOp &op) override; + void visit(ops::ResizeOp &op) override; + void visit(ops::SliceOp &op) override; + void visit(ops::SoftmaxOp &op) override; + void visit(ops::SqueezeOp &op) override; + void visit(ops::TransposeOp &op) override; + + void addAttribute(std::string name, std::string val); + + DotNode getDotNode() const { return {_id, getLabel()}; } + +private: + std::string getLabel() const; + + std::size_t _id; + std::string _type_name; + std::vector _in_shapes; + std::vector _out_shapes; + std::vector> _attributes; +}; + +} // namespace mir + +#endif //_MIR_DOT_NODE_BUILDER_H_ diff --git a/compiler/mir/src/IrDotDumper.cpp b/compiler/mir/src/IrDotDumper.cpp index 829b53a..e40255c 100644 --- a/compiler/mir/src/IrDotDumper.cpp +++ b/compiler/mir/src/IrDotDumper.cpp @@ -14,351 +14,28 @@ * limitations under the License. */ -#include "mir/Graph.h" #include "mir/IrDotDumper.h" -#include "mir/OpDefs.h" - -#include +#include "mir/Graph.h" +#include "DotGraph.h" +#include "DotNodeBuilder.h" namespace mir { -static std::vector getInputShapes(const Operation &op) -{ - std::vector shapes; - for (std::size_t i = 0; i < op.getNumInputs(); ++i) - { - shapes.push_back(op.getInputShape(i)); - } - return shapes; -} - -static std::vector getOutputShapes(const Operation &op) -{ - std::vector shapes; - for (std::size_t i = 0; i < op.getNumOutputs(); ++i) - { - shapes.push_back(op.getOutputShape(i)); - } - return shapes; -} - -void IrDotDumper::visit(ops::AddOp &op) { addSimpleOperation(op); } - -void IrDotDumper::visit(ops::AvgPool2DOp &op) -{ - auto nodeInfo = DotIrNodeInfo() - .withType("AvgPool2D") - .withInShapes(getInputShapes(op)) - .withOutShapes(getOutputShapes(op)) - .withShape("Window size", Shape(op.getWindowSize())) - .withStride(Shape(op.getStrides())) - .withShape("Padding before", Shape(op.getPaddingBefore())) - .withShape("Padding after", Shape(op.getPaddingAfter())); - - _dot_builder.updateWithOp(&op, nodeInfo); -} - -void IrDotDumper::visit(ops::CappedReluOp &op) -{ - auto nodeInfo = DotIrNodeInfo() - .withType("CappedRelu") - .withInShapes(getInputShapes(op)) - .withOutShapes(getOutputShapes(op)) - .withMisc("Cap", op.getCap()); - - _dot_builder.updateWithOp(&op, nodeInfo); -} - -void IrDotDumper::visit(ops::ConcatOp &op) -{ - auto nodeInfo = DotIrNodeInfo() - .withType("Concat") - .withInShapes(getInputShapes(op)) - .withOutShapes(getOutputShapes(op)) - .withMisc("Axis", op.getAxis()); - - _dot_builder.updateWithOp(&op, nodeInfo); -} - -void IrDotDumper::visit(ops::Conv2DOp &op) -{ - auto nodeInfo = DotIrNodeInfo() - .withType("Conv2D") - .withInShapes(getInputShapes(op)) - .withOutShapes(getOutputShapes(op)) - .withStride(Shape(op.getStrides())) - .withShape("Padding before", Shape(op.getPaddingBefore())) - .withShape("Padding after", Shape(op.getPaddingAfter())); - - _dot_builder.updateWithOp(&op, nodeInfo); -} - -void IrDotDumper::visit(ops::DepthwiseConv2DOp &op) -{ - auto nodeInfo = DotIrNodeInfo() - .withType("DepthwiseConv2D") - .withInShapes(getInputShapes(op)) - .withOutShapes(getOutputShapes(op)) - .withStride(Shape(op.getStrides())) - .withShape("Padding before", Shape(op.getPaddingBefore())) - .withShape("Padding after", Shape(op.getPaddingAfter())); - - _dot_builder.updateWithOp(&op, nodeInfo); -} - -void IrDotDumper::visit(ops::DivOp &op) { addSimpleOperation(op); } - -void IrDotDumper::visit(ops::FullyConnectedOp &op) -{ - auto nodeInfo = DotIrNodeInfo() - .withType("FullyConnected") - .withInShapes(getInputShapes(op)) - .withOutShapes(getOutputShapes(op)); - - _dot_builder.updateWithOp(&op, nodeInfo); -} - -void IrDotDumper::visit(ops::MaxPool2DOp &op) -{ - auto nodeInfo = DotIrNodeInfo() - .withType("MaxPool2D") - .withInShapes(getInputShapes(op)) - .withOutShapes(getOutputShapes(op)) - .withShape("Window size", Shape(op.getWindowSize())) - .withStride(Shape(op.getStrides())) - .withShape("Padding before", Shape(op.getPaddingBefore())) - .withShape("Padding after", Shape(op.getPaddingAfter())); - - _dot_builder.updateWithOp(&op, nodeInfo); -} - -void IrDotDumper::visit(ops::SoftmaxOp &op) -{ - auto nodeInfo = DotIrNodeInfo() - .withType("Softmax") - .withInShapes(getInputShapes(op)) - .withOutShapes(getOutputShapes(op)) - .withMisc("Axis", op.getAxis()); - - _dot_builder.updateWithOp(&op, nodeInfo); -} - -void IrDotDumper::visit(ops::ReluOp &op) -{ - auto nodeInfo = DotIrNodeInfo() - .withType("ReLU") - .withInShapes(getInputShapes(op)) - .withOutShapes(getOutputShapes(op)); - - _dot_builder.updateWithOp(&op, nodeInfo); -} - -void IrDotDumper::visit(ops::ReshapeOp &op) -{ - auto nodeInfo = DotIrNodeInfo() - .withType("Reshape") - .withInShapes(getInputShapes(op)) - .withOutShapes(getOutputShapes(op)); - - _dot_builder.updateWithOp(&op, nodeInfo); -} - -void IrDotDumper::visit(ops::InputOp &op) -{ - auto nodeInfo = DotIrNodeInfo() - .withType("Input") - .withInShapes(getInputShapes(op)) - .withOutShapes(getOutputShapes(op)); - - _dot_builder.updateWithOp(&op, nodeInfo); -} - -void IrDotDumper::visit(ops::ConstantOp &op) -{ - auto node_info = DotIrNodeInfo() - .withType("Constant") - .withInShapes(getInputShapes(op)) - .withOutShapes(getOutputShapes(op)); - - _dot_builder.updateWithOp(&op, node_info); -} - -void IrDotDumper::visit(ops::SliceOp &op) -{ - auto node_info = DotIrNodeInfo() - .withType("SliceOp") - .withInShapes(getInputShapes(op)) - .withShape("Starts", op.getStarts()) - .withShape("Sizes", op.getSizes()) - .withOutShapes(getOutputShapes(op)); - - _dot_builder.updateWithOp(&op, node_info); -} - -void IrDotDumper::visit(ops::DeConv2DOp &op) -{ - auto node_info = DotIrNodeInfo() - .withType("DeConv2D") - .withInShapes(getInputShapes(op)) - .withOutShapes(getOutputShapes(op)) - .withPadType(op.getPaddingType()) - .withStride(Shape(op.getStrides())); - - _dot_builder.updateWithOp(&op, node_info); -} - -void IrDotDumper::visit(ops::EluOp &op) -{ - auto nodeInfo = DotIrNodeInfo() - .withType("EluOp") - .withInShapes(getInputShapes(op)) - .withOutShapes(getOutputShapes(op)) - .withMisc("Alpha", op.getAlpha()); -} - -void IrDotDumper::visit(ops::TanhOp &op) -{ - auto nodeInfo = DotIrNodeInfo() - .withType("TanhOp") - .withInShapes(getInputShapes(op)) - .withOutShapes(getOutputShapes(op)); - - _dot_builder.updateWithOp(&op, nodeInfo); -} - -void IrDotDumper::visit(ops::SqueezeOp &op) +void dumpGraph(const Graph *graph, std::ostream &stream) { - auto node_info = DotIrNodeInfo() - .withType("SqueezeOp") - .withInShapes(getInputShapes(op)) - .withOutShapes(getOutputShapes(op)); + DotGraph dot_graph; - for (auto dim : op.getDimsToSqueeze()) + for (const auto *node : graph->getNodes()) { - node_info.withMisc("SqueezeDim", dim); + dot_graph.addNode(DotNodeBuilder(*node).getDotNode()); + for (const auto &input : node->getInputs()) + { + dot_graph.addEdge({input.getProducer()->getNode()->getId(), node->getId()}); + } } - _dot_builder.updateWithOp(&op, node_info); -} - -void IrDotDumper::visit(ops::SubOp &op) { addSimpleOperation(op); } - -void mir::IrDotDumper::visit(ops::PadOp &op) -{ - auto node_info = DotIrNodeInfo() - .withType("PadOp") - .withInShapes(getInputShapes(op)) - .withOutShapes(getOutputShapes(op)); - - _dot_builder.updateWithOp(&op, node_info); -} - -void IrDotDumper::visit(ops::SqrtOp &op) -{ - auto node_info = DotIrNodeInfo() - .withType("Sqrt") - .withInShapes(getInputShapes(op)) - .withOutShapes(getOutputShapes(op)); - - _dot_builder.updateWithOp(&op, node_info); -} - -void IrDotDumper::visit(ops::ReduceMeanOp &op) -{ - auto node_info = - DotIrNodeInfo() - .withType("ReduceMeanOp") - .withInShapes(getInputShapes(op)) - .withOutShapes(getOutputShapes(op)) - .withShape("Reduction dims", Shape(op.getReductionDims())) // appropriated shape to dims - .withMisc("Keep dims", op.getKeepDims()); - - _dot_builder.updateWithOp(&op, node_info); -} - -void IrDotDumper::visit(ops::ResizeOp &op) -{ - static const std::map modes{ - {ops::ResizeOp::ResizeMethod::nearestNeighbor, "nearestNeighbor"}}; - - auto node_info = DotIrNodeInfo() - .withType("Resize") - .withInShapes(getInputShapes(op)) - .withOutShapes(getOutputShapes(op)) - .withMisc("Mode", modes.at(op.getMode())); - // scale is only needed in Shape Inference - - _dot_builder.updateWithOp(&op, node_info); -} - -void IrDotDumper::visit(ops::TransposeOp &op) -{ - auto node_info = DotIrNodeInfo() - .withType("TransposeOp") - .withInShapes(getInputShapes(op)) - .withOutShapes(getOutputShapes(op)); - - _dot_builder.updateWithOp(&op, node_info); -} - -void IrDotDumper::visit(ops::GatherOp &op) -{ - auto node_info = DotIrNodeInfo() - .withType("GatherOp") - .withInShapes(getInputShapes(op)) - .withOutShapes(getOutputShapes(op)); - - _dot_builder.updateWithOp(&op, node_info); -} - -void IrDotDumper::visit(ops::SigmoidOp &op) -{ - auto node_info = DotIrNodeInfo() - .withType("SigmoidOp") - .withInShapes(getInputShapes(op)) - .withOutShapes(getOutputShapes(op)); - - _dot_builder.updateWithOp(&op, node_info); -} - -void IrDotDumper::visit(mir::ops::LeakyReluOp &op) -{ - auto node_info = DotIrNodeInfo() - .withType("LeakyReluOp") - .withInShapes(getInputShapes(op)) - .withOutShapes(getOutputShapes(op)) - .withMisc("alpha", op.getAlpha()); - - _dot_builder.updateWithOp(&op, node_info); -} - -void IrDotDumper::visit(ops::MaxOp &op) { addSimpleOperation(op); } - -void IrDotDumper::visit(ops::MulOp &op) { addSimpleOperation(op); } - -void IrDotDumper::visit(ops::OutputOp &op) -{ - auto node_info = DotIrNodeInfo().withType("OutputOp").withInShapes(getInputShapes(op)); - - _dot_builder.updateWithOp(&op, node_info); -} - -void IrDotDumper::addSimpleOperation(const Operation &op) -{ - auto node_info = DotIrNodeInfo() - .withType(getTypeName(op.getType())) - .withInShapes(getInputShapes(op)) - .withOutShapes(getOutputShapes(op)); - - _dot_builder.updateWithOp(&op, node_info); -} - -void dumpGraph(const Graph *graph, std::ostream &stream) -{ - IrDotDumper dumper; - const_cast(graph)->accept(&dumper); - dumper.writeDot(stream); + stream << dot_graph; } } // namespace mir diff --git a/compiler/mir/src/Shape.cpp b/compiler/mir/src/Shape.cpp index 0ef79b5..d6250e2 100644 --- a/compiler/mir/src/Shape.cpp +++ b/compiler/mir/src/Shape.cpp @@ -18,6 +18,7 @@ #include #include +#include namespace mir { @@ -39,22 +40,23 @@ int32_t Shape::numElements() const return res; } -std::ostream &operator<<(std::ostream &s, const Shape &sh) +std::string toString(const Shape &shape) { - int32_t rank = sh.rank(); - s << "["; + std::stringstream ss; - for (int32_t axis = 0; axis < rank; ++axis) + ss << "["; + for (int32_t axis = 0; axis < shape.rank(); ++axis) { if (axis != 0) - s << ", "; - if (sh.dim(axis) == Shape::autoDim) - s << "AUTO"; + ss << ", "; + if (shape.dim(axis) == Shape::autoDim) + ss << "AUTO"; else - s << sh.dim(axis); + ss << shape.dim(axis); } - s << "]"; - return s; + ss << "]"; + + return ss.str(); } } // namespace mir diff --git a/compiler/mir/src/ir_dot_node_info.cpp b/compiler/mir/src/ir_dot_node_info.cpp deleted file mode 100644 index e0968ff..0000000 --- a/compiler/mir/src/ir_dot_node_info.cpp +++ /dev/null @@ -1,210 +0,0 @@ -/* - * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -#include "mir/ir_dot_node_info.h" - -namespace mir -{ - -template <> DotIrNodeInfo::Stringable::Stringable(std::string val) : _val(std::move(val)) {} - -template <> DotIrNodeInfo::Stringable::Stringable(const char *val) : _val(val) {} - -DotIrNodeInfo &DotIrNodeInfo::withType(const std::string &type_name) -{ - this->_type_name = type_name; - return *this; -} - -DotIrNodeInfo &DotIrNodeInfo::withInShapes(DotIrNodeInfo::Shapes &&in_shapes) -{ - this->_in_shapes = in_shapes; - return *this; -} - -DotIrNodeInfo &DotIrNodeInfo::withOutShapes(DotIrNodeInfo::Shapes &&out_shapes) -{ - this->_out_shapes = out_shapes; - return *this; -} - -DotIrNodeInfo &DotIrNodeInfo::withStride(const Shape &stride_shape) -{ - this->_stride_shape = stride_shape; - return *this; -} - -/** - * @brief Allows dumping arbitrary parameters from layer - * The values that are actually integers get dumped as integers. - */ -DotIrNodeInfo &DotIrNodeInfo::withMisc(const std::string &name, Stringable val) -{ - this->_misc_vals.emplace_back(name, std::move(val)); - return *this; -} - -DotIrNodeInfo &DotIrNodeInfo::withShape(const std::string &shape_name, const Shape &shape) -{ - this->_shapes.emplace_back(shape_name, shape); - return *this; -} - -DotIrNodeInfo &DotIrNodeInfo::withPadType(DotIrNodeInfo::PadType pad_type) -{ - this->_pad_type = pad_type; - this->_has_pad = true; - return *this; -} - -std::string DotIrNodeInfo::getLabel() const -{ - std::stringstream ss; - - ss << "{" << _type_name << " | "; - - // Note inputs and output shapes - ss << "{{"; - writeInShapesLabel(ss); - ss << " | "; - writeOutShapesLabel(ss); - ss << "}"; - - // Other node parameters - kernel shape, stride, padding type etc - std::string label = labelForNodeParams(); - if (!label.empty()) - ss << " | {" << label << "}"; - - ss << "}}"; - - return ss.str(); -} - -std::string DotIrNodeInfo::labelForPadAndPool() const -{ - std::stringstream ss; - - if (_has_pad) - { - ss << "{PadType: "; - switch (_pad_type) - { - case PadType::Valid: - ss << "VALID"; - break; - case PadType::Same: - ss << "SAME"; - break; - case PadType::Custom: - ss << "CUSTOM"; - break; - default: - assert(false && "Unknown Padding type"); - break; - } - - ss << "}"; - } - - return ss.str(); -} - -void DotIrNodeInfo::writeInShapesLabel(std::stringstream &ss) const -{ - if (_in_shapes.empty()) - ss << "IN_SHAPES_NOT_SET"; - else - { - for (Shapes::size_type i = 0; i < _in_shapes.size(); ++i) - { - if (i != 0) - ss << " | "; - ss << "in" << i << ": " << _in_shapes[i]; - } - } -} - -void DotIrNodeInfo::writeOutShapesLabel(std::stringstream &ss) const -{ - if (_out_shapes.empty()) - ss << "OUT_SHAPES_NOT_SET"; - else - { - for (Shapes::size_type i = 0; i < _out_shapes.size(); ++i) - { - if (i != 0) - ss << "| "; - ss << "out" << i << ": " << _out_shapes[i]; - } - } -} - -std::string DotIrNodeInfo::labelForNodeParams() const -{ - std::stringstream ss; - - bool need_pipe = false; - if (_kernel_shape.rank() != 0) - { - ss << "Kernel: " << _kernel_shape; - need_pipe = true; - } - - std::string label = labelForPadAndPool(); - addPipeIfNeeded(ss, !label.empty(), need_pipe); - ss << label; - - addPipeIfNeeded(ss, !_shapes.empty(), need_pipe); - for (Shapes::size_type i = 0; i < _shapes.size(); ++i) - { - if (i != 0) - ss << " | "; - ss << _shapes[i].first << ": " << _shapes[i].second; - } - - if (_stride_shape.rank() != 0) - { - addPipeIfNeeded(ss, true, need_pipe); - ss << "Stride: " << _stride_shape; - } - - // misc scalar parameters (Cap, dropRate, etc..) - addPipeIfNeeded(ss, !_misc_vals.empty(), need_pipe); - for (Shapes::size_type i = 0; i < _misc_vals.size(); ++i) - { - if (i != 0) - ss << " | "; - ss << _misc_vals[i].first << ": " << _misc_vals[i].second; - } - - return ss.str(); -} - -void DotIrNodeInfo::addPipeIfNeeded(std::stringstream &ss, bool needed, bool &need_pipe) const -{ - if (needed) - { - if (need_pipe) - ss << " | "; - else - need_pipe = true; - } -} - -} // namespace mir