#include "GraphBuilderContext.h"
#include "Knob.h"
+#include "IR/TFConv2D.h"
+
#include "Annotations/PaddingData.h"
+#include "Annotations/PadData.h"
#include <moco/tf/Names.h>
_ker_enc->input(ker_node);
}
+class TFConv2DGraphUpdate final : public GraphUpdate
+{
+public:
+ TFConv2DGraphUpdate(TFConv2D *node, std::vector<TensorName> names) : _node(node), _names(names) {}
+
+ void input(const SymbolTable *) const override;
+
+private:
+ TFConv2D *_node;
+ std::vector<TensorName> _names;
+};
+
+void TFConv2DGraphUpdate::input(const SymbolTable *node_table) const
+{
+ assert(_names.size() == 2);
+
+ auto ifm_node = node_table->node(_names[0]);
+ auto ker_node = node_table->node(_names[1]);
+ assert(ifm_node != nullptr);
+ assert(ker_node != nullptr);
+
+ _node->ifm(ifm_node);
+ _node->ker(ker_node);
+}
+
} // namespace
namespace moco
void Conv2DGraphBuilder::buildTF(const tensorflow::NodeDef &node,
GraphBuilderContext *context) const
{
- throw std::runtime_error("NYI");
+ loco::Graph *graph = context->graph();
+ SymbolTable *tensor_names = context->tensor_names();
+ UpdateQueue *updates = context->updates();
+
+ // name of loco nodes
+ std::string conv2d_name = node.name();
+
+ auto conv2d = graph->nodes()->create<TFConv2D>();
+
+ // read attributes
+ auto data_layout = get_string_attr(node, "data_format");
+ if (!(data_layout == "NHWC" || data_layout == "NCHW"))
+ {
+ throw std::runtime_error("Not yet supported");
+ }
+ conv2d->data_layout(data_layout);
+
+ auto tf_strides = get_list_attr(node, "strides");
+ auto strides = as_int64_list(tf_strides);
+ conv2d->strides(strides);
+
+ auto padding = moco::str_toupper(get_string_attr(node, "padding"));
+ assert(padding == "VALID" || padding == "SAME");
+ conv2d->padding(padding);
+
+ // save the name for graph link updates
+ TensorName output_name(conv2d_name, 0);
+ tensor_names->enroll(output_name, conv2d);
+
+ std::vector<TensorName> input_names;
+ input_names.push_back(TensorName(node.input(0))); // input
+ input_names.push_back(TensorName(node.input(1))); // kernel
+
+ // Record ifm inputs to featureEncode_node
+ auto tfconv2d_update = stdex::make_unique<TFConv2DGraphUpdate>(conv2d, input_names);
+
+ updates->enroll(std::move(tfconv2d_update));
}
} // namespace tf