Allow specifying in the arrays extra info file:
authorA. Unique TensorFlower <gardener@tensorflow.org>
Thu, 22 Mar 2018 20:21:57 +0000 (13:21 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 22 Mar 2018 20:25:05 +0000 (13:25 -0700)
 - the shape of the array
 - the hardcoding of the values of the array as a single repeated constant
   scalar value, turning an activations array into a constant array.

PiperOrigin-RevId: 190115218

tensorflow/contrib/lite/toco/model_flags.proto
tensorflow/contrib/lite/toco/tooling_util.cc

index 867b86f..42e0f54 100644 (file)
@@ -96,11 +96,13 @@ message RnnState {
 // model that does not already contain such MinMax information.
 message ArraysExtraInfo {
   message Entry {
-    // Next ID to use: 5.
+    // Next ID to use: 7.
     optional string name = 1;
     optional float min = 2;
     optional float max = 3;
     optional IODataType data_type = 4;
+    optional InputArrayShape shape = 5;
+    optional float constant_float_value = 6;
   }
   repeated Entry entries = 1;
 }
index ec1770c..f3f5048 100644 (file)
@@ -1972,9 +1972,9 @@ void FinishBuildingRNNStates(Model* model) {
 
 void UseArraysExtraInfo(Model* model) {
   for (const auto& entry : model->flags.arrays_extra_info().entries()) {
-    QCHECK(model->HasArray(entry.name()))
-        << "ArraysExtraInfo refers to non-existent array name: "
-        << entry.name();
+    if (!model->HasArray(entry.name())) {
+      continue;
+    }
     auto& array = model->GetArray(entry.name());
     auto& minmax = array.GetOrCreateMinMax();
     if (entry.has_min() || entry.has_max()) {
@@ -1986,6 +1986,24 @@ void UseArraysExtraInfo(Model* model) {
       array.final_data_type =
           ConvertIODataTypeToArrayDataType(entry.data_type());
     }
+    if (entry.has_shape()) {
+      array.clear_shape();
+      // Make sure to create the shape even if there are no dims, to
+      // correctly record 0-D shapes.
+      array.mutable_shape();
+      for (int dim : entry.shape().dims()) {
+        array.mutable_shape()->mutable_dims()->push_back(dim);
+      }
+    }
+    if (entry.has_constant_float_value()) {
+      CHECK(array.has_shape());
+      CHECK(array.data_type == ArrayDataType::kFloat);
+      auto& data = array.GetMutableBuffer<ArrayDataType::kFloat>().data;
+      data.resize(RequiredBufferSizeForShape(array.shape()));
+      for (float& f : data) {
+        f = entry.constant_float_value();
+      }
+    }
   }
 }