#include <cassert>
#include <stdexcept>
+#include <algorithm>
namespace
{
// note: even though "data_format" is not entered when a model is written,
// TF seems to generate "data_format" field into a pb file
- return plier::tf::has_attrs(node, {"T", "data_format", "dilations", "padding", "strides"});
+ bool has_mandatory_attrs = plier::tf::has_attrs(node, {"T", "data_format", "padding", "strides"});
+ // dilation attribute is not fully supported
+ bool supported_dilations = true;
+ if (plier::tf::has_attr(node, "dilations"))
+ {
+ auto dilation = plier::tf::get_list_attr(node, "dilations").i();
+ supported_dilations =
+ std::all_of(dilation.begin(), dilation.end(), [](std::int64_t dil) { return dil == 1; });
+ }
+ return has_mandatory_attrs && supported_dilations;
}
void Conv2DGraphBuilder::build(const tensorflow::NodeDef &node, GraphBuilderContext *context) const