Internal Change
authorA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 23 Jan 2018 20:24:21 +0000 (12:24 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 23 Jan 2018 20:28:02 +0000 (12:28 -0800)
PiperOrigin-RevId: 182974191

58 files changed:
tensorflow/contrib/lite/toco/allocate_transient_arrays.cc
tensorflow/contrib/lite/toco/dump_graphviz.cc
tensorflow/contrib/lite/toco/export_tensorflow.cc
tensorflow/contrib/lite/toco/graph_transformations/convert_expanddims_to_reshape.cc
tensorflow/contrib/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc
tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_transpose_to_reshape.cc
tensorflow/contrib/lite/toco/graph_transformations/create_im2col_arrays.cc
tensorflow/contrib/lite/toco/graph_transformations/dequantize.cc
tensorflow/contrib/lite/toco/graph_transformations/drop_fake_quant.cc
tensorflow/contrib/lite/toco/graph_transformations/drop_im2col_arrays.cc
tensorflow/contrib/lite/toco/graph_transformations/fuse_activation_functions.cc
tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_following_affine.cc
tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc
tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.cc
tensorflow/contrib/lite/toco/graph_transformations/identify_l2_normalization.cc
tensorflow/contrib/lite/toco/graph_transformations/identify_l2_pool.cc
tensorflow/contrib/lite/toco/graph_transformations/identify_relu1.cc
tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
tensorflow/contrib/lite/toco/graph_transformations/quantize.cc
tensorflow/contrib/lite/toco/graph_transformations/read_fake_quant_min_max.cc
tensorflow/contrib/lite/toco/graph_transformations/remove_final_dequantize_op.cc
tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_binary.cc
tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_concatenation_input.cc
tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.cc
tensorflow/contrib/lite/toco/graph_transformations/remove_unused_op.cc
tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_normalization.cc
tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_to_space_nd_attributes.cc
tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc
tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_concatenation.cc
tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc
tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fill.cc
tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_range.cc
tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_shape_or_rank.cc
tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_stack.cc
tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc
tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc
tensorflow/contrib/lite/toco/graph_transformations/resolve_mean_attributes.cc
tensorflow/contrib/lite/toco/graph_transformations/resolve_pad_attributes.cc
tensorflow/contrib/lite/toco/graph_transformations/resolve_reorder_axes.cc
tensorflow/contrib/lite/toco/graph_transformations/resolve_reshape_attributes.cc
tensorflow/contrib/lite/toco/graph_transformations/resolve_slice_attributes.cc
tensorflow/contrib/lite/toco/graph_transformations/resolve_space_to_batch_nd_attributes.cc
tensorflow/contrib/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc
tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_concat.cc
tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc
tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_merge.cc
tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_switch.cc
tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_tile.cc
tensorflow/contrib/lite/toco/graph_transformations/resolve_transpose_attributes.cc
tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_concatenation_test.cc
tensorflow/contrib/lite/toco/graph_transformations/unfuse_activation_functions.cc
tensorflow/contrib/lite/toco/import_tensorflow.cc
tensorflow/contrib/lite/toco/model.h
tensorflow/contrib/lite/toco/tflite/export.cc
tensorflow/contrib/lite/toco/tflite/import_test.cc
tensorflow/contrib/lite/toco/toco_tooling.cc
tensorflow/contrib/lite/toco/tooling_util.cc

index d4da8f5dfe13a38e8b6886656c5c7e0c8fbb1316..5961d30bf5403df7fa6228e05124479d118dd279 100644 (file)
@@ -148,7 +148,7 @@ std::size_t TransientArraySize(const Model& model, const string& array_name,
   if (!IsAllocatableTransientArray(model, array_name)) {
     return 0;
   }
-  const auto& array = model.arrays.at(array_name);
+  const auto& array = &model.GetArray(array_name);
   CHECK(array->has_shape())
       << "Array '" << array_name << "' doesn't have a shape";
   if (array->data_type == ArrayDataType::kNone) {
@@ -185,7 +185,7 @@ void AllocateTransientArray(const Model& model, const string& array_name,
   }
   const std::size_t size =
       TransientArraySize(model, array_name, transient_data_alignment);
-  const auto& array = model.arrays.at(array_name);
+  const auto& array = &model.GetArray(array_name);
   CHECK(!array->alloc);
   allocator->Allocate(size, &array->GetOrCreateAlloc());
 }
@@ -197,7 +197,7 @@ void DeallocateTransientArray(const Model& model, const string& array_name,
   if (!IsAllocatableTransientArray(model, array_name)) {
     return;
   }
-  const auto& array = model.arrays.at(array_name);
+  const auto& array = &model.GetArray(array_name);
   CHECK(!!array->alloc);
   allocator->Deallocate(*array->alloc);
 }
@@ -231,7 +231,7 @@ void AllocateTransientArrays(Model* model,
   // Construct a sorted map of array names, so that other layout engines can
   // match exactly.
   std::map<string, const Array*> ordered_arrays_map;
-  for (const auto& pair : model->arrays) {
+  for (const auto& pair : model->GetArrayMap()) {
     ordered_arrays_map[pair.first] = pair.second.get();
   }
 
index 39809216c77bdadfd44aafbddc8e0979fde66a49..c726eb6d8678e2703f5acba8b3d8d740186939f5 100644 (file)
@@ -278,8 +278,8 @@ std::vector<const Operator*> OperatorsToDump(const Model& model) {
   if (last_specified) {
     // Return only the part of the graph between graphviz_first_array
     // and graphviz_last_array.
-    CHECK(model.arrays.count(dump_options.graphviz_first_array));
-    CHECK(model.arrays.count(dump_options.graphviz_last_array));
+    CHECK(model.HasArray(dump_options.graphviz_first_array));
+    CHECK(model.HasArray(dump_options.graphviz_last_array));
     std::unordered_set<string> arrays_already_produced;
     std::vector<string> arrays_to_produce;
     arrays_to_produce.push_back(dump_options.graphviz_last_array);
@@ -336,7 +336,7 @@ void DumpGraphviz(const Model& model, string* output_file_contents) {
             op_properties.color.TextColorString().c_str());
     // Add nodes and edges for all inputs of the operator.
     for (const auto& input : op.inputs) {
-      if (model.arrays.count(input) == 0) {
+      if (!model.HasArray(input)) {
         // Arrays should _always_ exist. Except, perhaps, during development.
         continue;
       }
@@ -352,7 +352,7 @@ void DumpGraphviz(const Model& model, string* output_file_contents) {
     }
     // Add nodes and edges for all outputs of the operator.
     for (const auto& output : op.outputs) {
-      if (model.arrays.count(output) == 0) {
+      if (!model.HasArray(output)) {
         // Arrays should _always_ exist. Except, perhaps, during development.
         continue;
       }
index 90fa442746cdee975b0103ce60817a95f9b31086..4fc01dbc20272eb863b0b22d6a1ef7b27c499981 100644 (file)
@@ -156,8 +156,8 @@ void ConvertFloatTensorConst(const Model& model, const string& name,
   const_op->set_name(name);
   (*const_op->mutable_attr())["dtype"].set_type(DT_FLOAT);
   auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor();
-  CHECK(model.arrays.count(name));
-  const auto& input_array = *model.arrays.at(name);
+  CHECK(model.HasArray(name));
+  const auto& input_array = model.GetArray(name);
   const auto& input_shape = input_array.shape();
   CHECK(input_array.buffer);
   CHECK(input_array.buffer->type == ArrayDataType::kFloat);
@@ -177,8 +177,8 @@ void ConvertFloatTensorConst(const Model& model, const string& name,
   const_op->set_name(name);
   (*const_op->mutable_attr())["dtype"].set_type(DT_FLOAT);
   auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor();
-  CHECK(model.arrays.count(name));
-  const auto& input_array = *model.arrays.at(name);
+  CHECK(model.HasArray(name));
+  const auto& input_array = model.GetArray(name);
   const auto& input_shape = input_array.shape();
   CHECK(input_array.buffer);
   CHECK(input_array.buffer->type == ArrayDataType::kFloat);
@@ -193,8 +193,8 @@ void ConvertIntTensorConst(const Model& model, const string& name,
   if (HasAlreadyExportedConst(name, *tensorflow_graph)) {
     return;
   }
-  CHECK(model.arrays.count(name));
-  const auto& array = *model.arrays.at(name);
+  CHECK(model.HasArray(name));
+  const auto& array = model.GetArray(name);
   auto* const_op = tensorflow_graph->add_node();
   const_op->set_op("Const");
   const_op->set_name(name);
@@ -324,7 +324,7 @@ void ConvertConvOperator(const Model& model, const ConvOperator& src_op,
     biasadd_op->add_input(conv_output);
     biasadd_op->add_input(src_op.inputs[2]);
     (*biasadd_op->mutable_attr())["T"].set_type(DT_FLOAT);
-    CHECK(model.arrays.count(src_op.inputs[2]));
+    CHECK(model.HasArray(src_op.inputs[2]));
     const string& bias_array_name =
         WalkUpToConstantArray(model, src_op.inputs[2]);
     const auto& bias_array = model.GetArray(bias_array_name);
@@ -361,7 +361,7 @@ void ConvertDepthwiseConvOperator(const Model& model,
   // We need to convert that to H x W x InputDepth x Multiplier.
   // That's only a matter of constructing a Dims object; the actual
   // array layout is the same.
-  CHECK(model.arrays.count(src_op.inputs[1]));
+  CHECK(model.HasArray(src_op.inputs[1]));
   const string& src_weights_name =
       WalkUpToConstantArray(model, src_op.inputs[1]);
   const auto& src_weights_array = model.GetArray(src_weights_name);
@@ -404,7 +404,7 @@ void ConvertDepthwiseConvOperator(const Model& model,
     biasadd_op->add_input(conv_output);
     biasadd_op->add_input(src_op.inputs[2]);
     (*biasadd_op->mutable_attr())["T"].set_type(DT_FLOAT);
-    CHECK(model.arrays.count(src_op.inputs[2]));
+    CHECK(model.HasArray(src_op.inputs[2]));
     const string& bias_name = WalkUpToConstantArray(model, src_op.inputs[2]);
     const auto& bias_array = model.GetArray(bias_name);
     // TODO(b/62904716) Bias arrays should be 1-D, and used directly.
@@ -469,10 +469,10 @@ void ConvertFullyConnectedOperator(const Model& model,
   (*matmul_op->mutable_attr())["T"].set_type(DT_FLOAT);
   (*matmul_op->mutable_attr())["transpose_a"].set_b(false);
   (*matmul_op->mutable_attr())["transpose_b"].set_b(false);
-  CHECK(model.arrays.count(src_op.inputs[1]));
+  CHECK(model.HasArray(src_op.inputs[1]));
   const string& fc_weights_name =
       WalkUpToConstantArray(model, src_op.inputs[1]);
-  const auto& fc_weights_array = *model.arrays.at(fc_weights_name);
+  const auto& fc_weights_array = model.GetArray(fc_weights_name);
   const auto& fc_weights_shape = fc_weights_array.shape();
   CHECK_EQ(fc_weights_shape.dimensions_count(), 2);
   CreateMatrixShapeTensorConst(reshape_shape, fc_weights_shape.dims(1), -1,
@@ -492,8 +492,8 @@ void ConvertFullyConnectedOperator(const Model& model,
     biasadd_op->add_input(matmul_output);
     biasadd_op->add_input(src_op.inputs[2]);
     (*biasadd_op->mutable_attr())["T"].set_type(DT_FLOAT);
-    CHECK(model.arrays.count(src_op.inputs[2]));
-    const auto& bias_array = *model.arrays.at(src_op.inputs[2]);
+    CHECK(model.HasArray(src_op.inputs[2]));
+    const auto& bias_array = model.GetArray(src_op.inputs[2]);
     // TODO(b/62904716) Bias arrays should be 1-D, and used directly.
     Shape bias_shape_1d = bias_array.shape();
     UnextendShape(&bias_shape_1d, 1);
@@ -625,7 +625,7 @@ void ConvertSoftmaxOperator(const Model& model, const SoftmaxOperator& src_op,
     *reshape_op->add_input() = softmax_size;
     (*reshape_op->mutable_attr())["T"].set_type(DT_FLOAT);
 
-    const auto& input_shape = model.arrays.at(src_op.inputs[0])->shape();
+    const auto& input_shape = model.GetArray(src_op.inputs[0]).shape();
     int32 flattened_size = 1;
     for (int i = 0; i < input_shape.dimensions_count() - 1; ++i) {
       flattened_size *= input_shape.dims(i);
@@ -1013,8 +1013,8 @@ void ConvertLstmCellOperator(const Model& model, const LstmCellOperator& src_op,
   // Op names have been chosen to match the tf.slim LSTM naming
   // as closely as possible.
   const int axis =
-      model.arrays.at(src_op.inputs[LstmCellOperator::PREV_ACTIV_INPUT])
-          ->shape()
+      model.GetArray(src_op.inputs[LstmCellOperator::PREV_ACTIV_INPUT])
+          .shape()
           .dimensions_count() -
       1;
   // Note that DATA_INPUT may have extra size 1 dimensions, but TF concat
@@ -1033,9 +1033,9 @@ void ConvertLstmCellOperator(const Model& model, const LstmCellOperator& src_op,
 
   // Write weights
   const string weights_output = base + "weights";
-  CHECK(model.arrays.count(src_op.inputs[LstmCellOperator::WEIGHTS_INPUT]));
+  CHECK(model.HasArray(src_op.inputs[LstmCellOperator::WEIGHTS_INPUT]));
   const auto& weights_array =
-      *model.arrays.at(src_op.inputs[LstmCellOperator::WEIGHTS_INPUT]);
+      model.GetArray(src_op.inputs[LstmCellOperator::WEIGHTS_INPUT]);
   // Convert 4D FullyConnected weights into 2D matrix
   const auto& weights_shape = weights_array.shape();
   CHECK_EQ(weights_shape.dimensions_count(), 2);
@@ -1059,9 +1059,9 @@ void ConvertLstmCellOperator(const Model& model, const LstmCellOperator& src_op,
 
   // Write biases
   const string biases_output = base + "biases";
-  CHECK(model.arrays.count(src_op.inputs[LstmCellOperator::BIASES_INPUT]));
+  CHECK(model.HasArray(src_op.inputs[LstmCellOperator::BIASES_INPUT]));
   const auto& bias_array =
-      *model.arrays.at(src_op.inputs[LstmCellOperator::BIASES_INPUT]);
+      model.GetArray(src_op.inputs[LstmCellOperator::BIASES_INPUT]);
   // TODO(b/62904716) Bias arrays should be 1-D, and used directly.
   Shape bias_shape_1d = bias_array.shape();
   UnextendShape(&bias_shape_1d, 1);
@@ -1557,7 +1557,7 @@ void AddPlaceholderForRNNState(const Model& model, const string& name, int size,
   (*placeholder->mutable_attr())["dtype"].set_type(DT_FLOAT);
 
   auto* shape = (*placeholder->mutable_attr())["shape"].mutable_shape();
-  const auto& state_array = *model.arrays.at(name);
+  const auto& state_array = model.GetArray(name);
   if (state_array.has_shape()) {
     const auto& state_shape = state_array.shape();
     const int kDims = state_shape.dimensions_count();
@@ -1574,7 +1574,7 @@ void ExportTensorFlowGraphDefImplementation(const Model& model,
                                             GraphDef* tensorflow_graph) {
   for (const auto& input_array : model.flags.input_arrays()) {
     AddPlaceholder(input_array.name(),
-                   model.arrays.at(input_array.name())->data_type,
+                   model.GetArray(input_array.name()).data_type,
                    tensorflow_graph);
   }
   for (const auto& rnn_state : model.flags.rnn_states()) {
@@ -1588,7 +1588,7 @@ void ExportTensorFlowGraphDefImplementation(const Model& model,
   // by the above operators export. It's important that this comes
   // after, as some operators need to export arrays that they reference
   // in a specific way, rather than in the generic way done below.
-  for (const auto& array_pair : model.arrays) {
+  for (const auto& array_pair : model.GetArrayMap()) {
     const string& array_name = array_pair.first;
     const auto& array = *array_pair.second;
     if (array.buffer) {
index 3bde9b0169ddfb7fc37657122e2e8eb65ccbdf6d..56f48d47de4e86ece76ceef1d09a25f50957a8dc 100644 (file)
@@ -35,7 +35,7 @@ bool ConvertExpandDimsToReshape::Run(Model* model, std::size_t op_index) {
   CHECK_EQ(expand_op->inputs.size(), 2);
   CHECK_EQ(expand_op->outputs.size(), 1);
 
-  const auto& input_array = *model->arrays[expand_op->inputs[0]];
+  const auto& input_array = model->GetArray(expand_op->inputs[0]);
   if (!input_array.has_shape()) {
     // Yield until input dims have been resolved.
     return false;
@@ -46,7 +46,7 @@ bool ConvertExpandDimsToReshape::Run(Model* model, std::size_t op_index) {
     return false;
   }
 
-  const auto& axis_array = *model->arrays[expand_op->inputs[1]];
+  const auto& axis_array = model->GetArray(expand_op->inputs[1]);
   if (!axis_array.has_shape()) {
     // Yield until input axis array shape has been resolved.
     return false;
@@ -86,7 +86,7 @@ bool ConvertExpandDimsToReshape::Run(Model* model, std::size_t op_index) {
   if (IsDiscardableArray(*model, axis_array_name) &&
       CountOpsWithInput(*model, axis_array_name) == 1 &&
       !GetOpWithOutput(*model, axis_array_name)) {
-    model->arrays.erase(axis_array_name);
+    model->EraseArray(axis_array_name);
   }
 
   // Replace the operator in the graph.
index bf454c40c7b50d242d8a7e9eb6b7e579fb0da217..d38db85280d7bd935a47cda70227d383a513fbac 100644 (file)
@@ -58,7 +58,7 @@ bool ConvertPureConvToDepthwise::Run(Model* model, std::size_t op_index) {
   depthwiseconv_op->outputs = {conv_op->outputs[0]};
   if (conv_op->outputs.size() > 1) {
     // delete the im2col array.
-    model->arrays.erase(conv_op->outputs[1]);
+    model->EraseArray(conv_op->outputs[1]);
   }
   depthwiseconv_op->fused_activation_function =
       conv_op->fused_activation_function;
index a234c209240ecb9eeba1d2e416a294be53d221ee..c2b166033c33b777bad88cb712adf8517be1762a 100644 (file)
@@ -29,7 +29,7 @@ bool ConvertTrivialTransposeToReshape::Run(Model* model, std::size_t op_index) {
   TransposeOperator* transpose_op =
       static_cast<TransposeOperator*>(transpose_it->get());
 
-  const auto& output_array = *model->arrays[transpose_op->outputs[0]];
+  const auto& output_array = model->GetArray(transpose_op->outputs[0]);
   if (!output_array.has_shape()) {
     // Yield until PropagateFixedSizes has been run on this op.
     return false;
@@ -70,7 +70,7 @@ bool ConvertTrivialTransposeToReshape::Run(Model* model, std::size_t op_index) {
   // Delete perm array if unused
   if (IsDiscardableArray(*model, perm_array_name) &&
       CountOpsWithInput(*model, perm_array_name) == 1) {
-    model->arrays.erase(perm_array_name);
+    model->EraseArray(perm_array_name);
   }
 
   // Replace the operator in the graph.
index 1735b51e5b6ca517bad62bf55f0cc9f0c21ac440..076415ece8c1039caa32e947fe54ab3e101bec9e 100644 (file)
@@ -35,7 +35,7 @@ bool CreateIm2colArrays::Run(Model* model, std::size_t op_index) {
     // We already have an im2col array
     return false;
   }
-  const auto& weights_array = *model->arrays[conv_op->inputs[1]];
+  const auto& weights_array = model->GetArray(conv_op->inputs[1]);
   if (!weights_array.has_shape()) {
     // We need to yield until weights dims have been resolved, because
     // from the weights dims we determine whether an im2col array is
index 79854cba348227226b1456b71c89746239ebfc06..498c864bde6d656c8318e981204cb42cb3a4d03f 100644 (file)
@@ -53,7 +53,7 @@ std::vector<std::unique_ptr<Operator>>::iterator FindFirstOpWithInput(
 }
 
 void ClearArrayQuantizationParams(const string& array_name, Model* model) {
-  auto* array = model->arrays.at(array_name).get();
+  auto* array = &model->GetArray(array_name);
   CHECK(array->quantization_params);
   for (auto& input_array : *model->flags.mutable_input_arrays()) {
     if (input_array.name() == array_name) {
@@ -77,7 +77,7 @@ void ClearArrayQuantizationParams(const string& array_name, Model* model) {
 
 bool DequantizeArray(const string& array_name,
                      GraphTransformation* transformation, Model* model) {
-  auto* array = model->arrays.at(array_name).get();
+  auto* array = &model->GetArray(array_name);
   if (!array->quantization_params) {
     return false;
   }
index fea360740f4e645e1f00eaed42cbff48f430fe2a..95558ef5ece9a78825daf0203e2f6f6fee6f3cda 100644 (file)
@@ -45,7 +45,7 @@ bool DropFakeQuant::Run(Model* model, std::size_t op_index) {
   // Drop min/max inputs
   for (int i = 1; i < fakequant_op->inputs.size(); i++) {
     if (CountOpsWithInput(*model, fakequant_op->inputs[i]) == 1) {
-      model->arrays.erase(fakequant_op->inputs[i]);
+      model->EraseArray(fakequant_op->inputs[i]);
     }
   }
   fakequant_op->inputs.resize(1);
index a3ed6663bcc80c5fc642a399b1e5c0cf3336973a..f7fd878b7e8b1c834125130ea2a778cecefd3de0 100644 (file)
@@ -32,7 +32,7 @@ bool DropIm2colArrays::Run(Model* model, std::size_t op_index) {
 
   // Drop the im2col array.
   CHECK_EQ(conv_op->outputs.size(), 2);
-  model->arrays.erase(conv_op->outputs[1]);
+  model->EraseArray(conv_op->outputs[1]);
   conv_op->outputs.resize(1);
   AddMessageF("Dropped an im2col array for %s", LogName(*conv_op));
 
index ad4a6f9b78b06fd738da40c2054c07e8f272ee17..88e59664ec427841df6f20686238feacef6a47e9 100644 (file)
@@ -91,7 +91,7 @@ bool FuseActivationFunctions::Run(Model* model, std::size_t op_index) {
   } else {
     LOG(FATAL) << "Unhandled activation function type";
   }
-  model->arrays.erase(ac_op->inputs[0]);
+  model->EraseArray(ac_op->inputs[0]);
   op->outputs[0] = ac_op->outputs[0];
   model->operators.erase(ac_it);
   return true;
index 4619d8bbee2e52483a523277f421de5bfa155635..dcbbead517f26a227363989b5af2a4040c98ff57 100644 (file)
@@ -285,13 +285,13 @@ bool FuseBinaryIntoFollowingAffine::Run(Model* model, std::size_t op_index) {
   AddMessageF("Fusing %s into the following %s", LogName(*binary_op),
               LogName(*following_op));
 
-  model->arrays.erase(binary_op->outputs[0]);
+  model->EraseArray(binary_op->outputs[0]);
   following_op->inputs[0] = binary_op->inputs[index_of_variable_input];
   const auto& old_constant_param_name =
       binary_op->inputs[index_of_constant_input];
   CHECK(IsConstantParameterArray(*model, old_constant_param_name));
   if (CountOpsWithInput(*model, old_constant_param_name) == 1) {
-    model->arrays.erase(old_constant_param_name);
+    model->EraseArray(old_constant_param_name);
   }
   model->operators.erase(binary_it);
   return true;
index 8948653ec38f5a5a6e92cfe9e6bafdbf1aa9a962..5b57178b18d2d60e1f301a1a8b257d8057618550 100644 (file)
@@ -309,7 +309,7 @@ bool FuseBinaryIntoPrecedingAffine::Run(Model* model, std::size_t op_index) {
     LOG(FATAL) << "should not get here";
   }
 
-  model->arrays.erase(preceding_op->outputs[0]);
+  model->EraseArray(preceding_op->outputs[0]);
   preceding_op->outputs[0] = binary_op->outputs[0];
   preceding_op->fused_activation_function =
       binary_op->fused_activation_function;
@@ -317,7 +317,7 @@ bool FuseBinaryIntoPrecedingAffine::Run(Model* model, std::size_t op_index) {
       binary_op->inputs[index_of_constant_input];
   CHECK(IsConstantParameterArray(*model, old_constant_param_name));
   if (CountOpsWithInput(*model, old_constant_param_name) == 1) {
-    model->arrays.erase(old_constant_param_name);
+    model->EraseArray(old_constant_param_name);
   }
   model->operators.erase(binary_it);
   return true;
index f861c4147a04fe31b7236bfa22ed4627f7742d09..2340f0e850ef6b995e9147af00e85c0e15887c7f 100644 (file)
@@ -31,13 +31,13 @@ namespace {
 
 void PrintModelStats(const string& label, const Model& model) {
   int quantized_arrays = 0;
-  for (const auto& array : model.arrays) {
+  for (const auto& array : model.GetArrayMap()) {
     if (array.second->quantization_params) {
       quantized_arrays++;
     }
   }
   LOG(INFO) << label << ": " << model.operators.size() << " operators, "
-            << model.arrays.size() << " arrays (" << quantized_arrays
+            << model.GetArrayMap().size() << " arrays (" << quantized_arrays
             << " quantized)";
 }
 
@@ -91,14 +91,9 @@ void DiscardUselessConnectedComponentsAndRNNBackEdges(Model* model) {
     }
   } while (found_new_useful_arrays);
   // Erase arrays that aren't useful, and that are discardable.
-  for (auto it = model->arrays.begin(); it != model->arrays.end();) {
-    if (useful_arrays.count(it->first) ||
-        !IsDiscardableArray(*model, it->first)) {
-      ++it;
-    } else {
-      it = model->arrays.erase(it);
-    }
-  }
+  model->EraseArrays([&](const string& name) {
+    return (!useful_arrays.count(name) && IsDiscardableArray(*model, name));
+  });
   // Erase operators that do not produce a useful output array.
   for (auto it = model->operators.begin(); it != model->operators.end();) {
     // Only need to test the first output, as we simultaneously added all of
@@ -118,8 +113,8 @@ void DiscardUselessConnectedComponentsAndRNNBackEdges(Model* model) {
   std::vector<RnnState> rnn_states_to_keep;
   for (const auto& rnn_state : model->flags.rnn_states()) {
     const bool dangling =
-        !model->arrays.count(rnn_state.back_edge_source_array()) ||
-        !model->arrays.count(rnn_state.state_array());
+        !model->HasArray(rnn_state.back_edge_source_array()) ||
+        !model->HasArray(rnn_state.state_array());
     if (dangling) {
       CHECK(rnn_state.discardable());
     } else {
index 01b75e37c691d48fabf8832af04543be3f5eb3bc..419a0776a6b987a18df059d3c1d4bf4370cd24d8 100644 (file)
@@ -150,19 +150,19 @@ bool IdentifyL2Normalization::Run(Model* model, std::size_t op_index) {
 
   // Erase the subgraph that is now replaced by L2Normalization
   model->operators.erase(FindOperator(model, square_op));
-  model->arrays.erase(sum_op->inputs[0]);
+  model->EraseArray(sum_op->inputs[0]);
   if (sum_op->inputs.size() > 1) {
-    model->arrays.erase(sum_op->inputs[1]);
+    model->EraseArray(sum_op->inputs[1]);
   }
   model->operators.erase(FindOperator(model, sum_op));
   if (add_op) {
-    model->arrays.erase(add_op->inputs[0]);
-    model->arrays.erase(add_op->inputs[1]);
+    model->EraseArray(add_op->inputs[0]);
+    model->EraseArray(add_op->inputs[1]);
     model->operators.erase(FindOperator(model, add_op));
   }
-  model->arrays.erase(sqrt_or_rsqrt_op->inputs[0]);
+  model->EraseArray(sqrt_or_rsqrt_op->inputs[0]);
   model->operators.erase(FindOperator(model, sqrt_or_rsqrt_op));
-  model->arrays.erase(div_or_mul_op->inputs[1]);
+  model->EraseArray(div_or_mul_op->inputs[1]);
   model->operators.erase(FindOperator(model, div_or_mul_op));
   return true;
 }
index 1865416fc2226d663dfd51a5c0a0e2129caf485c..e4d52476c649de53b3ab663f53ce7a5538dbb5ab 100644 (file)
@@ -92,8 +92,8 @@ bool IdentifyL2Pool::Run(Model* model, std::size_t op_index) {
   AddMessageF("Creating %s replacing equivalent subgraph", LogName(*l2pool_op));
 
   // Erase intermediate arrays, keeping input to square op.
-  model->arrays.erase(avpool_op->inputs[0]);
-  model->arrays.erase(sqrt_op->inputs[0]);
+  model->EraseArray(avpool_op->inputs[0]);
+  model->EraseArray(sqrt_op->inputs[0]);
 
   // Erase three operators being replaced.
   model->operators.erase(FindOperator(model, square_op));
index cfc77024e7e56038878570c9d3a462715a53ae3f..d36e95060937d6af0789766bcb29ae70cef4569d 100644 (file)
@@ -89,12 +89,12 @@ bool IdentifyRelu1::Run(Model* model, std::size_t op_index) {
   AddMessageF("Creating %s replacing equivalent subgraph", LogName(*relu1_op));
 
   // Erase Maximum scalar input & operator
-  model->arrays.erase(maximum_op->inputs[scalar_input_index]);
+  model->EraseArray(maximum_op->inputs[scalar_input_index]);
   model->operators.erase(FindOperator(model, maximum_op));
 
   // Erase Minimum inputs & operator
-  model->arrays.erase(minimum_op->inputs[0]);
-  model->arrays.erase(minimum_op->inputs[1]);
+  model->EraseArray(minimum_op->inputs[0]);
+  model->EraseArray(minimum_op->inputs[1]);
   model->operators.erase(FindOperator(model, minimum_op));
 
   return true;
index 29b55d9bfcdba39840fe0262140c2bffbf7e0b72..f0d107232b4517115aa3f64b39b825dbaffb83ce 100644 (file)
@@ -27,7 +27,7 @@ namespace {
 void SetDataTypeForAllOutputs(Model* model, Operator* op,
                               ArrayDataType data_type) {
   for (const auto& output : op->outputs) {
-    model->arrays[output]->data_type = data_type;
+    model->GetArray(output).data_type = data_type;
   }
 }
 }  // namespace
@@ -39,7 +39,7 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) {
   // If the data type of some input is unknown, we need to yield.
   for (const auto& input : op->inputs) {
     if (!model->IsOptionalArray(input) &&
-        model->arrays[input]->data_type == ArrayDataType::kNone) {
+        model->GetArray(input).data_type == ArrayDataType::kNone) {
       return false;
     }
   }
@@ -47,7 +47,7 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) {
   // end if we changed anything, and return the correct boolean value.
   std::unordered_map<string, ArrayDataType> old_output_data_types;
   for (const auto& output : op->outputs) {
-    old_output_data_types[output] = model->arrays[output]->data_type;
+    old_output_data_types[output] = model->GetArray(output).data_type;
   }
   // Do the actual output data types propagation.
   if (op->type == OperatorType::kDequantize ||
@@ -69,18 +69,18 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) {
              op->type == OperatorType::kFill) {
     // These operators produce an output with the same type as their 2nd input
     CHECK_GE(op->inputs.size(), 2);
-    const ArrayDataType data_type = model->arrays[op->inputs[1]]->data_type;
+    const ArrayDataType data_type = model->GetArray(op->inputs[1]).data_type;
     SetDataTypeForAllOutputs(model, op, data_type);
   } else if (op->type == OperatorType::kCast) {
     // Data type of the Cast op is specified.
     CHECK_EQ(op->outputs.size(), 1);
     auto* cast_op = static_cast<CastOperator*>(op);
-    model->arrays[op->outputs[0]]->data_type = cast_op->dst_data_type;
+    model->GetArray(op->outputs[0]).data_type = cast_op->dst_data_type;
   } else if (op->type == OperatorType::kArgMax) {
     // Data type of the ArgMax op is specified.
     CHECK_EQ(op->outputs.size(), 1);
     auto* argmax_op = static_cast<ArgMaxOperator*>(op);
-    model->arrays[op->outputs[0]]->data_type = argmax_op->output_data_type;
+    model->GetArray(op->outputs[0]).data_type = argmax_op->output_data_type;
   } else if (op->type == OperatorType::kRange) {
     auto* range_op = static_cast<RangeOperator*>(op);
     // Output type of the Range op can be set via an attribute
@@ -91,7 +91,7 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) {
     } else {
       // Otherwise use the first input
       CHECK_GE(op->inputs.size(), 1);
-      data_type = model->arrays[op->inputs[0]]->data_type;
+      data_type = model->GetArray(op->inputs[0]).data_type;
     }
     CHECK_EQ(op->outputs.size(), 1);
     SetDataTypeForAllOutputs(model, op, data_type);
@@ -103,7 +103,7 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) {
     for (int i = 0; i < unsupported_op->output_data_types.size(); ++i) {
       auto output = op->outputs[i];
       auto data_type = unsupported_op->output_data_types[i];
-      model->arrays[output]->data_type = data_type;
+      model->GetArray(output).data_type = data_type;
     }
   } else if (op->type == OperatorType::kExpandDims) {
     // Yield on ExpandDim until it is converted to Reshape
@@ -111,12 +111,12 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) {
   } else {
     // These operators produce outputs with the same type as their 1st input
     CHECK_GT(op->inputs.size(), 0);
-    const ArrayDataType data_type = model->arrays[op->inputs[0]]->data_type;
+    const ArrayDataType data_type = model->GetArray(op->inputs[0]).data_type;
     SetDataTypeForAllOutputs(model, op, data_type);
   }
   // Return true if any output data type changed, false if none changed.
   for (const auto& output : op->outputs) {
-    if (old_output_data_types[output] != model->arrays[output]->data_type) {
+    if (old_output_data_types[output] != model->GetArray(output).data_type) {
       return true;
     }
   }
index a939efb4dbbc6ec0af2e44270d7c028eff882b70..ff0a3bd8819dd6f3413cac11b0a64b727a37bd3d 100644 (file)
@@ -85,7 +85,7 @@ void ComputeBinaryOperatorOutputSize(const Shape& input_shape1,
 
 int GetOutputDepthFromWeights(const Model& model, const Operator& op) {
   const string& weights_name = op.inputs[1];
-  const auto& weights_shape = model.arrays.at(weights_name)->shape();
+  const auto& weights_shape = model.GetArray(weights_name).shape();
   if (op.type == OperatorType::kConv ||
       op.type == OperatorType::kFullyConnected) {
     return weights_shape.dims(0);
@@ -98,7 +98,7 @@ int GetOutputDepthFromWeights(const Model& model, const Operator& op) {
 
 bool EnsureBiasVectorShape(Model* model, Operator* op) {
   const string& weights_name = op->inputs[1];
-  const auto& weights_array = *model->arrays[weights_name];
+  const auto& weights_array = model->GetArray(weights_name);
   // Yield until weights shape has been resolved.
   if (!weights_array.has_shape()) {
     return false;
@@ -107,7 +107,7 @@ bool EnsureBiasVectorShape(Model* model, Operator* op) {
   if (op->inputs.size() < 3) {
     return false;
   }
-  auto& bias_array = *model->arrays[op->inputs[2]];
+  auto& bias_array = model->GetArray(op->inputs[2]);
   if (bias_array.has_shape()) {
     return true;
   }
@@ -126,7 +126,7 @@ void ProcessConvOperator(Model* model, ConvOperator* op) {
     return;
   }
 
-  const auto& input_array = *model->arrays[op->inputs[0]];
+  const auto& input_array = model->GetArray(op->inputs[0]);
   // Yield until input dims have been resolved.
   if (!input_array.has_shape()) {
     return;
@@ -134,7 +134,7 @@ void ProcessConvOperator(Model* model, ConvOperator* op) {
   const auto& input_shape = input_array.shape();
   CHECK_EQ(input_shape.dimensions_count(), 4);
 
-  const auto& weights_array = *model->arrays[op->inputs[1]];
+  const auto& weights_array = model->GetArray(op->inputs[1]);
   // Yield until weights dims have been resolved.
   if (!weights_array.has_shape()) {
     return;
@@ -156,7 +156,7 @@ void ProcessConvOperator(Model* model, ConvOperator* op) {
   if (op->outputs.size() == 2) {
     const auto& output_shape = output_array.shape();
     const int input_depth = weights_shape.dims(3);
-    auto& im2col_array = *model->arrays[op->outputs[1]];
+    auto& im2col_array = model->GetArray(op->outputs[1]);
     im2col_array.copy_shape(Shape{output_shape.dims(0), output_shape.dims(1),
                                   output_shape.dims(2),
                                   input_depth * kheight * kwidth});
@@ -168,7 +168,7 @@ void ProcessDepthwiseConvOperator(Model* model, DepthwiseConvOperator* op) {
     return;
   }
 
-  const auto& input_array = *model->arrays[op->inputs[0]];
+  const auto& input_array = model->GetArray(op->inputs[0]);
   // Yield until input dims have been resolved.
   if (!input_array.has_shape()) {
     return;
@@ -176,7 +176,7 @@ void ProcessDepthwiseConvOperator(Model* model, DepthwiseConvOperator* op) {
   const auto& input_shape = input_array.shape();
   CHECK_EQ(input_shape.dimensions_count(), 4);
 
-  const auto& weights_array = *model->arrays[op->inputs[1]];
+  const auto& weights_array = model->GetArray(op->inputs[1]);
   // Yield until weights dims have been resolved.
   if (!weights_array.has_shape()) {
     return;
@@ -209,7 +209,7 @@ void ProcessDepthwiseConvOperator(Model* model, DepthwiseConvOperator* op) {
 }
 
 void ProcessDepthToSpaceOperator(Model* model, DepthToSpaceOperator* op) {
-  const auto& input_array = *model->arrays[op->inputs[0]];
+  const auto& input_array = model->GetArray(op->inputs[0]);
   // Yield until input dims have been resolved.
   if (!input_array.has_shape()) {
     return;
@@ -232,7 +232,7 @@ void ProcessDepthToSpaceOperator(Model* model, DepthToSpaceOperator* op) {
 }
 
 void ProcessSpaceToDepthOperator(Model* model, SpaceToDepthOperator* op) {
-  const auto& input_array = *model->arrays[op->inputs[0]];
+  const auto& input_array = model->GetArray(op->inputs[0]);
   // Yield until input dims have been resolved.
   if (!input_array.has_shape()) {
     return;
@@ -258,7 +258,7 @@ void ProcessSpaceToDepthOperator(Model* model, SpaceToDepthOperator* op) {
 void ProcessFillOperator(Model* model, FillOperator* op) {
   CHECK_EQ(op->inputs.size(), 2);
   CHECK_EQ(op->outputs.size(), 1);
-  auto& output_array = *model->arrays[op->outputs[0]];
+  auto& output_array = model->GetArray(op->outputs[0]);
   if (output_array.has_shape()) {
     // We have already run
     return;
@@ -287,7 +287,7 @@ void ProcessFullyConnectedOperator(Model* model, FullyConnectedOperator* op) {
     return;
   }
 
-  const auto& input_array = *model->arrays[op->inputs[0]];
+  const auto& input_array = model->GetArray(op->inputs[0]);
   // Yield until input dims have been resolved.
   if (!input_array.has_shape()) {
     return;
@@ -295,7 +295,7 @@ void ProcessFullyConnectedOperator(Model* model, FullyConnectedOperator* op) {
   const auto& input_shape = input_array.shape();
   CHECK_GE(input_shape.dimensions_count(), 1);
 
-  const auto& weights_array = *model->arrays[op->inputs[1]];
+  const auto& weights_array = model->GetArray(op->inputs[1]);
   // Yield until weights dims have been resolved.
   if (!weights_array.has_shape()) {
     return;
@@ -315,13 +315,13 @@ void ProcessFullyConnectedOperator(Model* model, FullyConnectedOperator* op) {
 
 void ProcessTensorFlowReshapeOperator(Model* model,
                                       TensorFlowReshapeOperator* op) {
-  auto& output_array = *model->arrays[op->outputs[0]];
+  auto& output_array = model->GetArray(op->outputs[0]);
   if (output_array.has_shape()) {
     // We have already run
     return;
   }
 
-  const auto& input_array = *model->arrays[op->inputs[0]];
+  const auto& input_array = model->GetArray(op->inputs[0]);
   if (!input_array.has_shape()) {
     // Yield until input dims have been resolved.
     return;
@@ -377,14 +377,14 @@ void ProcessTensorFlowReshapeOperator(Model* model,
 }
 
 void ProcessSimpleOperator(Model* model, Operator* op) {
-  const auto& input_array = *model->arrays[op->inputs[0]];
+  const auto& input_array = model->GetArray(op->inputs[0]);
   // Yield until input dims have been resolved.
   if (!input_array.has_shape()) {
     return;
   }
 
   const string& output_name = op->outputs[0];
-  auto& output_array = *model->arrays[output_name];
+  auto& output_array = model->GetArray(output_name);
   if (output_array.has_shape()) {
     return;
   }
@@ -394,14 +394,14 @@ void ProcessSimpleOperator(Model* model, Operator* op) {
 
 void ProcessSimpleBinaryOperator(Model* model, Operator* op) {
   CHECK_EQ(op->inputs.size(), 2);
-  const auto& input0_array = *model->arrays[op->inputs[0]];
-  const auto& input1_array = *model->arrays[op->inputs[1]];
+  const auto& input0_array = model->GetArray(op->inputs[0]);
+  const auto& input1_array = model->GetArray(op->inputs[1]);
   // Yield until input dims have been resolved.
   if (!input0_array.has_shape() || !input1_array.has_shape()) {
     return;
   }
   const string& output_name = op->outputs[0];
-  auto& output_array = *model->arrays[output_name];
+  auto& output_array = model->GetArray(output_name);
   ComputeBinaryOperatorOutputSize(input0_array.shape(), input1_array.shape(),
                                   &output_array);
 }
@@ -424,11 +424,11 @@ bool KeepDims(const Operator& op) {
 
 void ProcessTensorFlowReductionOperator(Model* model, Operator* op) {
   CHECK_LE(op->inputs.size(), 2);
-  auto& output_array = *model->arrays[op->outputs[0]];
+  auto& output_array = model->GetArray(op->outputs[0]);
   if (output_array.has_shape()) {
     return;
   }
-  const auto& input_array = *model->arrays[op->inputs[0]];
+  const auto& input_array = model->GetArray(op->inputs[0]);
   if (!input_array.has_shape()) {
     return;
   }
@@ -436,7 +436,7 @@ void ProcessTensorFlowReductionOperator(Model* model, Operator* op) {
   const bool keep_dims = KeepDims(*op);
   if (op->inputs.size() == 2) {
     // There is a reduction_indices input.
-    const auto& reduction_array = *model->arrays[op->inputs[1]];
+    const auto& reduction_array = model->GetArray(op->inputs[1]);
     if (!reduction_array.buffer) {
       return;
     }
@@ -476,11 +476,11 @@ void ProcessSliceOperator(Model* model, SliceOperator* op) {
   if (op->begin.empty()) return;
 
   // Yield until input dims have been resolved.
-  const auto& input_array = *model->arrays[op->inputs[0]];
+  const auto& input_array = model->GetArray(op->inputs[0]);
   if (!input_array.has_shape()) return;
   const Shape& input_shape = input_array.shape();
 
-  auto& output_array = *model->arrays[op->outputs[0]];
+  auto& output_array = model->GetArray(op->outputs[0]);
   if (output_array.has_shape()) return;
 
   CHECK_EQ(input_shape.dims().size(), op->size.size());
@@ -500,7 +500,7 @@ void ProcessSliceOperator(Model* model, SliceOperator* op) {
 
 void ProcessReorderAxesOperator(Model* model, ReorderAxesOperator* op) {
   const string& input_name = op->inputs[0];
-  const auto& input_array = *model->arrays[input_name];
+  const auto& input_array = model->GetArray(input_name);
   // Yield until input dims have been resolved.
   if (!input_array.has_shape()) {
     return;
@@ -515,20 +515,20 @@ void ProcessReorderAxesOperator(Model* model, ReorderAxesOperator* op) {
 void ProcessConcatenationOperator(Model* model, ConcatenationOperator* op) {
   // Yield until input dims have been resolved.
   for (const auto& input_name : op->inputs) {
-    auto& input_array = *model->arrays[input_name];
+    auto& input_array = model->GetArray(input_name);
     if (!input_array.has_shape()) {
       return;
     }
   }
   auto& output_array = model->GetArray(op->outputs[0]);
   // Use 0 input as basis for output dimensions.
-  const auto& first_input_array = *model->arrays[op->inputs[0]];
+  const auto& first_input_array = model->GetArray(op->inputs[0]);
   output_array.copy_shape(first_input_array.shape());
   // Determine the concat size, and enfore that all inputs have
   // the same dimensions count.
   int concat_size = 0;
   for (const auto& input_name : op->inputs) {
-    auto& input_array = *model->arrays[input_name];
+    auto& input_array = model->GetArray(input_name);
     CHECK(input_array.has_shape());
     if (input_array.shape().dimensions_count() == 0) {
       continue;
@@ -548,16 +548,16 @@ void ProcessConcatenationOperator(Model* model, ConcatenationOperator* op) {
 
 void ProcessRangeOperator(Model* model, RangeOperator* op) {
   CHECK_EQ(op->inputs.size(), 3);
-  const auto& start_array = *model->arrays[op->inputs[0]];
+  const auto& start_array = model->GetArray(op->inputs[0]);
   if (!start_array.has_shape()) {
     // Yield until input dims have been resolved.
     return;
   }
-  const auto& limit_array = *model->arrays[op->inputs[1]];
+  const auto& limit_array = model->GetArray(op->inputs[1]);
   if (!limit_array.has_shape()) {
     return;
   }
-  const auto& delta_array = *model->arrays[op->inputs[2]];
+  const auto& delta_array = model->GetArray(op->inputs[2]);
   if (!delta_array.has_shape()) {
     return;
   }
@@ -599,7 +599,7 @@ void ProcessRangeOperator(Model* model, RangeOperator* op) {
 void ProcessTensorFlowSplitOperator(Model* model, TensorFlowSplitOperator* op) {
   CHECK_EQ(op->inputs.size(), 2);
   const string& input_name = op->inputs[1];
-  const auto& input_array = *model->arrays[input_name];
+  const auto& input_array = model->GetArray(input_name);
   // Yield until input dims have been resolved.
   if (!input_array.has_shape()) {
     return;
@@ -618,13 +618,13 @@ void ProcessTensorFlowSplitOperator(Model* model, TensorFlowSplitOperator* op) {
 
   CHECK_EQ(op->outputs.size(), op->num_split);
   for (const auto& output : op->outputs) {
-    model->arrays[output]->copy_shape(output_shape);
+    model->GetArray(output).copy_shape(output_shape);
   }
 }
 
 void ProcessAveragePoolOperator(Model* model, AveragePoolOperator* op) {
   const string& input_name = op->inputs[0];
-  const auto& input_array = *model->arrays[input_name];
+  const auto& input_array = model->GetArray(input_name);
   // Yield until input dims have been resolved.
   if (!input_array.has_shape()) {
     return;
@@ -641,7 +641,7 @@ void ProcessAveragePoolOperator(Model* model, AveragePoolOperator* op) {
 
 void ProcessMaxPoolOperator(Model* model, MaxPoolOperator* op) {
   const string& input_name = op->inputs[0];
-  const auto& input_array = *model->arrays[input_name];
+  const auto& input_array = model->GetArray(input_name);
   // Yield until input dims have been resolved.
   if (!input_array.has_shape()) {
     return;
@@ -658,7 +658,7 @@ void ProcessMaxPoolOperator(Model* model, MaxPoolOperator* op) {
 
 void ProcessL2PoolOperator(Model* model, L2PoolOperator* op) {
   const string& input_name = op->inputs[0];
-  const auto& input_array = *model->arrays[input_name];
+  const auto& input_array = model->GetArray(input_name);
   // Yield until input dims have been resolved.
   if (!input_array.has_shape()) {
     return;
@@ -679,14 +679,14 @@ void ProcessResizeBilinearOperator(Model* model, ResizeBilinearOperator* op) {
   CHECK_EQ(op->inputs.size(), 2);
   CHECK_EQ(op->outputs.size(), 1);
 
-  if (!model->arrays[op->inputs[0]]->has_shape() ||
-      !model->arrays[op->inputs[1]]->has_shape()) {
+  if (!model->GetArray(op->inputs[0]).has_shape() ||
+      !model->GetArray(op->inputs[1]).has_shape()) {
     return;
   }
-  const auto& input_data_shape = model->arrays[op->inputs[0]]->shape();
+  const auto& input_data_shape = model->GetArray(op->inputs[0]).shape();
 
   const string& output_size_name = op->inputs[1];
-  const auto& output_size_array = *model->arrays[output_size_name];
+  const auto& output_size_array = model->GetArray(output_size_name);
   CHECK(output_size_array.data_type == ArrayDataType::kInt32);
   CHECK(output_size_array.has_shape());
   const auto& output_size_shape = output_size_array.shape();
@@ -697,9 +697,9 @@ void ProcessResizeBilinearOperator(Model* model, ResizeBilinearOperator* op) {
   }
   std::vector<int32> output_shape =
       output_size_array.GetBuffer<ArrayDataType::kInt32>().data;
-  model->arrays[op->outputs[0]]->copy_shape(
-      Shape({input_data_shape.dims(0), output_shape[0], output_shape[1],
-             input_data_shape.dims(3)}));
+  model->GetArray(op->outputs[0])
+      .copy_shape(Shape({input_data_shape.dims(0), output_shape[0],
+                         output_shape[1], input_data_shape.dims(3)}));
 }
 
 void ProcessLstmCellOperator(Model* model, LstmCellOperator* op) {
@@ -708,7 +708,7 @@ void ProcessLstmCellOperator(Model* model, LstmCellOperator* op) {
   QCHECK_EQ(op->outputs.size(), LstmCellOperator::NUM_OUTPUTS);
 
   const auto& input_array =
-      *model->arrays[op->inputs[LstmCellOperator::DATA_INPUT]];
+      model->GetArray(op->inputs[LstmCellOperator::DATA_INPUT]);
   // Yield until all input dims have been resolved.
   if (!input_array.has_shape()) {
     return;
@@ -717,7 +717,7 @@ void ProcessLstmCellOperator(Model* model, LstmCellOperator* op) {
   CHECK_GE(input_shape.dimensions_count(), 2);
 
   const auto& prev_activ_array =
-      *model->arrays[op->inputs[LstmCellOperator::PREV_ACTIV_INPUT]];
+      model->GetArray(op->inputs[LstmCellOperator::PREV_ACTIV_INPUT]);
   // Yield until all input dims have been resolved.
   if (!prev_activ_array.has_shape()) {
     return;
@@ -726,7 +726,7 @@ void ProcessLstmCellOperator(Model* model, LstmCellOperator* op) {
   CHECK_GE(prev_activ_shape.dimensions_count(), 2);
 
   const auto& weights_array =
-      *model->arrays[op->inputs[LstmCellOperator::WEIGHTS_INPUT]];
+      model->GetArray(op->inputs[LstmCellOperator::WEIGHTS_INPUT]);
   // Yield until weights dims have been resolved.
   if (!weights_array.has_shape()) {
     return;
@@ -735,7 +735,7 @@ void ProcessLstmCellOperator(Model* model, LstmCellOperator* op) {
   CHECK_EQ(weights_shape.dimensions_count(), 2);
 
   const auto& bias_array =
-      *model->arrays[op->inputs[LstmCellOperator::BIASES_INPUT]];
+      model->GetArray(op->inputs[LstmCellOperator::BIASES_INPUT]);
   // Yield until bias dims have been resolved.
   if (!bias_array.has_shape()) {
     return;
@@ -744,7 +744,7 @@ void ProcessLstmCellOperator(Model* model, LstmCellOperator* op) {
   CHECK_GE(bias_shape.dimensions_count(), 1);
 
   const auto& prev_state_array =
-      *model->arrays[op->inputs[LstmCellOperator::PREV_STATE_INPUT]];
+      model->GetArray(op->inputs[LstmCellOperator::PREV_STATE_INPUT]);
   // Yield until all input dims have been resolved.
   if (!prev_state_array.has_shape()) {
     return;
@@ -784,7 +784,7 @@ void ProcessLstmCellOperator(Model* model, LstmCellOperator* op) {
 }
 
 void ProcessSpaceToBatchNDOperator(Model* model, SpaceToBatchNDOperator* op) {
-  const auto& input_array = *model->arrays[op->inputs[0]];
+  const auto& input_array = model->GetArray(op->inputs[0]);
   // Yield until input dims have been resolved.
   if (!input_array.has_shape()) {
     return;
@@ -797,8 +797,8 @@ void ProcessSpaceToBatchNDOperator(Model* model, SpaceToBatchNDOperator* op) {
   const auto input_height = input_shape.dims(1);
   const auto input_width = input_shape.dims(2);
 
-  const auto& block_shape_array = *model->arrays[op->inputs[1]];
-  const auto& paddings_array = *model->arrays[op->inputs[2]];
+  const auto& block_shape_array = model->GetArray(op->inputs[1]);
+  const auto& paddings_array = model->GetArray(op->inputs[2]);
   const auto& block_shape_array_shape = block_shape_array.shape();
   const auto& paddings_array_shape = paddings_array.shape();
   QCHECK_EQ(block_shape_array_shape.dimensions_count(), 1);
@@ -830,13 +830,13 @@ void ProcessSpaceToBatchNDOperator(Model* model, SpaceToBatchNDOperator* op) {
   int output_height = height_with_paddings / block_height;
   int output_width = width_with_paddings / block_width;
 
-  model->arrays[op->outputs[0]]->copy_shape(
-      Shape({input_shape.dims(0) * block_height * block_width, output_height,
-             output_width, input_shape.dims(3)}));
+  model->GetArray(op->outputs[0])
+      .copy_shape(Shape({input_shape.dims(0) * block_height * block_width,
+                         output_height, output_width, input_shape.dims(3)}));
 }
 
 void ProcessBatchToSpaceNDOperator(Model* model, BatchToSpaceNDOperator* op) {
-  const auto& input_array = *model->arrays[op->inputs[0]];
+  const auto& input_array = model->GetArray(op->inputs[0]);
   // Yield until input dims have been resolved.
   if (!input_array.has_shape()) {
     return;
@@ -846,8 +846,8 @@ void ProcessBatchToSpaceNDOperator(Model* model, BatchToSpaceNDOperator* op) {
   const auto input_height = input_shape.dims(1);
   const auto input_width = input_shape.dims(2);
 
-  const auto& block_shape_array = *model->arrays[op->inputs[1]];
-  const auto& crops_array = *model->arrays[op->inputs[2]];
+  const auto& block_shape_array = model->GetArray(op->inputs[1]);
+  const auto& crops_array = model->GetArray(op->inputs[2]);
   const auto& block_shape_array_shape = block_shape_array.shape();
   const auto& crops_array_shape = crops_array.shape();
   QCHECK_EQ(block_shape_array_shape.dimensions_count(), 1);
@@ -882,15 +882,15 @@ void ProcessBatchToSpaceNDOperator(Model* model, BatchToSpaceNDOperator* op) {
   int output_height = input_height * block_height;
   int output_width = input_width * block_width;
 
-  model->arrays[op->outputs[0]]->copy_shape(
-      Shape({input_shape.dims(0) / (block_height * block_width), output_height,
-             output_width, input_shape.dims(3)}));
+  model->GetArray(op->outputs[0])
+      .copy_shape(Shape({input_shape.dims(0) / (block_height * block_width),
+                         output_height, output_width, input_shape.dims(3)}));
 }
 
 void ProcessGatherOperator(Model* model, GatherOperator* op) {
-  const auto& input_array = *model->arrays[op->inputs[0]];
-  const auto& indices_array = *model->arrays[op->inputs[1]];
-  auto& output_array = *model->arrays[op->outputs[0]];
+  const auto& input_array = model->GetArray(op->inputs[0]);
+  const auto& indices_array = model->GetArray(op->inputs[1]);
+  auto& output_array = model->GetArray(op->outputs[0]);
 
   // Bail if we already know the output shape.
   if (output_array.has_shape()) {
@@ -924,7 +924,7 @@ void ProcessPadOperator(Model* model, PadOperator* op) {
   CHECK_EQ(op->inputs.size(), 2);
   CHECK_EQ(op->outputs.size(), 1);
 
-  const auto& input_array = *model->arrays[op->inputs[0]];
+  const auto& input_array = model->GetArray(op->inputs[0]);
 
   // Yield until input dims have been resolved.
   if (!input_array.has_shape()) return;
@@ -932,7 +932,7 @@ void ProcessPadOperator(Model* model, PadOperator* op) {
   if (op->left_padding.empty()) return;
   CHECK_EQ(op->left_padding.size(), op->right_padding.size());
 
-  auto& output_array = *model->arrays[op->outputs[0]];
+  auto& output_array = model->GetArray(op->outputs[0]);
   if (output_array.has_shape()) return;
 
   Shape output_shape = input_array.shape();
@@ -949,13 +949,13 @@ void ProcessPadOperator(Model* model, PadOperator* op) {
 void ProcessRankOperator(Model* model, RankOperator* op) {
   CHECK_GE(op->inputs.size(), 1);
   CHECK_EQ(op->outputs.size(), 1);
-  auto& output_array = *model->arrays[op->outputs[0]];
+  auto& output_array = model->GetArray(op->outputs[0]);
   if (output_array.has_shape()) {
     // Shape already propagated
     return;
   }
 
-  const auto& input_array = *model->arrays[op->inputs[0]];
+  const auto& input_array = model->GetArray(op->inputs[0]);
   if (!input_array.has_shape()) {
     // Yield until input dims have been resolved.
     return;
@@ -970,13 +970,13 @@ void ProcessRankOperator(Model* model, RankOperator* op) {
 void ProcessShapeOperator(Model* model, TensorFlowShapeOperator* op) {
   CHECK_GE(op->inputs.size(), 1);
   CHECK_EQ(op->outputs.size(), 1);
-  auto& output_array = *model->arrays[op->outputs[0]];
+  auto& output_array = model->GetArray(op->outputs[0]);
   if (output_array.has_shape()) {
     // Shape already propagated
     return;
   }
 
-  const auto& input_array = *model->arrays[op->inputs[0]];
+  const auto& input_array = model->GetArray(op->inputs[0]);
   if (!input_array.has_shape()) {
     // Yield until input dims have been resolved.
     return;
@@ -991,7 +991,7 @@ void ProcessShapeOperator(Model* model, TensorFlowShapeOperator* op) {
 void ProcessStackOperator(Model* model, StackOperator* op) {
   CHECK_GE(op->inputs.size(), 1);
   CHECK_EQ(op->outputs.size(), 1);
-  auto& output_array = *model->arrays[op->outputs[0]];
+  auto& output_array = model->GetArray(op->outputs[0]);
   if (output_array.has_shape()) {
     // Shape already propagated
     return;
@@ -1032,7 +1032,7 @@ void ProcessStackOperator(Model* model, StackOperator* op) {
 void ProcessStridedSliceOperator(Model* model, StridedSliceOperator* op) {
   CHECK_GE(op->inputs.size(), 1);
   CHECK_EQ(op->outputs.size(), 1);
-  auto& output_array = *model->arrays[op->outputs[0]];
+  auto& output_array = model->GetArray(op->outputs[0]);
   if (output_array.has_shape()) {
     // Shape already propagated
     return;
@@ -1112,12 +1112,12 @@ void ProcessSqueezeOperator(Model* model, SqueezeOperator* op) {
   CHECK_EQ(op->inputs.size(), 1);
   CHECK_EQ(op->outputs.size(), 1);
 
-  const auto& input_array = *model->arrays[op->inputs[0]];
+  const auto& input_array = model->GetArray(op->inputs[0]);
 
   // Yield until input dims have been resolved.
   if (!input_array.has_shape()) return;
 
-  auto& output_array = *model->arrays[op->outputs[0]];
+  auto& output_array = model->GetArray(op->outputs[0]);
   if (output_array.has_shape()) return;
 
   const std::vector<int>& input_dims = input_array.shape().dims();
@@ -1136,18 +1136,18 @@ void ProcessSqueezeOperator(Model* model, SqueezeOperator* op) {
 
 void ProcessSvdfOperator(Model* model, SvdfOperator* op) {
   CHECK(op->inputs.size() == 3 || op->inputs.size() == 4);
-  const auto& input_array = *model->arrays[op->inputs[0]];
+  const auto& input_array = model->GetArray(op->inputs[0]);
   if (!input_array.has_shape()) return;
 
-  auto& weights_feature_array = *model->arrays[op->inputs[1]];
+  auto& weights_feature_array = model->GetArray(op->inputs[1]);
   if (!weights_feature_array.has_shape()) return;
 
-  const auto& weights_time_array = *model->arrays[op->inputs[2]];
+  const auto& weights_time_array = model->GetArray(op->inputs[2]);
   if (!weights_time_array.has_shape()) return;
 
   const bool has_bias = (op->inputs.size() == 4);
   if (has_bias) {
-    const auto& bias_array = *model->arrays[op->inputs[3]];
+    const auto& bias_array = model->GetArray(op->inputs[3]);
     if (!bias_array.has_shape()) return;
   }
 
@@ -1164,13 +1164,13 @@ void ProcessSvdfOperator(Model* model, SvdfOperator* op) {
 }
 
 void ProcessTransposeOperator(Model* model, TransposeOperator* op) {
-  auto& output_array = *model->arrays[op->outputs[0]];
+  auto& output_array = model->GetArray(op->outputs[0]);
   if (output_array.has_shape()) {
     // We have already run
     return;
   }
 
-  const auto& input_array = *model->arrays[op->inputs[0]];
+  const auto& input_array = model->GetArray(op->inputs[0]);
   if (!input_array.has_shape()) {
     // Yield until input dims have been resolved.
     return;
@@ -1204,7 +1204,7 @@ void ProcessTransposeOperator(Model* model, TransposeOperator* op) {
 
 void ProcessArgMaxOperator(Model* model, ArgMaxOperator* op) {
   CHECK_EQ(op->inputs.size(), 2);
-  const auto& input_array = *model->arrays[op->inputs[0]];
+  const auto& input_array = model->GetArray(op->inputs[0]);
   // Yield until input dims have been resolved.
   if (!input_array.has_shape()) {
     return;
@@ -1222,7 +1222,7 @@ void ProcessArgMaxOperator(Model* model, ArgMaxOperator* op) {
   }
   output_dims.push_back(1);
   const string& output_name = op->outputs[0];
-  auto& output_array = *model->arrays[output_name];
+  auto& output_array = model->GetArray(output_name);
   if (output_array.has_shape()) {
     return;
   }
@@ -1236,8 +1236,8 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
   auto* op = it->get();
   std::unordered_map<string, std::vector<int>> old_output_dims;
   for (const auto& output : op->outputs) {
-    if (model->arrays[output]->has_shape()) {
-      old_output_dims[output] = model->arrays[output]->shape().dims();
+    if (model->GetArray(output).has_shape()) {
+      old_output_dims[output] = model->GetArray(output).shape().dims();
     }
   }
 
@@ -1433,10 +1433,10 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
   // Return true if any output dim changed, false if none changed.
   // Assumption: no transformation clears an output shape, they only add shapes.
   for (const auto& output : op->outputs) {
-    if (model->arrays[output]->has_shape() &&
-        (old_output_dims[output] != model->arrays[output]->shape().dims())) {
+    if (model->GetArray(output).has_shape() &&
+        (old_output_dims[output] != model->GetArray(output).shape().dims())) {
       AddMessageF("Set shape of %s to [%s]", output,
-                  absl::StrJoin(model->arrays[output]->shape().dims(), ","));
+                  absl::StrJoin(model->GetArray(output).shape().dims(), ","));
       return true;
     }
   }
index 56082b965a7cbd9d61cca2e26f7d76764c0e54aa..b973b2b813147cc580d2e87cea7d395f180f5aa1 100644 (file)
@@ -412,7 +412,7 @@ bool Quantize::Run(Model* model, std::size_t op_index) {
               model->flags.set_output_arrays(i, dequantize_op->inputs[0]);
             }
           }
-          model->arrays.erase(dequantize_op->outputs[0]);
+          model->EraseArray(dequantize_op->outputs[0]);
           model->operators.erase(dequantize_it);
         }
       }
index 371ced388a8111c18ada32cf31a784809479291d..11f8d4b6eea836c5fe4efcbd5136e6183a59dc62 100644 (file)
@@ -80,7 +80,7 @@ bool ReadFakeQuantMinMax::Run(Model* model, std::size_t op_index) {
     // else.
     for (int i = 1; i <= 2; i++) {
       if (CountOpsWithInput(*model, fakequant_op->inputs[i]) == 1) {
-        model->arrays.erase(fakequant_op->inputs[i]);
+        model->EraseArray(fakequant_op->inputs[i]);
       }
     }
     fakequant_op->inputs.resize(1);
index 3992e7d1ef71edd4040e626d5848d2fd9bb3dab6..c3b2709a33d54213661ba96394b01aa2cfd1a278 100644 (file)
@@ -51,7 +51,7 @@ bool RemoveFinalDequantizeOp::Run(Model* model, std::size_t op_index) {
 
   // Remove the node and its output array.
   AddMessageF("Removed final %s", LogName(*dequantize_op));
-  model->arrays.erase(output);
+  model->EraseArray(output);
   model->operators.erase(dequantize_it);
   return true;
 }
index 6add443f2d62fd06e8c0d17e03bc78c5d74732a1..8512e6bb5ada41766a0ab6a4c06de060b898b1b4 100644 (file)
@@ -81,7 +81,7 @@ bool RemoveTrivialBinaryOperator::Run(Model* model, std::size_t op_index) {
   // Now check if the constant operand makes this binary
   // operator trivial.
   const auto& constant_input_array =
-      *model->arrays[binary_op->inputs[index_of_constant_input]];
+      model->GetArray(binary_op->inputs[index_of_constant_input]);
   // For now, we only handle floats here.
   if (constant_input_array.data_type != ArrayDataType::kFloat) {
     return false;
index 23a5c857e8b19f7edbb48f2c004d03e21008833d..936854a04fd600ea23ab5dda50370f85a311c28c 100644 (file)
@@ -59,7 +59,7 @@ bool RemoveTrivialConcatenationInput::Run(Model* model, std::size_t op_index) {
   for (const string& input : trivial_inputs) {
     if (IsDiscardableArray(*model, input) &&
         CountOpsWithInput(*model, input) == 1) {
-      model->arrays.erase(input);
+      model->EraseArray(input);
     }
   }
   concat_op->inputs = nontrivial_inputs;
index 047389f69a1d8987b52b07478b0d3eaf46f433ba..587f171bbf823408a45083c36d52f1d38c300123 100644 (file)
@@ -124,7 +124,7 @@ bool RemoveTrivialPassthroughOp(GraphTransformation* transformation,
       }
     }
     if (!is_referenced) {
-      model->arrays.erase(removal_candidate);
+      model->EraseArray(removal_candidate);
     }
   }
 
index e6cca8acf36745d989fb731aa948f257375d7e90..aa2c293382a98b476bee783ed8e177b19d35b858 100644 (file)
@@ -33,7 +33,7 @@ bool RemoveUnusedOp::Run(Model* model, std::size_t op_index) {
   // the model. We allow specifying an arbitrary input_array,
   // treating the part of the graph leading up to it as unused.
   for (const auto& output : op->outputs) {
-    CHECK(model->arrays.count(output));
+    CHECK(model->HasArray(output));
     // If this output is provided as the model's input array,
     // then we don't need this operator to produce its contents.
     if (IsInputArray(*model, output)) {
@@ -93,7 +93,7 @@ bool RemoveUnusedOp::Run(Model* model, std::size_t op_index) {
     if (IsDiscardableArray(*model, input) &&
         CountOpsWithInput(*model, input) == 1 &&
         !GetOpWithOutput(*model, input)) {
-      model->arrays.erase(input);
+      model->EraseArray(input);
     }
   }
 
@@ -116,7 +116,7 @@ bool RemoveUnusedOp::Run(Model* model, std::size_t op_index) {
       continue;
     }
     // Generic case: do delete this output array.
-    model->arrays.erase(output);
+    model->EraseArray(output);
   }
   model->operators.erase(it);
   return true;
index 3eb7fa3896c57ea612f21f8b4f3fa568d19420d4..fb109eb91b16e3a73005230f821c18b9ef82d2fb 100644 (file)
@@ -121,9 +121,9 @@ bool ResolveBatchNormalization::Run(Model* model, std::size_t op_index) {
   }
 
   // Remove the old param arrays
-  model->arrays.erase(bn_op->inputs[1]);
-  model->arrays.erase(bn_op->inputs[2]);
-  model->arrays.erase(bn_op->inputs[3]);
+  model->EraseArray(bn_op->inputs[1]);
+  model->EraseArray(bn_op->inputs[2]);
+  model->EraseArray(bn_op->inputs[3]);
 
   // Remove the old operator
   DCHECK_EQ(bn_it->get(), bn_op);
index 7777d4f54359071c775806999ecf1418a8762d60..a06919e228dc2084f8943a714a0ca111d013c159 100644 (file)
@@ -42,7 +42,7 @@ bool ResolveBatchToSpaceNDAttributes::Run(Model* model, std::size_t op_index) {
     return false;
 
   // Handle crops
-  const auto& crops_array = *model->arrays[op->inputs[2]];
+  const auto& crops_array = model->GetArray(op->inputs[2]);
   if (!crops_array.has_shape()) return false;
   const std::vector<int>& crops_dims = crops_array.shape().dims();
   if (crops_dims.size() != 2) {
@@ -58,7 +58,7 @@ bool ResolveBatchToSpaceNDAttributes::Run(Model* model, std::size_t op_index) {
   }
 
   // Handle block_shape
-  const auto& block_shape_array = *model->arrays[op->inputs[1]];
+  const auto& block_shape_array = model->GetArray(op->inputs[1]);
   if (!block_shape_array.has_shape()) return false;
   const std::vector<int>& block_shape_dims = block_shape_array.shape().dims();
   CHECK_EQ(block_shape_dims.size(), 1);
index fd51df4058dbda4732686983f9b9dab3781ec4d1..5e779f6765262326bd59db886c2feed603e0102e 100644 (file)
@@ -166,8 +166,9 @@ void EvaluateBinaryOperatorOnConstantInputs(Model* model,
 
 void EvaluateBinaryOperatorOnConstantInputs(Model* model,
                                             const Operator* binary_op) {
-  const auto inputs_data_type = model->arrays[binary_op->inputs[0]]->data_type;
-  const auto output_data_type = model->arrays[binary_op->outputs[0]]->data_type;
+  const auto inputs_data_type = model->GetArray(binary_op->inputs[0]).data_type;
+  const auto output_data_type =
+      model->GetArray(binary_op->outputs[0]).data_type;
 #define TOCO_HANDLE_CASE(InputsDataType, OutputDataType)                    \
   if (inputs_data_type == InputsDataType &&                                 \
       output_data_type == OutputDataType) {                                 \
@@ -214,7 +215,7 @@ bool ResolveConstantBinaryOperator::Run(Model* model, std::size_t op_index) {
     return false;
   }
 
-  auto& output_array = *model->arrays[binary_op->outputs[0]];
+  auto& output_array = model->GetArray(binary_op->outputs[0]);
   // Yield until the output array dims have been resolved.
   if (!output_array.has_shape()) {
     return false;
@@ -239,10 +240,10 @@ bool ResolveConstantBinaryOperator::Run(Model* model, std::size_t op_index) {
 
   // Remove the binary operator and its inputs
   if (CountOpsWithInput(*model, binary_op->inputs[0]) == 1) {
-    model->arrays.erase(binary_op->inputs[0]);
+    model->EraseArray(binary_op->inputs[0]);
   }
   if (CountOpsWithInput(*model, binary_op->inputs[1]) == 1) {
-    model->arrays.erase(binary_op->inputs[1]);
+    model->EraseArray(binary_op->inputs[1]);
   }
   AddMessageF("Resolved constant %s to the equivalent constant array",
               LogName(*binary_op));
index 9835f86398a37f118d3ebd5b568ffddbcd56c38b..833c97c758f7587a3f5bf799e2de6417153454f2 100644 (file)
@@ -189,7 +189,7 @@ bool ResolveConstantConcatenation::Run(Model* model, std::size_t op_index) {
 
   // Remove all the resolved arrays.
   for (const string& input_name : concat_op->inputs) {
-    model->arrays.erase(input_name);
+    model->EraseArray(input_name);
   }
 
   // Remove concatenate operator
index 244adcc4c46eda9de79dd753565113bbeca970c5..81fe37d7e017c6e2440de34cc2daedf7fb2a422e 100644 (file)
@@ -66,7 +66,7 @@ bool ResolveConstantFakeQuant::Run(Model* model, std::size_t op_index) {
     output_buffer.data[i] = dst_val;
   }
   if (CountOpsWithInput(*model, fakequant_op->inputs[0]) == 1) {
-    model->arrays.erase(fakequant_op->inputs[0]);
+    model->EraseArray(fakequant_op->inputs[0]);
   }
   model->operators.erase(fakequant_it);
 
index 9da51d9147a98a935d00db04827aa7ebb12998b9..f6f95481b57f58f497b119df73d331f13d9705c0 100644 (file)
@@ -104,11 +104,11 @@ bool ResolveConstantFill::Run(Model* model, std::size_t op_index) {
   // Erase input arrays if no longer used
   if (IsDiscardableArray(*model, op->inputs[0]) &&
       CountOpsWithInput(*model, op->inputs[0]) == 1) {
-    model->arrays.erase(op->inputs[0]);
+    model->EraseArray(op->inputs[0]);
   }
   if (IsDiscardableArray(*model, op->inputs[1]) &&
       CountOpsWithInput(*model, op->inputs[1]) == 1) {
-    model->arrays.erase(op->inputs[1]);
+    model->EraseArray(op->inputs[1]);
   }
 
   // Erase the operator
index 383d54aa5a7fa4933a9eb9ffac014bab4497d40d..1a0ba9e2bc7235720b59210cdd6affa089613077 100644 (file)
@@ -28,17 +28,17 @@ bool ResolveConstantRange::Run(Model* model, std::size_t op_index) {
   auto* op = static_cast<RangeOperator*>(base_op);
 
   CHECK_EQ(op->inputs.size(), 3);
-  const auto& start_array = *model->arrays[op->inputs[0]];
+  const auto& start_array = model->GetArray(op->inputs[0]);
   if (!start_array.has_shape()) {
     // Yield until all input dims have been resolved.
     return false;
   }
-  const auto& limit_array = *model->arrays[op->inputs[1]];
+  const auto& limit_array = model->GetArray(op->inputs[1]);
   if (!limit_array.has_shape()) {
     // Yield until all input dims have been resolved.
     return false;
   }
-  const auto& delta_array = *model->arrays[op->inputs[2]];
+  const auto& delta_array = model->GetArray(op->inputs[2]);
   if (!delta_array.has_shape()) {
     // Yield until all input dims have been resolved.
     return false;
@@ -52,7 +52,7 @@ bool ResolveConstantRange::Run(Model* model, std::size_t op_index) {
   }
 
   CHECK_EQ(op->outputs.size(), 1);
-  auto& output_array = *model->arrays[op->outputs[0]];
+  auto& output_array = model->GetArray(op->outputs[0]);
   if (output_array.data_type == ArrayDataType::kNone) {
     // Yield until the output type has been set by PropagateArrayDataTypes
     return false;
@@ -87,15 +87,15 @@ bool ResolveConstantRange::Run(Model* model, std::size_t op_index) {
   // Delete the input array if no longer used
   if (IsDiscardableArray(*model, op->inputs[0]) &&
       CountOpsWithInput(*model, op->inputs[0]) == 1) {
-    model->arrays.erase(op->inputs[0]);
+    model->EraseArray(op->inputs[0]);
   }
   if (IsDiscardableArray(*model, op->inputs[1]) &&
       CountOpsWithInput(*model, op->inputs[1]) == 1) {
-    model->arrays.erase(op->inputs[1]);
+    model->EraseArray(op->inputs[1]);
   }
   if (IsDiscardableArray(*model, op->inputs[2]) &&
       CountOpsWithInput(*model, op->inputs[2]) == 1) {
-    model->arrays.erase(op->inputs[2]);
+    model->EraseArray(op->inputs[2]);
   }
 
   // Delete the operator
index 35b81dd5506cfb0048ab1347bfefd07b128bc92b..9ea01acd05364224ce219bed533c999793a2a2f1 100644 (file)
@@ -62,7 +62,7 @@ bool ResolveConstantShapeOrRank::Run(Model* model, std::size_t op_index) {
   // Delete the input array if no longer used
   if (IsDiscardableArray(*model, op->inputs[0]) &&
       CountOpsWithInput(*model, op->inputs[0]) == 1) {
-    model->arrays.erase(op->inputs[0]);
+    model->EraseArray(op->inputs[0]);
   }
 
   model->operators.erase(it);
index 86c76141a4705de841c8e70790cce7be28fb59c9..ea0d6dc8200897db9266efbe41556dbf4c296db3 100644 (file)
@@ -101,7 +101,7 @@ bool ResolveConstantStack::Run(Model* model, std::size_t op_index) {
   for (const auto& input : op->inputs) {
     if (IsDiscardableArray(*model, input) &&
         CountOpsWithInput(*model, input) == 1) {
-      model->arrays.erase(input);
+      model->EraseArray(input);
     }
   }
 
index 3976d9cbb492138c0c45801045833e08411acbd4..a0cfc3d59763dc1211ed4d1ac114d371a4a7ee0b 100644 (file)
@@ -186,7 +186,7 @@ bool ResolveConstantStridedSlice::Run(Model* model, std::size_t op_index) {
   // Erase input array if no longer used
   if (IsDiscardableArray(*model, op->inputs[0]) &&
       CountOpsWithInput(*model, op->inputs[0]) == 1) {
-    model->arrays.erase(op->inputs[0]);
+    model->EraseArray(op->inputs[0]);
   }
 
   // Erase the operator
index 26ff9d887b40651559ad030cd41a824679d6dd15..1cd2aff28c68eaba4e9b18d8e2c2803834328696 100644 (file)
@@ -199,7 +199,7 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) {
   }
   for (const auto& input : unary_op->inputs) {
     if (CountOpsWithInput(*model, input) == 1) {
-      model->arrays.erase(input);
+      model->EraseArray(input);
     }
   }
   AddMessageF("Resolved constant %s to the equivalent constant array",
index b77be3f5c0d04b028391c1ce9de39afd7632eb36..013b50ac9ba8a51c23b19953d987b2fbf63fcea1 100644 (file)
@@ -36,7 +36,7 @@ bool ResolveMeanAttributes::Run(Model* model, std::size_t op_index) {
   if (op->inputs.size() != 2) return false;
   if (!IsConstantParameterArray(*model, op->inputs[1])) return false;
 
-  const auto& indices_array = *model->arrays[op->inputs[1]];
+  const auto& indices_array = model->GetArray(op->inputs[1]);
   if (!indices_array.has_shape()) return false;
   op->axis = indices_array.GetBuffer<ArrayDataType::kInt32>().data;
   return true;
index d5f5869c625f419a825f6bd652a04eca1bce4a6f..8a8e723cf7b2d77ec199e3817464a068bf85afdd 100644 (file)
@@ -35,7 +35,7 @@ bool ResolvePadAttributes::Run(Model* model, std::size_t op_index) {
   CHECK_EQ(op->inputs.size(), 2);
   if (!IsConstantParameterArray(*model, op->inputs[1])) return false;
 
-  const auto& array = *model->arrays[op->inputs[1]];
+  const auto& array = model->GetArray(op->inputs[1]);
   if (!array.has_shape()) return false;
 
   const std::vector<int>& dims = array.shape().dims();
index b5093bc4c7c33b3e555ca14151c2489cddc6dbd3..5c68f87f6ccd912a94213c95a59a78076b0e768b 100644 (file)
@@ -103,7 +103,7 @@ bool ResolveReorderAxes::Run(Model* model, std::size_t op_index) {
   AddMessageF("Reordered axes for array %s", input_array_name);
 
   // Remove the op and output array.
-  model->arrays.erase(output_array_name);
+  model->EraseArray(output_array_name);
   model->operators.erase(reorder_it);
   return true;
 }
index bed2a85bd262c49913f22e522d260c4dc6510246..2e063e35548aa5e51c3bcc94a2dfc7992180d014 100644 (file)
@@ -37,7 +37,7 @@ bool ResolveReshapeAttributes::Run(Model* model, std::size_t op_index) {
   if (!op->shape.empty()) return false;
 
   if (IsConstantParameterArray(*model, reshape_op->inputs[1])) {
-    const auto& constant_input_array = *model->arrays[reshape_op->inputs[1]];
+    const auto& constant_input_array = model->GetArray(reshape_op->inputs[1]);
     op->shape = constant_input_array.GetBuffer<ArrayDataType::kInt32>().data;
   }
 
index 1d0a2ec8f6c1f532f23873062534a37e07fff72b..e760d08e5a6c2f56db6b11fee922b701d33dd1a0 100644 (file)
@@ -36,10 +36,10 @@ bool ResolveSliceAttributes::Run(Model* model, std::size_t op_index) {
   if (!IsConstantParameterArray(*model, op->inputs[1])) return false;
   if (!IsConstantParameterArray(*model, op->inputs[2])) return false;
 
-  const auto& begin_array = *model->arrays[op->inputs[1]];
+  const auto& begin_array = model->GetArray(op->inputs[1]);
   if (!begin_array.has_shape()) return false;
 
-  const auto& size_array = *model->arrays[op->inputs[2]];
+  const auto& size_array = model->GetArray(op->inputs[2]);
   if (!size_array.has_shape()) return false;
 
   op->begin = begin_array.GetBuffer<ArrayDataType::kInt32>().data;
index a73f16735cb232753e8f64caae31f5c945b6bffd..dad6aceccfd201b3db07c29c99a8c6ef75bb89a1 100644 (file)
@@ -45,7 +45,7 @@ bool ResolveSpaceToBatchNDAttributes::Run(Model* model, std::size_t op_index) {
     return false;
 
   // Handle paddings.
-  const auto& paddings_array = *model->arrays[op->inputs[paddings_index]];
+  const auto& paddings_array = model->GetArray(op->inputs[paddings_index]);
   if (!paddings_array.has_shape()) return false;
   const std::vector<int>& paddings_dims = paddings_array.shape().dims();
   if (paddings_dims.size() != 2) {
@@ -61,7 +61,8 @@ bool ResolveSpaceToBatchNDAttributes::Run(Model* model, std::size_t op_index) {
   }
 
   // Handle block_shape.
-  const auto& block_shape_array = *model->arrays[op->inputs[block_shape_index]];
+  const auto& block_shape_array =
+      model->GetArray(op->inputs[block_shape_index]);
   if (!block_shape_array.has_shape()) return false;
   const std::vector<int>& block_shape_dims = block_shape_array.shape().dims();
   CHECK_EQ(block_shape_dims.size(), 1);
index dbe69adcbd34bb0544239ebb096fb8bfc4bfcb49..de4d06be2a551b7ba6c99a72b1b11454f9cdcdcf 100644 (file)
@@ -31,13 +31,13 @@ bool ResolveStridedSliceAttributes::Run(Model* model, std::size_t op_index) {
   }
 
   CHECK_EQ(op->inputs.size(), 4);
-  const auto& start_array = *model->arrays[op->inputs[1]];
+  const auto& start_array = model->GetArray(op->inputs[1]);
   if (!start_array.has_shape()) return false;
 
-  const auto& stop_array = *model->arrays[op->inputs[2]];
+  const auto& stop_array = model->GetArray(op->inputs[2]);
   if (!stop_array.has_shape()) return false;
 
-  const auto& stride_array = *model->arrays[op->inputs[3]];
+  const auto& stride_array = model->GetArray(op->inputs[3]);
   if (!stride_array.has_shape()) return false;
 
   if (!IsConstantParameterArray(*model, op->inputs[1])) return false;
index c6723a880ed0e51cc5828f77742a6c8eb70fa864..5c0c1e3478fa0d94104d1b76bab176b98b314c50 100644 (file)
@@ -75,7 +75,7 @@ bool ResolveTensorFlowConcat::Run(Model* model, std::size_t op_index) {
 
   // Remove the axis array if it is not used by anything else.
   if (CountOpsWithInput(*model, axis_name) == 1) {
-    model->arrays.erase(axis_name);
+    model->EraseArray(axis_name);
   }
   // Remove the TensorFlowConcat op
   model->operators.erase(concat_it);
index bea7487051a58344a56a3186a05d0fdceebc8727..ad1e56888e53133c5a84cc0e3d5e76b7ef3b29b4 100644 (file)
@@ -69,7 +69,7 @@ bool ResolveTensorFlowMatMul::Run(Model* model, std::size_t op_index) {
                 LogName(*matmul_op), LogName(*fc_op));
     const auto& previous_op_output = previous_op->outputs[0];
     if (CountOpsWithInput(*model, previous_op_output) == 1) {
-      model->arrays.erase(previous_op_output);
+      model->EraseArray(previous_op_output);
     }
     CHECK_EQ(previous_op->inputs.size(), 2);
     fc_op->inputs = {previous_op->inputs[0], matmul_op->inputs[1]};
@@ -78,7 +78,7 @@ bool ResolveTensorFlowMatMul::Run(Model* model, std::size_t op_index) {
       const auto& previous_op_shape = previous_op->inputs[1];
       if (CountOpsWithInput(*model, previous_op_shape) == 1 &&
           !GetOpWithOutput(*model, previous_op_shape)) {
-        model->arrays.erase(previous_op_shape);
+        model->EraseArray(previous_op_shape);
       }
       model->operators.erase(previous_op_it);
     }
index cfa5ce0716523adbfb0a76e89ce3b202f0595763..477e7f13da3d88a68547d494011cd4984936b909 100644 (file)
@@ -55,7 +55,7 @@ bool ResolveTensorFlowMerge::Run(Model* model, std::size_t op_index) {
 
   // Remove the node and its output array.
   AddMessageF("Removing already-resolved %s", LogName(*merge_op));
-  model->arrays.erase(merge_op->outputs[0]);
+  model->EraseArray(merge_op->outputs[0]);
   model->operators.erase(merge_it);
   return true;
 }
index 150cf53da3099227c5c637ee58c44512d5a41d4f..a418073441f1241a5acb1164b36f332828ea2e99 100644 (file)
@@ -103,7 +103,7 @@ bool ResolveTensorFlowSwitch::Run(Model* model, std::size_t op_index) {
   // Remove the output arrays if they are now unused.
   for (int i = 0; i < 2; i++) {
     if (!GetOpWithInput(*model, switch_op->outputs[i])) {
-      model->arrays.erase(switch_op->outputs[i]);
+      model->EraseArray(switch_op->outputs[i]);
     }
   }
   // Remove input arrays if they are only used by the switch itself and aren't
@@ -111,7 +111,7 @@ bool ResolveTensorFlowSwitch::Run(Model* model, std::size_t op_index) {
   for (const auto& input : switch_op->inputs) {
     if (CountOpsWithInput(*model, input) == 1 &&
         !GetOpWithOutput(*model, input)) {
-      model->arrays.erase(input);
+      model->EraseArray(input);
     }
   }
   // Remove the switch node itself.
index 9f7e7c42a26b60c96573be6653babb78fdb5fd73..1ddf54c778cd1fae7a8fce0ecb97209274e71ac0 100644 (file)
@@ -45,10 +45,10 @@ void RemoveTileOperator(Model* model, Operator* tile_op, Operator* binary_op,
   model->operators.erase(tile_it);
   if (!CountOpsWithInput(*model, tile_multiplier_array) &&
       !GetOpWithOutput(*model, tile_multiplier_array)) {
-    model->arrays.erase(tile_multiplier_array);
+    model->EraseArray(tile_multiplier_array);
   }
   if (!CountOpsWithInput(*model, tile_output_array)) {
-    model->arrays.erase(tile_output_array);
+    model->EraseArray(tile_output_array);
   }
 }
 }  // namespace
index 12d966b26104fd491f914fbdb39e0a62fdda19bc..a657ee00af66bd431f96c361e12d5213e203b3df 100644 (file)
@@ -35,7 +35,7 @@ bool ResolveTransposeAttributes::Run(Model* model, std::size_t op_index) {
   if (!IsConstantParameterArray(*model, op->inputs[1])) return false;
 
   // Handling perm.
-  const auto& perm_array = *model->arrays[op->inputs[1]];
+  const auto& perm_array = model->GetArray(op->inputs[1]);
   if (!perm_array.has_shape()) return false;
 
   const std::vector<int>& perm_dims = perm_array.shape().dims();
index a14016e8e2705a66c392118899335eb3997fa1de..3a1d175b9823f085c9b8730caba8bedd7eb87d52 100644 (file)
@@ -19,7 +19,6 @@ limitations under the License.
 
 #include <gmock/gmock.h>
 #include <gtest/gtest.h>
-//#include "tensorflow/contrib/lite/kernels/test_util.h"
 #include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
 #include "tensorflow/contrib/lite/toco/model.h"
 #include "tensorflow/contrib/lite/toco/tooling_util.h"
@@ -168,11 +167,11 @@ TEST_F(ResolveConstantConcatenationTest, ConcatAtAxis0) {
 
   GraphTransformationsSet graph_transformation_set;
   graph_transformation_set.Add(new toco::ResolveConstantConcatenation);
-  EXPECT_THAT(model.arrays.size(), 5);
+  EXPECT_THAT(model.GetArrayMap().size(), 5);
   (*graph_transformation_set.begin())->Run(&model, /*op_index=*/0);
-  EXPECT_THAT(model.arrays.size(), 1);
+  EXPECT_THAT(model.GetArrayMap().size(), 1);
 
-  auto& concatenated_array = (*model.arrays.begin()).second;
+  auto& concatenated_array = (*model.GetArrayMap().begin()).second;
   EXPECT_THAT(concatenated_array->GetBuffer<toco::ArrayDataType::kFloat>().data,
               ElementsAreArray(ArrayFloatNear(
                   {0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  10., 11., 12.,
@@ -187,11 +186,11 @@ TEST_F(ResolveConstantConcatenationTest, ConcatAtAxis1) {
 
   GraphTransformationsSet graph_transformation_set;
   graph_transformation_set.Add(new toco::ResolveConstantConcatenation);
-  EXPECT_THAT(model.arrays.size(), 5);
+  EXPECT_THAT(model.GetArrayMap().size(), 5);
   (*graph_transformation_set.begin())->Run(&model, /*op_index=*/0);
-  EXPECT_THAT(model.arrays.size(), 1);
+  EXPECT_THAT(model.GetArrayMap().size(), 1);
 
-  auto& concatenated_array = (*model.arrays.begin()).second;
+  auto& concatenated_array = (*model.GetArrayMap().begin()).second;
   EXPECT_THAT(concatenated_array->GetBuffer<toco::ArrayDataType::kFloat>().data,
               ElementsAreArray(ArrayFloatNear(
                   {0.,  1.,  2.,  3.,  10., 11., 12., 13., 20., 21., 22.,
@@ -206,11 +205,11 @@ TEST_F(ResolveConstantConcatenationTest, ConcatAtAxis2) {
 
   GraphTransformationsSet graph_transformation_set;
   graph_transformation_set.Add(new toco::ResolveConstantConcatenation);
-  EXPECT_THAT(model.arrays.size(), 5);
+  EXPECT_THAT(model.GetArrayMap().size(), 5);
   (*graph_transformation_set.begin())->Run(&model, /*op_index=*/0);
-  EXPECT_THAT(model.arrays.size(), 1);
+  EXPECT_THAT(model.GetArrayMap().size(), 1);
 
-  auto& concatenated_array = (*model.arrays.begin()).second;
+  auto& concatenated_array = (*model.GetArrayMap().begin()).second;
   EXPECT_THAT(concatenated_array->GetBuffer<toco::ArrayDataType::kFloat>().data,
               ElementsAreArray(ArrayFloatNear(
                   {0.,  1.,  10., 11., 20., 21., 30., 31., 2.,  3.,  12.,
index 4e273343df9f3e5ade8f23a2fbd868bcab72c62e..2c7046c8c77c94a89fc05a26d7d72b3661380475 100644 (file)
@@ -63,7 +63,7 @@ bool UnfuseActivationFunctions::Run(Model* model, std::size_t op_index) {
   ac_op->outputs = op->outputs;
   const string& tmp_array_name =
       AvailableArrayName(*model, op->outputs[0] + "_unfused");
-  CHECK(!model->arrays.count(tmp_array_name));
+  CHECK(!model->HasArray(tmp_array_name));
   model->GetOrCreateArray(tmp_array_name);
   ac_op->inputs = {tmp_array_name};
   op->outputs = {tmp_array_name};
index 995e9d67ca3ae34471595d2d629d2fe993c21ab5..1947271f55d66b74a04cdb9bb06b724c4b5f4133 100644 (file)
@@ -1652,7 +1652,7 @@ void StripCaretFromArrayNames(Model* model) {
       output = string(absl::StripPrefix(output, "^"));
     }
   }
-  for (auto& array : model->arrays) {
+  for (auto& array : model->GetArrayMap()) {
     if (absl::StartsWith(array.first, "^")) {
       LOG(FATAL) << "What?";
     }
index b079750bed69f4caad7ab194b91e5ae1f5ca282e..54fbba7381b4b8ad73dc5ce0baba5cde8fce90ef 100644 (file)
@@ -1521,15 +1521,19 @@ struct Array {
 
 // Our Model struct, represents an entire model (our "top-level" struct).
 // Owns everything.
-struct Model {
+class Model {
+ public:
+  using ArrayMap = std::unordered_map<string, std::unique_ptr<Array>>;
+
+  bool HasArray(const string& name) const { return arrays.count(name) > 0; }
   Array& GetArray(const string& name) const {
-    DCHECK(arrays.count(name));
+    DCHECK(HasArray(name));
     return *arrays.at(name);
   }
   Array& GetOrCreateArray(const string& name) {
     // Make sure name is not used by an optional array
     DCHECK(!optional_arrays.count(name));
-    if (!arrays.count(name)) {
+    if (!HasArray(name)) {
       Array* ptr = new Array;
       arrays[name] = std::unique_ptr<Array>(ptr);
     }
@@ -1544,18 +1548,27 @@ struct Model {
     return optional_arrays.count(name);
   }
 
+  // Note that this invalidates all array iterators.
+  void EraseArray(const string& name) { arrays.erase(name); }
+  void EraseArrays(std::function<bool(const string&)> discardable) {
+    for (auto it = arrays.begin(); it != arrays.end();) {
+      if (discardable(it->first)) {
+        it = arrays.erase(it);
+      } else {
+        ++it;
+      }
+    }
+  }
+  const ArrayMap& GetArrayMap() const { return arrays; }
+
   // Optional arrays are used for optional tensors,
   // these tensors do not have data, but with reserved names as op inputs.
   std::set<string> optional_arrays;
+
   // The list of operators. Notice how it's a list of unique_ptr's, implying
   // that the Model is what owns Operator's and keeps them alive.
   std::vector<std::unique_ptr<Operator>> operators;
-  // The associative array mapping names to Array's.
-  // Notice how it's a container of unique_ptr's, implying
-  // that the Model is what owns Array's and keeps them alive.
-  // The Operator's refer to these Array's by their name strings, not by their
-  // addresses. See Operator::inputs, Operator::outputs.
-  std::unordered_map<string, std::unique_ptr<Array>> arrays;
+
   // Generic flags, a place where we combine information passed to us via
   // command-line parameters (e.g. --input_width=N) with information that
   // we may or may not find in the input model file.
@@ -1564,6 +1577,14 @@ struct Model {
   std::size_t transient_data_size = 0;
   // For code-generation only: required alignment of the transient_data buffer
   std::size_t transient_data_alignment = 0;
+
+ private:
+  // The associative array mapping names to Array's.
+  // Notice how it's a container of unique_ptr's, implying
+  // that the Model is what owns Array's and keeps them alive.
+  // The Operator's refer to these Array's by their name strings, not by their
+  // addresses. See Operator::inputs, Operator::outputs.
+  std::unordered_map<string, std::unique_ptr<Array>> arrays;
 };
 }  // namespace toco
 
index 440353203e7aeb93d7ee3e2bb1971bc57843f933..391ef87029d019ab52af2716f72883f5f82f94d9 100644 (file)
@@ -62,7 +62,7 @@ namespace details {
 void LoadTensorsMap(const Model& model, TensorsMap* tensors_map) {
   // First find a list of unique array names.
   std::set<string> names;
-  for (const auto& array_pair : model.arrays) {
+  for (const auto& array_pair : model.GetArrayMap()) {
     names.insert(array_pair.first);
   }
 
@@ -96,7 +96,7 @@ Offset<Vector<Offset<Tensor>>> ExportTensors(
   // tensors in the tensors_map.
   std::map<int, Offset<Tensor>> ordered_tensors;
 
-  for (const auto& array_pair : model.arrays) {
+  for (const auto& array_pair : model.GetArrayMap()) {
     const string& tensor_name = array_pair.first;
     const toco::Array& array = *array_pair.second;
 
index 309fa6d7f688ba1dd99a7e6eeda14d513a9e49d4..aad6e780d5eb5c3dbc880906df5053ad231ffd54 100644 (file)
@@ -114,7 +114,7 @@ TEST_F(ImportTest, Tensors) {
 
   auto model = Import(ModelFlags(), InputModelAsString());
 
-  ASSERT_GT(model->arrays.count("tensor_one"), 0);
+  ASSERT_GT(model->HasArray("tensor_one"), 0);
   Array& a1 = model->GetArray("tensor_one");
   EXPECT_EQ(ArrayDataType::kFloat, a1.data_type);
   EXPECT_THAT(a1.GetBuffer<ArrayDataType::kFloat>().data,
index 94b4d146968d4bf92bd8f662763eecdc92a66663..afaa0fd0c74a353d3ec74385efdca4923189093e 100644 (file)
@@ -133,7 +133,7 @@ void SetFinalDataTypeOnInputs(const TocoFlags& toco_flags, Model* model) {
 
   for (int i = 0; i < model->flags.input_arrays_size(); i++) {
     string const& array_name = model->flags.input_arrays(i).name();
-    auto* array = model->arrays[array_name].get();
+    auto* array = &model->GetArray(array_name);
     // Note that the notion of changing data types only applies to real-numbers
     // arrays (see the documentation for inference_input_type).
     // TODO(benoitjacob) this is assuming that uint8 arrays are quantized,
index f9093ab973b3eb90153ef87212f74a346e18d6a5..900c60bd909824996563c2a769efa03a51d9c065 100644 (file)
@@ -93,7 +93,7 @@ int CountOpsWithInput(const Model& model, const string& array_name) {
 
 bool DeleteArrayIfUnused(const string& array_name, Model* model) {
   if (CountOpsWithInput(*model, array_name) == 0) {
-    model->arrays.erase(array_name);
+    model->EraseArray(array_name);
     return true;
   }
   return false;
@@ -566,11 +566,11 @@ int RequiredBufferSizeForShape(const Shape& shape) {
 }
 
 bool IsConstantParameterArray(const Model& model, const string& name) {
-  if (!model.arrays.count(name)) {
+  if (!model.HasArray(name)) {
     return false;
   }
 
-  return !!model.arrays.at(name)->buffer;
+  return !!model.GetArray(name).buffer;
 }
 
 namespace {
@@ -633,17 +633,17 @@ void CheckNonExistentIOArrays(const Model& model) {
     return;
   }
   for (const auto& input_array : model.flags.input_arrays()) {
-    CHECK(model.arrays.count(input_array.name()))
+    CHECK(model.HasArray(input_array.name()))
         << "Input array not found: " << input_array.name();
   }
   for (const string& output_array : model.flags.output_arrays()) {
-    CHECK(model.arrays.count(output_array))
+    CHECK(model.HasArray(output_array))
         << "Output array not found: " << output_array;
   }
   for (const auto& rnn_state : model.flags.rnn_states()) {
     if (!rnn_state.discardable()) {
-      CHECK(model.arrays.count(rnn_state.state_array()));
-      CHECK(model.arrays.count(rnn_state.back_edge_source_array()));
+      CHECK(model.HasArray(rnn_state.state_array()));
+      CHECK(model.HasArray(rnn_state.back_edge_source_array()));
     }
   }
 }
@@ -652,10 +652,10 @@ void CheckNonExistentIOArrays(const Model& model) {
 void CheckNoMissingArray(const Model& model) {
   for (const auto& op : model.operators) {
     for (const auto& input : op->inputs) {
-      CHECK(model.arrays.count(input) || model.optional_arrays.count(input));
+      CHECK(model.HasArray(input) || model.optional_arrays.count(input));
     }
     for (const auto& output : op->outputs) {
-      CHECK(model.arrays.count(output));
+      CHECK(model.HasArray(output));
     }
   }
   CheckNonExistentIOArrays(model);
@@ -664,12 +664,12 @@ void CheckNoMissingArray(const Model& model) {
 void FixNoMissingArray(Model* model) {
   for (const auto& op : model->operators) {
     for (const auto& input : op->inputs) {
-      if (!model->arrays.count(input)) {
+      if (!model->HasArray(input)) {
         model->GetOrCreateArray(input);
       }
     }
     for (const auto& output : op->outputs) {
-      if (!model->arrays.count(output)) {
+      if (!model->HasArray(output)) {
         model->GetOrCreateArray(output);
       }
     }
@@ -687,7 +687,7 @@ void FixNoMissingArray(Model* model) {
 
 void CheckNoOrphanedArray(const Model& model) {
   std::unordered_set<string> arrays_without_known_use;
-  for (const auto& array : model.arrays) {
+  for (const auto& array : model.GetArrayMap()) {
     if (IsDiscardableArray(model, array.first)) {
       arrays_without_known_use.insert(array.first);
     }
@@ -714,7 +714,7 @@ void CheckNoOrphanedArray(const Model& model) {
 
 void FixNoOrphanedArray(Model* model) {
   std::unordered_set<string> arrays_without_known_use;
-  for (const auto& array : model->arrays) {
+  for (const auto& array : model->GetArrayMap()) {
     arrays_without_known_use.insert(array.first);
   }
   for (const auto& op : model->operators) {
@@ -731,13 +731,13 @@ void FixNoOrphanedArray(Model* model) {
   }
   for (const auto& array : arrays_without_known_use) {
     if (IsDiscardableArray(*model, array)) {
-      model->arrays.erase(array);
+      model->EraseArray(array);
     }
   }
 }
 
 void CheckArrayFieldsConsistent(const Model& model) {
-  for (const auto& array_entry : model.arrays) {
+  for (const auto& array_entry : model.GetArrayMap()) {
     const auto& array = array_entry.second;
     if (array->has_shape()) {
       for (int d : array->shape().dims()) {
@@ -756,7 +756,7 @@ void CheckArrayFieldsConsistent(const Model& model) {
 
 void CheckOperatorOrdering(const Model& model) {
   std::unordered_set<string> arrays_behind_us;
-  for (const auto& array_entry : model.arrays) {
+  for (const auto& array_entry : model.GetArrayMap()) {
     if (!GetOpWithOutput(model, array_entry.first)) {
       arrays_behind_us.insert(array_entry.first);
     }
@@ -781,7 +781,7 @@ void CheckOperatorOrdering(const Model& model) {
 
 void FixOperatorOrdering(Model* model) {
   std::unordered_set<string> arrays_behind_us;
-  for (const auto& array_entry : model->arrays) {
+  for (const auto& array_entry : model->GetArrayMap()) {
     if (!GetOpWithOutput(*model, array_entry.first)) {
       arrays_behind_us.insert(array_entry.first);
     }
@@ -936,7 +936,8 @@ void CheckModelCounts(const Model& model) {
     if (count_type == "None") {
       continue;
     } else if (count_type == "Arrays") {
-      CheckCountInRange(model_check, model.arrays.size(), "count of arrays");
+      CheckCountInRange(model_check, model.GetArrayMap().size(),
+                        "count of arrays");
     } else if (count_type == "Total") {
       CheckCountInRange(model_check, model.operators.size(),
                         "count of all operator instances");
@@ -1297,7 +1298,7 @@ bool IsAllocatableTransientArray(const Model& model, const string& array_name) {
       return false;
     }
   }
-  const auto& array = model.arrays.at(array_name);
+  const auto& array = &model.GetArray(array_name);
   // An array with a constant buffer isn't a transient array.
   if (!!array->buffer) {
     return false;
@@ -1310,13 +1311,13 @@ bool IsAllocatableTransientArray(const Model& model, const string& array_name) {
 }
 
 string AvailableArrayName(const Model& model, const string& name) {
-  if (!model.arrays.count(name) && !model.optional_arrays.count(name)) {
+  if (!model.HasArray(name) && !model.optional_arrays.count(name)) {
     return name;
   }
   const int kNumSuffixesToTry = 1000;
   for (int i = 0; i < kNumSuffixesToTry; i++) {
     const string& name_with_suffix = toco::port::StringF("%s_%d", name, i);
-    if (!model.arrays.count(name_with_suffix)) {
+    if (!model.HasArray(name_with_suffix)) {
       return name_with_suffix;
     }
   }
@@ -1334,12 +1335,12 @@ string ShapeToString(const Shape& shape) {
 }
 
 void PrintArrayShape(Model* model, const string& name) {
-  if (!model->arrays[name]->has_shape()) {
+  if (!model->GetArray(name).has_shape()) {
     LOG(INFO) << name << " has no shape";
     return;
   }
   LOG(INFO) << name
-            << " has shape: " << ShapeToString(model->arrays[name]->shape());
+            << " has shape: " << ShapeToString(model->GetArray(name).shape());
 }
 
 bool IsArrayFullyConnectedWeights(const Model& model, const string& name) {
@@ -1673,7 +1674,7 @@ bool IsDiscardableArray(const Model& model, const string& array_name) {
 }
 
 void CheckFinalDataTypesSatisfied(const Model& model) {
-  for (const auto& array_entry : model.arrays) {
+  for (const auto& array_entry : model.GetArrayMap()) {
     const auto& array = *array_entry.second;
     if (array.final_data_type != ArrayDataType::kNone) {
       CHECK(array.final_data_type == array.data_type)