#include "ShapeFixer.h"
+#include "Convert.h"
#include <stdexcept>
namespace neurun
assert(tensor_builder);
}
+void ShapeFixer::visit(const model::Subgraph &subgraph)
+{
+ assert(_lower_info_map != nullptr);
+ _current_subg_layout = subgraph.getLayout();
+ for (const auto &e : subgraph.operations())
+ {
+ const auto &node = *(e.node);
+ node.accept(*this);
+
+ // NOTE
+ const auto frontend_layout = _current_subg_layout;
+ for (const auto &input : node.getInputs())
+ {
+ const auto &obj = _ctx.at(input);
+ const auto lower_info = _lower_info_map->operand.at(input).get();
+ const auto backend_layout = lower_info->def_factors().getOnlyElement().layout();
+ model::OperandInfo backend_info{asTensorShape(obj.shape(), frontend_layout, backend_layout),
+ obj.info().typeInfo()};
+ _tensor_builder->registerTensorInfo(input, backend_info, frontend_layout, backend_layout,
+ obj.isConstant());
+ }
+
+ for (const auto &output : node.getOutputs())
+ {
+ const auto &obj = _ctx.at(output);
+ const auto lower_info = _lower_info_map->operand.at(output).get();
+ const auto backend_layout = lower_info->def_factors().getOnlyElement().layout();
+ model::OperandInfo backend_info{asTensorShape(obj.shape(), frontend_layout, backend_layout),
+ obj.info().typeInfo()};
+ _tensor_builder->registerTensorInfo(output, backend_info, frontend_layout, backend_layout,
+ obj.isConstant());
+ }
+ }
+}
+
void ShapeFixer::visit(const model::operation::InstanceNorm &) { /* DO NOTHING */}
-void ShapeFixer::visit(const model::operation::TransposeConv &) { /* DO NOTHING */}
+
+void ShapeFixer::visit(const model::operation::TransposeConv &node)
+{
+ // Special case
+ const auto &kernel_index = node.getInputs().at(model::operation::TransposeConv::KERNEL);
+ const auto &kernel_obj = _ctx.at(kernel_index);
+
+ const auto frontend_layout = _current_subg_layout;
+ assert(frontend_layout == model::Layout::NCHW || frontend_layout == model::Layout::NHWC);
+ const auto frontend_filter_layout = frontend_layout == model::Layout::NHWC
+ ? kernel::FilterLayout::OHWI
+ : kernel::FilterLayout::OIHW;
+ const auto lower_info = _lower_info_map->operand.at(kernel_index).get();
+ const auto backend_layout = lower_info->def_factors().getOnlyElement().layout();
+ assert(backend_layout == model::Layout::NCHW || backend_layout == model::Layout::NHWC);
+ const auto backend_filter_layout = backend_layout == model::Layout::NHWC
+ ? kernel::FilterLayout::HWOI
+ : kernel::FilterLayout::OIHW;
+
+ model::OperandInfo backend_info{
+ asKernelShape(kernel_obj.shape(), frontend_filter_layout, backend_filter_layout),
+ kernel_obj.info().typeInfo()};
+ _tensor_builder->registerTensorInfo(kernel_index, backend_info, frontend_layout, backend_layout,
+ kernel_obj.isConstant());
+}
+
void ShapeFixer::visit(const model::operation::Add &) { /* DO NOTHING */}
} // namespace srcn
}
void TensorBuilder::registerTensorInfo(const model::OperandIndex &ind,
- const model::OperandInfo &info, model::Layout,
+ const model::OperandInfo &tensor_info, model::Layout,
model::Layout backend_layout, bool as_const)
{
- if (backend_layout == model::Layout::NCHW)
+ // NOTE This backend's several weights unusually are permutated. So ShapeFixer call this function
+ // for all tensor with those weights tensors to register unuseally permuated tensor info. After
+ // that, even if another place calls this function, registered tensor info is not updated.
+ if (_tensor_info_map.find(ind) != _tensor_info_map.end())
{
- throw std::runtime_error("Not supported layout yet");
+ return;
}
- _tensor_info_map.emplace(ind, info);
+ _tensor_info_map.emplace(ind, tensor_info);
_tensor_layout_map.emplace(ind, backend_layout);
if (as_const)
void TensorBuilder::notifyFirstUse(const model::OperandIndex &ind)
{
assert(_tensor_info_map.find(ind) != _tensor_info_map.end());
- const auto &info = _tensor_info_map.at(ind);
- const auto size = info.total_size();
- // TODO Remove frontend_layout
- const auto frontend_layout = model::Layout::NHWC;
+ const auto &tensor_info = _tensor_info_map.at(ind);
+ const auto size = tensor_info.total_size();
const auto &backend_layout = _tensor_layout_map.at(ind);
- const auto tensor_info =
- asTensorInfo(info.shape(), info.typeInfo(), frontend_layout, backend_layout);
_tensor_mgr->buildTensor(ind, tensor_info, backend_layout, _constants.contains(ind));
_tensor_mgr->claimPlan(ind, size);
}