From 6fc0da83bb5fdd0ebde4eefe62351805256d1961 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=84=B8=ED=9D=AC/On-Device=20Lab=28SR=29/Princip?= =?utf8?q?al=20Engineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Tue, 26 Nov 2019 13:36:48 +0900 Subject: [PATCH] [moco-tf] Use oops for RsqrtCanonicalizer (#9208) This will update RsqrtCanonicalizer to use oops - modify internal method to handle error a bit easier Signed-off-by: SaeHie Park --- .../moco-tf/src/Canonicalization/RsqrtCanonicalizer.cpp | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/compiler/moco-tf/src/Canonicalization/RsqrtCanonicalizer.cpp b/compiler/moco-tf/src/Canonicalization/RsqrtCanonicalizer.cpp index dfaeb1d..c31dbf6 100644 --- a/compiler/moco-tf/src/Canonicalization/RsqrtCanonicalizer.cpp +++ b/compiler/moco-tf/src/Canonicalization/RsqrtCanonicalizer.cpp @@ -24,15 +24,16 @@ #include #include +#include namespace { template -void prepare_const_gen(loco::ConstGen *const_node, const loco::TensorShape &tensorshape, T value); +bool prepare_const_gen(loco::ConstGen *const_node, const loco::TensorShape &tensorshape, T value); template <> -void prepare_const_gen(loco::ConstGen *const_node, const loco::TensorShape &tensorshape, +bool prepare_const_gen(loco::ConstGen *const_node, const loco::TensorShape &tensorshape, float value) { LOGGER(l); @@ -49,7 +50,7 @@ void prepare_const_gen(loco::ConstGen *const_node, const loco::TensorShap if (tensorshape.dim(r).known()) const_node->dim(r) = tensorshape.dim(r); else - throw std::runtime_error("Cannot handle unknown shape"); + return false; assert(tensorshape.dim(r).value() > 0); @@ -63,6 +64,8 @@ void prepare_const_gen(loco::ConstGen *const_node, const loco::TensorShap { const_node->at(i) = value; } + + return true; } bool canonicalize_rsqrt(loco::Graph *graph, moco::TFRsqrt *node) @@ -110,11 +113,12 @@ bool canonicalize_rsqrt(loco::Graph *graph, moco::TFRsqrt *node) switch (dtype) { case loco::DataType::FLOAT32: - prepare_const_gen(const_node, tensorshape, 1.0f); + if (!prepare_const_gen(const_node, tensorshape, 1.0f)) + throw oops::UserExn("Cannot handle unknown shape", node->name()); break; default: - throw std::runtime_error("NYI for this DataType"); + throw oops::UserExn("Unsupported data type", node->name()); } auto node_A = node->x(); -- 2.7.4