if (const auto *graph_builder = tflimport::GraphBuilderRegistry::get().lookup(builtincode))
{
+ if (!graph_builder->validate(op))
+ {
+ throw std::runtime_error{"Invalid operator"};
+ }
+
graph_builder->build(op, &opbuilder_context);
}
else
class GraphBuilder
{
public:
+ /**
+ * TODO Declare "validate" method as a pure virtual method
+ *
+ * Q: Is it possible to validate T/F Lite model only with this interface?
+ */
+ virtual bool validate(const tflite::Operator *) const { return true; }
+
virtual void build(const tflite::Operator *op, GraphBuilderContext *context) const = 0;
virtual ~GraphBuilder() {}
};
namespace tflimport
{
+bool AvgPool2DGraphBuilder::validate(const tflite::Operator *op) const
+{
+ auto const options = op->builtin_options_as_Pool2DOptions();
+
+ if ((options->stride_h() == 0) || (options->stride_w() == 0))
+ {
+ return false;
+ }
+
+ return true;
+}
+
void AvgPool2DGraphBuilder::build(const tflite::Operator *op, GraphBuilderContext *context) const
{
assert(context != nullptr); // check if init(..) is called
class AvgPool2DGraphBuilder : public GraphBuilder
{
public:
+ bool validate(const tflite::Operator *op) const override;
void build(const tflite::Operator *op, GraphBuilderContext *) const override;
};
namespace tflimport
{
+bool Conv2DGraphBuilder::validate(const tflite::Operator *op) const
+{
+ auto const options = op->builtin_options_as_Conv2DOptions();
+
+ if ((options->stride_h() == 0) || (options->stride_w() == 0))
+ {
+ return false;
+ }
+
+ return true;
+}
+
void Conv2DGraphBuilder::build(const tflite::Operator *op, GraphBuilderContext *context) const
{
assert(context != nullptr);
class Conv2DGraphBuilder : public GraphBuilder
{
public:
+ bool validate(const tflite::Operator *op) const override;
void build(const tflite::Operator *op, GraphBuilderContext *context) const override;
};
namespace tflimport
{
+bool DepthwiseConv2DGraphBuilder::validate(const tflite::Operator *op) const
+{
+ auto const options = op->builtin_options_as_DepthwiseConv2DOptions();
+
+ if ((options->stride_h() == 0) || (options->stride_w() == 0))
+ {
+ return false;
+ }
+
+ return true;
+}
+
void DepthwiseConv2DGraphBuilder::build(const tflite::Operator *op,
GraphBuilderContext *context) const
{
class DepthwiseConv2DGraphBuilder : public GraphBuilder
{
public:
+ bool validate(const tflite::Operator *op) const override;
void build(const tflite::Operator *op, GraphBuilderContext *context) const override;
};
namespace tflimport
{
+bool MaxPool2DGraphBuilder::validate(const tflite::Operator *op) const
+{
+ auto const options = op->builtin_options_as_Pool2DOptions();
+
+ if ((options->stride_h() == 0) || (options->stride_w() == 0))
+ {
+ return false;
+ }
+
+ return true;
+}
+
void MaxPool2DGraphBuilder::build(const tflite::Operator *op, GraphBuilderContext *context) const
{
assert(context != nullptr); // check if init(..) is called
class MaxPool2DGraphBuilder : public GraphBuilder
{
public:
+ bool validate(const tflite::Operator *op) const override;
void build(const tflite::Operator *op, GraphBuilderContext *) const override;
};
tflite::Padding padding, int stride_w, int stride_h,
int dilation_w_factor, int dilation_h_factor)
{
+ assert(stride_w != 0);
+ assert(stride_h != 0);
assert(ifm_shape.rank() == 4);
/**