Automated g4 rollback of changelist 195091587
authorA. Unique TensorFlower <gardener@tensorflow.org>
Thu, 3 May 2018 01:11:25 +0000 (18:11 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 3 May 2018 01:13:39 +0000 (18:13 -0700)
PiperOrigin-RevId: 195184798

tensorflow/contrib/lite/toco/BUILD
tensorflow/contrib/lite/toco/model_flags.proto
tensorflow/contrib/lite/toco/tooling_util.cc

index f16225f..ce0a747 100644 (file)
@@ -397,6 +397,7 @@ cc_library(
         ":types_proto_cc",
         "//tensorflow/core:lib",
         "@com_google_absl//absl/strings",
+        "@com_googlesource_code_re2//:re2",
         "@protobuf_archive//:protobuf_headers",
     ],
 )
index d23e80c..6c1c536 100644 (file)
@@ -96,8 +96,9 @@ message RnnState {
 // model that does not already contain such MinMax information.
 message ArraysExtraInfo {
   message Entry {
-    // Next ID to use: 7.
+    // Next ID to use: 8.
     optional string name = 1;
+    optional string name_regexp = 7;
     optional double min = 2;
     optional double max = 3;
     optional IODataType data_type = 4;
index f334c51..11293a5 100644 (file)
@@ -26,6 +26,7 @@ limitations under the License.
 #include "absl/strings/str_join.h"
 #include "absl/strings/str_replace.h"
 #include "absl/strings/str_split.h"
+#include "re2/re2.h"
 #include "tensorflow/contrib/lite/toco/dump_graphviz.h"
 #include "tensorflow/contrib/lite/toco/model_flags.pb.h"
 #include "tensorflow/contrib/lite/toco/toco_graphviz_dump_options.h"
@@ -1983,38 +1984,58 @@ void FinishBuildingRNNStates(Model* model) {
   }
 }
 
+// Returns the array names that match the ArraysExtraInfo's name and
+// name_regexp. The regexp match is for a full match.
+std::unordered_set<string> ScanArrayNames(
+    const Model& model, const toco::ArraysExtraInfo_Entry& entry) {
+  std::unordered_set<string> matches;
+  if (model.HasArray(entry.name())) {
+    matches.insert(entry.name());
+  }
+  if (!entry.name_regexp().empty()) {
+    const auto& arrays = model.GetArrayMap();
+    const RE2 name_regexp = {entry.name_regexp()};
+    for (auto it = arrays.begin(); it != arrays.end(); ++it) {
+      if (RE2::FullMatch(it->first, name_regexp)) {
+        matches.insert(it->first);
+      }
+    }
+  }
+  return matches;
+}
+
 void UseArraysExtraInfo(Model* model, bool quantize_output) {
   for (const auto& entry : model->flags.arrays_extra_info().entries()) {
-    if (!model->HasArray(entry.name())) {
-      continue;
-    }
-    auto& array = model->GetArray(entry.name());
-    if (entry.has_min() || entry.has_max()) {
-      CHECK_EQ(entry.has_min(), entry.has_max());
-      auto& minmax = array.GetOrCreateMinMax();
-      minmax.min = entry.min();
-      minmax.max = entry.max();
-    }
-    if (entry.has_data_type() && quantize_output) {
-      array.final_data_type =
-          ConvertIODataTypeToArrayDataType(entry.data_type());
-    }
-    if (entry.has_shape()) {
-      array.clear_shape();
-      // Make sure to create the shape even if there are no dims, to
-      // correctly record 0-D shapes.
-      array.mutable_shape();
-      for (int dim : entry.shape().dims()) {
-        array.mutable_shape()->mutable_dims()->push_back(dim);
+    const auto matches = ScanArrayNames(*model, entry);
+    for (const auto& matched_name : matches) {
+      auto& array = model->GetArray(matched_name);
+      if (entry.has_min() || entry.has_max()) {
+        CHECK_EQ(entry.has_min(), entry.has_max());
+        auto& minmax = array.GetOrCreateMinMax();
+        minmax.min = entry.min();
+        minmax.max = entry.max();
       }
-    }
-    if (entry.has_constant_float_value()) {
-      CHECK(array.has_shape());
-      if (array.data_type == ArrayDataType::kFloat) {
-        auto& data = array.GetMutableBuffer<ArrayDataType::kFloat>().data;
-        data.resize(RequiredBufferSizeForShape(array.shape()));
-        for (float& f : data) {
-          f = entry.constant_float_value();
+      if (entry.has_data_type() && quantize_output) {
+        array.final_data_type =
+            ConvertIODataTypeToArrayDataType(entry.data_type());
+      }
+      if (entry.has_shape()) {
+        array.clear_shape();
+        // Make sure to create the shape even if there are no dims, to
+        // correctly record 0-D shapes.
+        array.mutable_shape();
+        for (int dim : entry.shape().dims()) {
+          array.mutable_shape()->mutable_dims()->push_back(dim);
+        }
+      }
+      if (entry.has_constant_float_value()) {
+        CHECK(array.has_shape());
+        if (array.data_type == ArrayDataType::kFloat) {
+          auto& data = array.GetMutableBuffer<ArrayDataType::kFloat>().data;
+          data.resize(RequiredBufferSizeForShape(array.shape()));
+          for (float& f : data) {
+            f = entry.constant_float_value();
+          }
         }
       }
     }