From d67c5a2cfe3bf6481bdfd96c8985320d12cdf8e1 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=A2=85=ED=98=84/On-Device=20Lab=28SR=29/Staff?= =?utf8?q?=20Engineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Wed, 7 Aug 2019 18:21:13 +0900 Subject: [PATCH] [exo-tflite] Implement FilterShape converter (#6334) This commit implement a helper which converts FilterShape as exo-tflite internal ShapeDescription. Signed-off-by: Jonghyun Park --- compiler/exo-tflite/src/ExporterUtils.cpp | 18 ++++++++++++++++++ compiler/exo-tflite/src/ExporterUtils.h | 1 + 2 files changed, 19 insertions(+) diff --git a/compiler/exo-tflite/src/ExporterUtils.cpp b/compiler/exo-tflite/src/ExporterUtils.cpp index c573e1f..384133d 100644 --- a/compiler/exo-tflite/src/ExporterUtils.cpp +++ b/compiler/exo-tflite/src/ExporterUtils.cpp @@ -49,6 +49,22 @@ ShapeDescription to_shape_description(const loco::FeatureShape &shape) return res; } +ShapeDescription to_shape_description(const loco::FilterShape &shape) +{ + ShapeDescription res; + + res._rank_known = true; + + // T/F Lite encodes a convolution filter as a NHWC tensor + res._dims.resize(4); + res._dims.at(0) = shape.count().value(); + res._dims.at(1) = shape.height().value(); + res._dims.at(2) = shape.width().value(); + res._dims.at(3) = shape.depth().value(); + + return res; +} + ShapeDescription to_shape_description(const loco::BiasShape &shape) { ShapeDescription res; @@ -69,6 +85,8 @@ ShapeDescription to_shape_description(const loco::NodeShape &shape) return to_shape_description(shape.as()); case loco::Domain::Feature: return to_shape_description(shape.as()); + case loco::Domain::Filter: + return to_shape_description(shape.as()); case loco::Domain::Bias: return to_shape_description(shape.as()); default: diff --git a/compiler/exo-tflite/src/ExporterUtils.h b/compiler/exo-tflite/src/ExporterUtils.h index c35a28e..cf61785 100644 --- a/compiler/exo-tflite/src/ExporterUtils.h +++ b/compiler/exo-tflite/src/ExporterUtils.h @@ -50,6 +50,7 @@ struct ShapeDescription ShapeDescription to_shape_description(const loco::TensorShape &shape); ShapeDescription to_shape_description(const loco::FeatureShape &shape); +ShapeDescription to_shape_description(const loco::FilterShape &shape); ShapeDescription to_shape_description(const loco::BiasShape &shape); ShapeDescription to_shape_description(const loco::NodeShape &shape); -- 2.7.4