#include "absl/strings/str_join.h"
#include "absl/strings/str_replace.h"
#include "absl/strings/str_split.h"
+#include "re2/re2.h"
#include "tensorflow/contrib/lite/toco/dump_graphviz.h"
#include "tensorflow/contrib/lite/toco/model_flags.pb.h"
#include "tensorflow/contrib/lite/toco/toco_graphviz_dump_options.h"
}
}
+// Returns the array names that match the ArraysExtraInfo's name and
+// name_regexp. The regexp match is for a full match.
+std::unordered_set<string> ScanArrayNames(
+ const Model& model, const toco::ArraysExtraInfo_Entry& entry) {
+ std::unordered_set<string> matches;
+ if (model.HasArray(entry.name())) {
+ matches.insert(entry.name());
+ }
+ if (!entry.name_regexp().empty()) {
+ const auto& arrays = model.GetArrayMap();
+ const RE2 name_regexp = {entry.name_regexp()};
+ for (auto it = arrays.begin(); it != arrays.end(); ++it) {
+ if (RE2::FullMatch(it->first, name_regexp)) {
+ matches.insert(it->first);
+ }
+ }
+ }
+ return matches;
+}
+
void UseArraysExtraInfo(Model* model, bool quantize_output) {
for (const auto& entry : model->flags.arrays_extra_info().entries()) {
- if (!model->HasArray(entry.name())) {
- continue;
- }
- auto& array = model->GetArray(entry.name());
- if (entry.has_min() || entry.has_max()) {
- CHECK_EQ(entry.has_min(), entry.has_max());
- auto& minmax = array.GetOrCreateMinMax();
- minmax.min = entry.min();
- minmax.max = entry.max();
- }
- if (entry.has_data_type() && quantize_output) {
- array.final_data_type =
- ConvertIODataTypeToArrayDataType(entry.data_type());
- }
- if (entry.has_shape()) {
- array.clear_shape();
- // Make sure to create the shape even if there are no dims, to
- // correctly record 0-D shapes.
- array.mutable_shape();
- for (int dim : entry.shape().dims()) {
- array.mutable_shape()->mutable_dims()->push_back(dim);
+ const auto matches = ScanArrayNames(*model, entry);
+ for (const auto& matched_name : matches) {
+ auto& array = model->GetArray(matched_name);
+ if (entry.has_min() || entry.has_max()) {
+ CHECK_EQ(entry.has_min(), entry.has_max());
+ auto& minmax = array.GetOrCreateMinMax();
+ minmax.min = entry.min();
+ minmax.max = entry.max();
}
- }
- if (entry.has_constant_float_value()) {
- CHECK(array.has_shape());
- if (array.data_type == ArrayDataType::kFloat) {
- auto& data = array.GetMutableBuffer<ArrayDataType::kFloat>().data;
- data.resize(RequiredBufferSizeForShape(array.shape()));
- for (float& f : data) {
- f = entry.constant_float_value();
+ if (entry.has_data_type() && quantize_output) {
+ array.final_data_type =
+ ConvertIODataTypeToArrayDataType(entry.data_type());
+ }
+ if (entry.has_shape()) {
+ array.clear_shape();
+ // Make sure to create the shape even if there are no dims, to
+ // correctly record 0-D shapes.
+ array.mutable_shape();
+ for (int dim : entry.shape().dims()) {
+ array.mutable_shape()->mutable_dims()->push_back(dim);
+ }
+ }
+ if (entry.has_constant_float_value()) {
+ CHECK(array.has_shape());
+ if (array.data_type == ArrayDataType::kFloat) {
+ auto& data = array.GetMutableBuffer<ArrayDataType::kFloat>().data;
+ data.resize(RequiredBufferSizeForShape(array.shape()));
+ for (float& f : data) {
+ f = entry.constant_float_value();
+ }
}
}
}