if (!IsAllocatableTransientArray(model, array_name)) {
return 0;
}
- const auto& array = model.arrays.at(array_name);
+ const auto& array = &model.GetArray(array_name);
CHECK(array->has_shape())
<< "Array '" << array_name << "' doesn't have a shape";
if (array->data_type == ArrayDataType::kNone) {
}
const std::size_t size =
TransientArraySize(model, array_name, transient_data_alignment);
- const auto& array = model.arrays.at(array_name);
+ const auto& array = &model.GetArray(array_name);
CHECK(!array->alloc);
allocator->Allocate(size, &array->GetOrCreateAlloc());
}
if (!IsAllocatableTransientArray(model, array_name)) {
return;
}
- const auto& array = model.arrays.at(array_name);
+ const auto& array = &model.GetArray(array_name);
CHECK(!!array->alloc);
allocator->Deallocate(*array->alloc);
}
// Construct a sorted map of array names, so that other layout engines can
// match exactly.
std::map<string, const Array*> ordered_arrays_map;
- for (const auto& pair : model->arrays) {
+ for (const auto& pair : model->GetArrayMap()) {
ordered_arrays_map[pair.first] = pair.second.get();
}
if (last_specified) {
// Return only the part of the graph between graphviz_first_array
// and graphviz_last_array.
- CHECK(model.arrays.count(dump_options.graphviz_first_array));
- CHECK(model.arrays.count(dump_options.graphviz_last_array));
+ CHECK(model.HasArray(dump_options.graphviz_first_array));
+ CHECK(model.HasArray(dump_options.graphviz_last_array));
std::unordered_set<string> arrays_already_produced;
std::vector<string> arrays_to_produce;
arrays_to_produce.push_back(dump_options.graphviz_last_array);
op_properties.color.TextColorString().c_str());
// Add nodes and edges for all inputs of the operator.
for (const auto& input : op.inputs) {
- if (model.arrays.count(input) == 0) {
+ if (!model.HasArray(input)) {
// Arrays should _always_ exist. Except, perhaps, during development.
continue;
}
}
// Add nodes and edges for all outputs of the operator.
for (const auto& output : op.outputs) {
- if (model.arrays.count(output) == 0) {
+ if (!model.HasArray(output)) {
// Arrays should _always_ exist. Except, perhaps, during development.
continue;
}
const_op->set_name(name);
(*const_op->mutable_attr())["dtype"].set_type(DT_FLOAT);
auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor();
- CHECK(model.arrays.count(name));
- const auto& input_array = *model.arrays.at(name);
+ CHECK(model.HasArray(name));
+ const auto& input_array = model.GetArray(name);
const auto& input_shape = input_array.shape();
CHECK(input_array.buffer);
CHECK(input_array.buffer->type == ArrayDataType::kFloat);
const_op->set_name(name);
(*const_op->mutable_attr())["dtype"].set_type(DT_FLOAT);
auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor();
- CHECK(model.arrays.count(name));
- const auto& input_array = *model.arrays.at(name);
+ CHECK(model.HasArray(name));
+ const auto& input_array = model.GetArray(name);
const auto& input_shape = input_array.shape();
CHECK(input_array.buffer);
CHECK(input_array.buffer->type == ArrayDataType::kFloat);
if (HasAlreadyExportedConst(name, *tensorflow_graph)) {
return;
}
- CHECK(model.arrays.count(name));
- const auto& array = *model.arrays.at(name);
+ CHECK(model.HasArray(name));
+ const auto& array = model.GetArray(name);
auto* const_op = tensorflow_graph->add_node();
const_op->set_op("Const");
const_op->set_name(name);
biasadd_op->add_input(conv_output);
biasadd_op->add_input(src_op.inputs[2]);
(*biasadd_op->mutable_attr())["T"].set_type(DT_FLOAT);
- CHECK(model.arrays.count(src_op.inputs[2]));
+ CHECK(model.HasArray(src_op.inputs[2]));
const string& bias_array_name =
WalkUpToConstantArray(model, src_op.inputs[2]);
const auto& bias_array = model.GetArray(bias_array_name);
// We need to convert that to H x W x InputDepth x Multiplier.
// That's only a matter of constructing a Dims object; the actual
// array layout is the same.
- CHECK(model.arrays.count(src_op.inputs[1]));
+ CHECK(model.HasArray(src_op.inputs[1]));
const string& src_weights_name =
WalkUpToConstantArray(model, src_op.inputs[1]);
const auto& src_weights_array = model.GetArray(src_weights_name);
biasadd_op->add_input(conv_output);
biasadd_op->add_input(src_op.inputs[2]);
(*biasadd_op->mutable_attr())["T"].set_type(DT_FLOAT);
- CHECK(model.arrays.count(src_op.inputs[2]));
+ CHECK(model.HasArray(src_op.inputs[2]));
const string& bias_name = WalkUpToConstantArray(model, src_op.inputs[2]);
const auto& bias_array = model.GetArray(bias_name);
// TODO(b/62904716) Bias arrays should be 1-D, and used directly.
(*matmul_op->mutable_attr())["T"].set_type(DT_FLOAT);
(*matmul_op->mutable_attr())["transpose_a"].set_b(false);
(*matmul_op->mutable_attr())["transpose_b"].set_b(false);
- CHECK(model.arrays.count(src_op.inputs[1]));
+ CHECK(model.HasArray(src_op.inputs[1]));
const string& fc_weights_name =
WalkUpToConstantArray(model, src_op.inputs[1]);
- const auto& fc_weights_array = *model.arrays.at(fc_weights_name);
+ const auto& fc_weights_array = model.GetArray(fc_weights_name);
const auto& fc_weights_shape = fc_weights_array.shape();
CHECK_EQ(fc_weights_shape.dimensions_count(), 2);
CreateMatrixShapeTensorConst(reshape_shape, fc_weights_shape.dims(1), -1,
biasadd_op->add_input(matmul_output);
biasadd_op->add_input(src_op.inputs[2]);
(*biasadd_op->mutable_attr())["T"].set_type(DT_FLOAT);
- CHECK(model.arrays.count(src_op.inputs[2]));
- const auto& bias_array = *model.arrays.at(src_op.inputs[2]);
+ CHECK(model.HasArray(src_op.inputs[2]));
+ const auto& bias_array = model.GetArray(src_op.inputs[2]);
// TODO(b/62904716) Bias arrays should be 1-D, and used directly.
Shape bias_shape_1d = bias_array.shape();
UnextendShape(&bias_shape_1d, 1);
*reshape_op->add_input() = softmax_size;
(*reshape_op->mutable_attr())["T"].set_type(DT_FLOAT);
- const auto& input_shape = model.arrays.at(src_op.inputs[0])->shape();
+ const auto& input_shape = model.GetArray(src_op.inputs[0]).shape();
int32 flattened_size = 1;
for (int i = 0; i < input_shape.dimensions_count() - 1; ++i) {
flattened_size *= input_shape.dims(i);
// Op names have been chosen to match the tf.slim LSTM naming
// as closely as possible.
const int axis =
- model.arrays.at(src_op.inputs[LstmCellOperator::PREV_ACTIV_INPUT])
- ->shape()
+ model.GetArray(src_op.inputs[LstmCellOperator::PREV_ACTIV_INPUT])
+ .shape()
.dimensions_count() -
1;
// Note that DATA_INPUT may have extra size 1 dimensions, but TF concat
// Write weights
const string weights_output = base + "weights";
- CHECK(model.arrays.count(src_op.inputs[LstmCellOperator::WEIGHTS_INPUT]));
+ CHECK(model.HasArray(src_op.inputs[LstmCellOperator::WEIGHTS_INPUT]));
const auto& weights_array =
- *model.arrays.at(src_op.inputs[LstmCellOperator::WEIGHTS_INPUT]);
+ model.GetArray(src_op.inputs[LstmCellOperator::WEIGHTS_INPUT]);
// Convert 4D FullyConnected weights into 2D matrix
const auto& weights_shape = weights_array.shape();
CHECK_EQ(weights_shape.dimensions_count(), 2);
// Write biases
const string biases_output = base + "biases";
- CHECK(model.arrays.count(src_op.inputs[LstmCellOperator::BIASES_INPUT]));
+ CHECK(model.HasArray(src_op.inputs[LstmCellOperator::BIASES_INPUT]));
const auto& bias_array =
- *model.arrays.at(src_op.inputs[LstmCellOperator::BIASES_INPUT]);
+ model.GetArray(src_op.inputs[LstmCellOperator::BIASES_INPUT]);
// TODO(b/62904716) Bias arrays should be 1-D, and used directly.
Shape bias_shape_1d = bias_array.shape();
UnextendShape(&bias_shape_1d, 1);
(*placeholder->mutable_attr())["dtype"].set_type(DT_FLOAT);
auto* shape = (*placeholder->mutable_attr())["shape"].mutable_shape();
- const auto& state_array = *model.arrays.at(name);
+ const auto& state_array = model.GetArray(name);
if (state_array.has_shape()) {
const auto& state_shape = state_array.shape();
const int kDims = state_shape.dimensions_count();
GraphDef* tensorflow_graph) {
for (const auto& input_array : model.flags.input_arrays()) {
AddPlaceholder(input_array.name(),
- model.arrays.at(input_array.name())->data_type,
+ model.GetArray(input_array.name()).data_type,
tensorflow_graph);
}
for (const auto& rnn_state : model.flags.rnn_states()) {
// by the above operators export. It's important that this comes
// after, as some operators need to export arrays that they reference
// in a specific way, rather than in the generic way done below.
- for (const auto& array_pair : model.arrays) {
+ for (const auto& array_pair : model.GetArrayMap()) {
const string& array_name = array_pair.first;
const auto& array = *array_pair.second;
if (array.buffer) {
CHECK_EQ(expand_op->inputs.size(), 2);
CHECK_EQ(expand_op->outputs.size(), 1);
- const auto& input_array = *model->arrays[expand_op->inputs[0]];
+ const auto& input_array = model->GetArray(expand_op->inputs[0]);
if (!input_array.has_shape()) {
// Yield until input dims have been resolved.
return false;
return false;
}
- const auto& axis_array = *model->arrays[expand_op->inputs[1]];
+ const auto& axis_array = model->GetArray(expand_op->inputs[1]);
if (!axis_array.has_shape()) {
// Yield until input axis array shape has been resolved.
return false;
if (IsDiscardableArray(*model, axis_array_name) &&
CountOpsWithInput(*model, axis_array_name) == 1 &&
!GetOpWithOutput(*model, axis_array_name)) {
- model->arrays.erase(axis_array_name);
+ model->EraseArray(axis_array_name);
}
// Replace the operator in the graph.
depthwiseconv_op->outputs = {conv_op->outputs[0]};
if (conv_op->outputs.size() > 1) {
// delete the im2col array.
- model->arrays.erase(conv_op->outputs[1]);
+ model->EraseArray(conv_op->outputs[1]);
}
depthwiseconv_op->fused_activation_function =
conv_op->fused_activation_function;
TransposeOperator* transpose_op =
static_cast<TransposeOperator*>(transpose_it->get());
- const auto& output_array = *model->arrays[transpose_op->outputs[0]];
+ const auto& output_array = model->GetArray(transpose_op->outputs[0]);
if (!output_array.has_shape()) {
// Yield until PropagateFixedSizes has been run on this op.
return false;
// Delete perm array if unused
if (IsDiscardableArray(*model, perm_array_name) &&
CountOpsWithInput(*model, perm_array_name) == 1) {
- model->arrays.erase(perm_array_name);
+ model->EraseArray(perm_array_name);
}
// Replace the operator in the graph.
// We already have an im2col array
return false;
}
- const auto& weights_array = *model->arrays[conv_op->inputs[1]];
+ const auto& weights_array = model->GetArray(conv_op->inputs[1]);
if (!weights_array.has_shape()) {
// We need to yield until weights dims have been resolved, because
// from the weights dims we determine whether an im2col array is
}
void ClearArrayQuantizationParams(const string& array_name, Model* model) {
- auto* array = model->arrays.at(array_name).get();
+ auto* array = &model->GetArray(array_name);
CHECK(array->quantization_params);
for (auto& input_array : *model->flags.mutable_input_arrays()) {
if (input_array.name() == array_name) {
bool DequantizeArray(const string& array_name,
GraphTransformation* transformation, Model* model) {
- auto* array = model->arrays.at(array_name).get();
+ auto* array = &model->GetArray(array_name);
if (!array->quantization_params) {
return false;
}
// Drop min/max inputs
for (int i = 1; i < fakequant_op->inputs.size(); i++) {
if (CountOpsWithInput(*model, fakequant_op->inputs[i]) == 1) {
- model->arrays.erase(fakequant_op->inputs[i]);
+ model->EraseArray(fakequant_op->inputs[i]);
}
}
fakequant_op->inputs.resize(1);
// Drop the im2col array.
CHECK_EQ(conv_op->outputs.size(), 2);
- model->arrays.erase(conv_op->outputs[1]);
+ model->EraseArray(conv_op->outputs[1]);
conv_op->outputs.resize(1);
AddMessageF("Dropped an im2col array for %s", LogName(*conv_op));
} else {
LOG(FATAL) << "Unhandled activation function type";
}
- model->arrays.erase(ac_op->inputs[0]);
+ model->EraseArray(ac_op->inputs[0]);
op->outputs[0] = ac_op->outputs[0];
model->operators.erase(ac_it);
return true;
AddMessageF("Fusing %s into the following %s", LogName(*binary_op),
LogName(*following_op));
- model->arrays.erase(binary_op->outputs[0]);
+ model->EraseArray(binary_op->outputs[0]);
following_op->inputs[0] = binary_op->inputs[index_of_variable_input];
const auto& old_constant_param_name =
binary_op->inputs[index_of_constant_input];
CHECK(IsConstantParameterArray(*model, old_constant_param_name));
if (CountOpsWithInput(*model, old_constant_param_name) == 1) {
- model->arrays.erase(old_constant_param_name);
+ model->EraseArray(old_constant_param_name);
}
model->operators.erase(binary_it);
return true;
LOG(FATAL) << "should not get here";
}
- model->arrays.erase(preceding_op->outputs[0]);
+ model->EraseArray(preceding_op->outputs[0]);
preceding_op->outputs[0] = binary_op->outputs[0];
preceding_op->fused_activation_function =
binary_op->fused_activation_function;
binary_op->inputs[index_of_constant_input];
CHECK(IsConstantParameterArray(*model, old_constant_param_name));
if (CountOpsWithInput(*model, old_constant_param_name) == 1) {
- model->arrays.erase(old_constant_param_name);
+ model->EraseArray(old_constant_param_name);
}
model->operators.erase(binary_it);
return true;
void PrintModelStats(const string& label, const Model& model) {
int quantized_arrays = 0;
- for (const auto& array : model.arrays) {
+ for (const auto& array : model.GetArrayMap()) {
if (array.second->quantization_params) {
quantized_arrays++;
}
}
LOG(INFO) << label << ": " << model.operators.size() << " operators, "
- << model.arrays.size() << " arrays (" << quantized_arrays
+ << model.GetArrayMap().size() << " arrays (" << quantized_arrays
<< " quantized)";
}
}
} while (found_new_useful_arrays);
// Erase arrays that aren't useful, and that are discardable.
- for (auto it = model->arrays.begin(); it != model->arrays.end();) {
- if (useful_arrays.count(it->first) ||
- !IsDiscardableArray(*model, it->first)) {
- ++it;
- } else {
- it = model->arrays.erase(it);
- }
- }
+ model->EraseArrays([&](const string& name) {
+ return (!useful_arrays.count(name) && IsDiscardableArray(*model, name));
+ });
// Erase operators that do not produce a useful output array.
for (auto it = model->operators.begin(); it != model->operators.end();) {
// Only need to test the first output, as we simultaneously added all of
std::vector<RnnState> rnn_states_to_keep;
for (const auto& rnn_state : model->flags.rnn_states()) {
const bool dangling =
- !model->arrays.count(rnn_state.back_edge_source_array()) ||
- !model->arrays.count(rnn_state.state_array());
+ !model->HasArray(rnn_state.back_edge_source_array()) ||
+ !model->HasArray(rnn_state.state_array());
if (dangling) {
CHECK(rnn_state.discardable());
} else {
// Erase the subgraph that is now replaced by L2Normalization
model->operators.erase(FindOperator(model, square_op));
- model->arrays.erase(sum_op->inputs[0]);
+ model->EraseArray(sum_op->inputs[0]);
if (sum_op->inputs.size() > 1) {
- model->arrays.erase(sum_op->inputs[1]);
+ model->EraseArray(sum_op->inputs[1]);
}
model->operators.erase(FindOperator(model, sum_op));
if (add_op) {
- model->arrays.erase(add_op->inputs[0]);
- model->arrays.erase(add_op->inputs[1]);
+ model->EraseArray(add_op->inputs[0]);
+ model->EraseArray(add_op->inputs[1]);
model->operators.erase(FindOperator(model, add_op));
}
- model->arrays.erase(sqrt_or_rsqrt_op->inputs[0]);
+ model->EraseArray(sqrt_or_rsqrt_op->inputs[0]);
model->operators.erase(FindOperator(model, sqrt_or_rsqrt_op));
- model->arrays.erase(div_or_mul_op->inputs[1]);
+ model->EraseArray(div_or_mul_op->inputs[1]);
model->operators.erase(FindOperator(model, div_or_mul_op));
return true;
}
AddMessageF("Creating %s replacing equivalent subgraph", LogName(*l2pool_op));
// Erase intermediate arrays, keeping input to square op.
- model->arrays.erase(avpool_op->inputs[0]);
- model->arrays.erase(sqrt_op->inputs[0]);
+ model->EraseArray(avpool_op->inputs[0]);
+ model->EraseArray(sqrt_op->inputs[0]);
// Erase three operators being replaced.
model->operators.erase(FindOperator(model, square_op));
AddMessageF("Creating %s replacing equivalent subgraph", LogName(*relu1_op));
// Erase Maximum scalar input & operator
- model->arrays.erase(maximum_op->inputs[scalar_input_index]);
+ model->EraseArray(maximum_op->inputs[scalar_input_index]);
model->operators.erase(FindOperator(model, maximum_op));
// Erase Minimum inputs & operator
- model->arrays.erase(minimum_op->inputs[0]);
- model->arrays.erase(minimum_op->inputs[1]);
+ model->EraseArray(minimum_op->inputs[0]);
+ model->EraseArray(minimum_op->inputs[1]);
model->operators.erase(FindOperator(model, minimum_op));
return true;
void SetDataTypeForAllOutputs(Model* model, Operator* op,
ArrayDataType data_type) {
for (const auto& output : op->outputs) {
- model->arrays[output]->data_type = data_type;
+ model->GetArray(output).data_type = data_type;
}
}
} // namespace
// If the data type of some input is unknown, we need to yield.
for (const auto& input : op->inputs) {
if (!model->IsOptionalArray(input) &&
- model->arrays[input]->data_type == ArrayDataType::kNone) {
+ model->GetArray(input).data_type == ArrayDataType::kNone) {
return false;
}
}
// end if we changed anything, and return the correct boolean value.
std::unordered_map<string, ArrayDataType> old_output_data_types;
for (const auto& output : op->outputs) {
- old_output_data_types[output] = model->arrays[output]->data_type;
+ old_output_data_types[output] = model->GetArray(output).data_type;
}
// Do the actual output data types propagation.
if (op->type == OperatorType::kDequantize ||
op->type == OperatorType::kFill) {
// These operators produce an output with the same type as their 2nd input
CHECK_GE(op->inputs.size(), 2);
- const ArrayDataType data_type = model->arrays[op->inputs[1]]->data_type;
+ const ArrayDataType data_type = model->GetArray(op->inputs[1]).data_type;
SetDataTypeForAllOutputs(model, op, data_type);
} else if (op->type == OperatorType::kCast) {
// Data type of the Cast op is specified.
CHECK_EQ(op->outputs.size(), 1);
auto* cast_op = static_cast<CastOperator*>(op);
- model->arrays[op->outputs[0]]->data_type = cast_op->dst_data_type;
+ model->GetArray(op->outputs[0]).data_type = cast_op->dst_data_type;
} else if (op->type == OperatorType::kArgMax) {
// Data type of the ArgMax op is specified.
CHECK_EQ(op->outputs.size(), 1);
auto* argmax_op = static_cast<ArgMaxOperator*>(op);
- model->arrays[op->outputs[0]]->data_type = argmax_op->output_data_type;
+ model->GetArray(op->outputs[0]).data_type = argmax_op->output_data_type;
} else if (op->type == OperatorType::kRange) {
auto* range_op = static_cast<RangeOperator*>(op);
// Output type of the Range op can be set via an attribute
} else {
// Otherwise use the first input
CHECK_GE(op->inputs.size(), 1);
- data_type = model->arrays[op->inputs[0]]->data_type;
+ data_type = model->GetArray(op->inputs[0]).data_type;
}
CHECK_EQ(op->outputs.size(), 1);
SetDataTypeForAllOutputs(model, op, data_type);
for (int i = 0; i < unsupported_op->output_data_types.size(); ++i) {
auto output = op->outputs[i];
auto data_type = unsupported_op->output_data_types[i];
- model->arrays[output]->data_type = data_type;
+ model->GetArray(output).data_type = data_type;
}
} else if (op->type == OperatorType::kExpandDims) {
// Yield on ExpandDim until it is converted to Reshape
} else {
// These operators produce outputs with the same type as their 1st input
CHECK_GT(op->inputs.size(), 0);
- const ArrayDataType data_type = model->arrays[op->inputs[0]]->data_type;
+ const ArrayDataType data_type = model->GetArray(op->inputs[0]).data_type;
SetDataTypeForAllOutputs(model, op, data_type);
}
// Return true if any output data type changed, false if none changed.
for (const auto& output : op->outputs) {
- if (old_output_data_types[output] != model->arrays[output]->data_type) {
+ if (old_output_data_types[output] != model->GetArray(output).data_type) {
return true;
}
}
int GetOutputDepthFromWeights(const Model& model, const Operator& op) {
const string& weights_name = op.inputs[1];
- const auto& weights_shape = model.arrays.at(weights_name)->shape();
+ const auto& weights_shape = model.GetArray(weights_name).shape();
if (op.type == OperatorType::kConv ||
op.type == OperatorType::kFullyConnected) {
return weights_shape.dims(0);
bool EnsureBiasVectorShape(Model* model, Operator* op) {
const string& weights_name = op->inputs[1];
- const auto& weights_array = *model->arrays[weights_name];
+ const auto& weights_array = model->GetArray(weights_name);
// Yield until weights shape has been resolved.
if (!weights_array.has_shape()) {
return false;
if (op->inputs.size() < 3) {
return false;
}
- auto& bias_array = *model->arrays[op->inputs[2]];
+ auto& bias_array = model->GetArray(op->inputs[2]);
if (bias_array.has_shape()) {
return true;
}
return;
}
- const auto& input_array = *model->arrays[op->inputs[0]];
+ const auto& input_array = model->GetArray(op->inputs[0]);
// Yield until input dims have been resolved.
if (!input_array.has_shape()) {
return;
const auto& input_shape = input_array.shape();
CHECK_EQ(input_shape.dimensions_count(), 4);
- const auto& weights_array = *model->arrays[op->inputs[1]];
+ const auto& weights_array = model->GetArray(op->inputs[1]);
// Yield until weights dims have been resolved.
if (!weights_array.has_shape()) {
return;
if (op->outputs.size() == 2) {
const auto& output_shape = output_array.shape();
const int input_depth = weights_shape.dims(3);
- auto& im2col_array = *model->arrays[op->outputs[1]];
+ auto& im2col_array = model->GetArray(op->outputs[1]);
im2col_array.copy_shape(Shape{output_shape.dims(0), output_shape.dims(1),
output_shape.dims(2),
input_depth * kheight * kwidth});
return;
}
- const auto& input_array = *model->arrays[op->inputs[0]];
+ const auto& input_array = model->GetArray(op->inputs[0]);
// Yield until input dims have been resolved.
if (!input_array.has_shape()) {
return;
const auto& input_shape = input_array.shape();
CHECK_EQ(input_shape.dimensions_count(), 4);
- const auto& weights_array = *model->arrays[op->inputs[1]];
+ const auto& weights_array = model->GetArray(op->inputs[1]);
// Yield until weights dims have been resolved.
if (!weights_array.has_shape()) {
return;
}
void ProcessDepthToSpaceOperator(Model* model, DepthToSpaceOperator* op) {
- const auto& input_array = *model->arrays[op->inputs[0]];
+ const auto& input_array = model->GetArray(op->inputs[0]);
// Yield until input dims have been resolved.
if (!input_array.has_shape()) {
return;
}
void ProcessSpaceToDepthOperator(Model* model, SpaceToDepthOperator* op) {
- const auto& input_array = *model->arrays[op->inputs[0]];
+ const auto& input_array = model->GetArray(op->inputs[0]);
// Yield until input dims have been resolved.
if (!input_array.has_shape()) {
return;
void ProcessFillOperator(Model* model, FillOperator* op) {
CHECK_EQ(op->inputs.size(), 2);
CHECK_EQ(op->outputs.size(), 1);
- auto& output_array = *model->arrays[op->outputs[0]];
+ auto& output_array = model->GetArray(op->outputs[0]);
if (output_array.has_shape()) {
// We have already run
return;
return;
}
- const auto& input_array = *model->arrays[op->inputs[0]];
+ const auto& input_array = model->GetArray(op->inputs[0]);
// Yield until input dims have been resolved.
if (!input_array.has_shape()) {
return;
const auto& input_shape = input_array.shape();
CHECK_GE(input_shape.dimensions_count(), 1);
- const auto& weights_array = *model->arrays[op->inputs[1]];
+ const auto& weights_array = model->GetArray(op->inputs[1]);
// Yield until weights dims have been resolved.
if (!weights_array.has_shape()) {
return;
void ProcessTensorFlowReshapeOperator(Model* model,
TensorFlowReshapeOperator* op) {
- auto& output_array = *model->arrays[op->outputs[0]];
+ auto& output_array = model->GetArray(op->outputs[0]);
if (output_array.has_shape()) {
// We have already run
return;
}
- const auto& input_array = *model->arrays[op->inputs[0]];
+ const auto& input_array = model->GetArray(op->inputs[0]);
if (!input_array.has_shape()) {
// Yield until input dims have been resolved.
return;
}
void ProcessSimpleOperator(Model* model, Operator* op) {
- const auto& input_array = *model->arrays[op->inputs[0]];
+ const auto& input_array = model->GetArray(op->inputs[0]);
// Yield until input dims have been resolved.
if (!input_array.has_shape()) {
return;
}
const string& output_name = op->outputs[0];
- auto& output_array = *model->arrays[output_name];
+ auto& output_array = model->GetArray(output_name);
if (output_array.has_shape()) {
return;
}
void ProcessSimpleBinaryOperator(Model* model, Operator* op) {
CHECK_EQ(op->inputs.size(), 2);
- const auto& input0_array = *model->arrays[op->inputs[0]];
- const auto& input1_array = *model->arrays[op->inputs[1]];
+ const auto& input0_array = model->GetArray(op->inputs[0]);
+ const auto& input1_array = model->GetArray(op->inputs[1]);
// Yield until input dims have been resolved.
if (!input0_array.has_shape() || !input1_array.has_shape()) {
return;
}
const string& output_name = op->outputs[0];
- auto& output_array = *model->arrays[output_name];
+ auto& output_array = model->GetArray(output_name);
ComputeBinaryOperatorOutputSize(input0_array.shape(), input1_array.shape(),
&output_array);
}
void ProcessTensorFlowReductionOperator(Model* model, Operator* op) {
CHECK_LE(op->inputs.size(), 2);
- auto& output_array = *model->arrays[op->outputs[0]];
+ auto& output_array = model->GetArray(op->outputs[0]);
if (output_array.has_shape()) {
return;
}
- const auto& input_array = *model->arrays[op->inputs[0]];
+ const auto& input_array = model->GetArray(op->inputs[0]);
if (!input_array.has_shape()) {
return;
}
const bool keep_dims = KeepDims(*op);
if (op->inputs.size() == 2) {
// There is a reduction_indices input.
- const auto& reduction_array = *model->arrays[op->inputs[1]];
+ const auto& reduction_array = model->GetArray(op->inputs[1]);
if (!reduction_array.buffer) {
return;
}
if (op->begin.empty()) return;
// Yield until input dims have been resolved.
- const auto& input_array = *model->arrays[op->inputs[0]];
+ const auto& input_array = model->GetArray(op->inputs[0]);
if (!input_array.has_shape()) return;
const Shape& input_shape = input_array.shape();
- auto& output_array = *model->arrays[op->outputs[0]];
+ auto& output_array = model->GetArray(op->outputs[0]);
if (output_array.has_shape()) return;
CHECK_EQ(input_shape.dims().size(), op->size.size());
void ProcessReorderAxesOperator(Model* model, ReorderAxesOperator* op) {
const string& input_name = op->inputs[0];
- const auto& input_array = *model->arrays[input_name];
+ const auto& input_array = model->GetArray(input_name);
// Yield until input dims have been resolved.
if (!input_array.has_shape()) {
return;
void ProcessConcatenationOperator(Model* model, ConcatenationOperator* op) {
// Yield until input dims have been resolved.
for (const auto& input_name : op->inputs) {
- auto& input_array = *model->arrays[input_name];
+ auto& input_array = model->GetArray(input_name);
if (!input_array.has_shape()) {
return;
}
}
auto& output_array = model->GetArray(op->outputs[0]);
// Use 0 input as basis for output dimensions.
- const auto& first_input_array = *model->arrays[op->inputs[0]];
+ const auto& first_input_array = model->GetArray(op->inputs[0]);
output_array.copy_shape(first_input_array.shape());
// Determine the concat size, and enfore that all inputs have
// the same dimensions count.
int concat_size = 0;
for (const auto& input_name : op->inputs) {
- auto& input_array = *model->arrays[input_name];
+ auto& input_array = model->GetArray(input_name);
CHECK(input_array.has_shape());
if (input_array.shape().dimensions_count() == 0) {
continue;
void ProcessRangeOperator(Model* model, RangeOperator* op) {
CHECK_EQ(op->inputs.size(), 3);
- const auto& start_array = *model->arrays[op->inputs[0]];
+ const auto& start_array = model->GetArray(op->inputs[0]);
if (!start_array.has_shape()) {
// Yield until input dims have been resolved.
return;
}
- const auto& limit_array = *model->arrays[op->inputs[1]];
+ const auto& limit_array = model->GetArray(op->inputs[1]);
if (!limit_array.has_shape()) {
return;
}
- const auto& delta_array = *model->arrays[op->inputs[2]];
+ const auto& delta_array = model->GetArray(op->inputs[2]);
if (!delta_array.has_shape()) {
return;
}
void ProcessTensorFlowSplitOperator(Model* model, TensorFlowSplitOperator* op) {
CHECK_EQ(op->inputs.size(), 2);
const string& input_name = op->inputs[1];
- const auto& input_array = *model->arrays[input_name];
+ const auto& input_array = model->GetArray(input_name);
// Yield until input dims have been resolved.
if (!input_array.has_shape()) {
return;
CHECK_EQ(op->outputs.size(), op->num_split);
for (const auto& output : op->outputs) {
- model->arrays[output]->copy_shape(output_shape);
+ model->GetArray(output).copy_shape(output_shape);
}
}
void ProcessAveragePoolOperator(Model* model, AveragePoolOperator* op) {
const string& input_name = op->inputs[0];
- const auto& input_array = *model->arrays[input_name];
+ const auto& input_array = model->GetArray(input_name);
// Yield until input dims have been resolved.
if (!input_array.has_shape()) {
return;
void ProcessMaxPoolOperator(Model* model, MaxPoolOperator* op) {
const string& input_name = op->inputs[0];
- const auto& input_array = *model->arrays[input_name];
+ const auto& input_array = model->GetArray(input_name);
// Yield until input dims have been resolved.
if (!input_array.has_shape()) {
return;
void ProcessL2PoolOperator(Model* model, L2PoolOperator* op) {
const string& input_name = op->inputs[0];
- const auto& input_array = *model->arrays[input_name];
+ const auto& input_array = model->GetArray(input_name);
// Yield until input dims have been resolved.
if (!input_array.has_shape()) {
return;
CHECK_EQ(op->inputs.size(), 2);
CHECK_EQ(op->outputs.size(), 1);
- if (!model->arrays[op->inputs[0]]->has_shape() ||
- !model->arrays[op->inputs[1]]->has_shape()) {
+ if (!model->GetArray(op->inputs[0]).has_shape() ||
+ !model->GetArray(op->inputs[1]).has_shape()) {
return;
}
- const auto& input_data_shape = model->arrays[op->inputs[0]]->shape();
+ const auto& input_data_shape = model->GetArray(op->inputs[0]).shape();
const string& output_size_name = op->inputs[1];
- const auto& output_size_array = *model->arrays[output_size_name];
+ const auto& output_size_array = model->GetArray(output_size_name);
CHECK(output_size_array.data_type == ArrayDataType::kInt32);
CHECK(output_size_array.has_shape());
const auto& output_size_shape = output_size_array.shape();
}
std::vector<int32> output_shape =
output_size_array.GetBuffer<ArrayDataType::kInt32>().data;
- model->arrays[op->outputs[0]]->copy_shape(
- Shape({input_data_shape.dims(0), output_shape[0], output_shape[1],
- input_data_shape.dims(3)}));
+ model->GetArray(op->outputs[0])
+ .copy_shape(Shape({input_data_shape.dims(0), output_shape[0],
+ output_shape[1], input_data_shape.dims(3)}));
}
void ProcessLstmCellOperator(Model* model, LstmCellOperator* op) {
QCHECK_EQ(op->outputs.size(), LstmCellOperator::NUM_OUTPUTS);
const auto& input_array =
- *model->arrays[op->inputs[LstmCellOperator::DATA_INPUT]];
+ model->GetArray(op->inputs[LstmCellOperator::DATA_INPUT]);
// Yield until all input dims have been resolved.
if (!input_array.has_shape()) {
return;
CHECK_GE(input_shape.dimensions_count(), 2);
const auto& prev_activ_array =
- *model->arrays[op->inputs[LstmCellOperator::PREV_ACTIV_INPUT]];
+ model->GetArray(op->inputs[LstmCellOperator::PREV_ACTIV_INPUT]);
// Yield until all input dims have been resolved.
if (!prev_activ_array.has_shape()) {
return;
CHECK_GE(prev_activ_shape.dimensions_count(), 2);
const auto& weights_array =
- *model->arrays[op->inputs[LstmCellOperator::WEIGHTS_INPUT]];
+ model->GetArray(op->inputs[LstmCellOperator::WEIGHTS_INPUT]);
// Yield until weights dims have been resolved.
if (!weights_array.has_shape()) {
return;
CHECK_EQ(weights_shape.dimensions_count(), 2);
const auto& bias_array =
- *model->arrays[op->inputs[LstmCellOperator::BIASES_INPUT]];
+ model->GetArray(op->inputs[LstmCellOperator::BIASES_INPUT]);
// Yield until bias dims have been resolved.
if (!bias_array.has_shape()) {
return;
CHECK_GE(bias_shape.dimensions_count(), 1);
const auto& prev_state_array =
- *model->arrays[op->inputs[LstmCellOperator::PREV_STATE_INPUT]];
+ model->GetArray(op->inputs[LstmCellOperator::PREV_STATE_INPUT]);
// Yield until all input dims have been resolved.
if (!prev_state_array.has_shape()) {
return;
}
void ProcessSpaceToBatchNDOperator(Model* model, SpaceToBatchNDOperator* op) {
- const auto& input_array = *model->arrays[op->inputs[0]];
+ const auto& input_array = model->GetArray(op->inputs[0]);
// Yield until input dims have been resolved.
if (!input_array.has_shape()) {
return;
const auto input_height = input_shape.dims(1);
const auto input_width = input_shape.dims(2);
- const auto& block_shape_array = *model->arrays[op->inputs[1]];
- const auto& paddings_array = *model->arrays[op->inputs[2]];
+ const auto& block_shape_array = model->GetArray(op->inputs[1]);
+ const auto& paddings_array = model->GetArray(op->inputs[2]);
const auto& block_shape_array_shape = block_shape_array.shape();
const auto& paddings_array_shape = paddings_array.shape();
QCHECK_EQ(block_shape_array_shape.dimensions_count(), 1);
int output_height = height_with_paddings / block_height;
int output_width = width_with_paddings / block_width;
- model->arrays[op->outputs[0]]->copy_shape(
- Shape({input_shape.dims(0) * block_height * block_width, output_height,
- output_width, input_shape.dims(3)}));
+ model->GetArray(op->outputs[0])
+ .copy_shape(Shape({input_shape.dims(0) * block_height * block_width,
+ output_height, output_width, input_shape.dims(3)}));
}
void ProcessBatchToSpaceNDOperator(Model* model, BatchToSpaceNDOperator* op) {
- const auto& input_array = *model->arrays[op->inputs[0]];
+ const auto& input_array = model->GetArray(op->inputs[0]);
// Yield until input dims have been resolved.
if (!input_array.has_shape()) {
return;
const auto input_height = input_shape.dims(1);
const auto input_width = input_shape.dims(2);
- const auto& block_shape_array = *model->arrays[op->inputs[1]];
- const auto& crops_array = *model->arrays[op->inputs[2]];
+ const auto& block_shape_array = model->GetArray(op->inputs[1]);
+ const auto& crops_array = model->GetArray(op->inputs[2]);
const auto& block_shape_array_shape = block_shape_array.shape();
const auto& crops_array_shape = crops_array.shape();
QCHECK_EQ(block_shape_array_shape.dimensions_count(), 1);
int output_height = input_height * block_height;
int output_width = input_width * block_width;
- model->arrays[op->outputs[0]]->copy_shape(
- Shape({input_shape.dims(0) / (block_height * block_width), output_height,
- output_width, input_shape.dims(3)}));
+ model->GetArray(op->outputs[0])
+ .copy_shape(Shape({input_shape.dims(0) / (block_height * block_width),
+ output_height, output_width, input_shape.dims(3)}));
}
void ProcessGatherOperator(Model* model, GatherOperator* op) {
- const auto& input_array = *model->arrays[op->inputs[0]];
- const auto& indices_array = *model->arrays[op->inputs[1]];
- auto& output_array = *model->arrays[op->outputs[0]];
+ const auto& input_array = model->GetArray(op->inputs[0]);
+ const auto& indices_array = model->GetArray(op->inputs[1]);
+ auto& output_array = model->GetArray(op->outputs[0]);
// Bail if we already know the output shape.
if (output_array.has_shape()) {
CHECK_EQ(op->inputs.size(), 2);
CHECK_EQ(op->outputs.size(), 1);
- const auto& input_array = *model->arrays[op->inputs[0]];
+ const auto& input_array = model->GetArray(op->inputs[0]);
// Yield until input dims have been resolved.
if (!input_array.has_shape()) return;
if (op->left_padding.empty()) return;
CHECK_EQ(op->left_padding.size(), op->right_padding.size());
- auto& output_array = *model->arrays[op->outputs[0]];
+ auto& output_array = model->GetArray(op->outputs[0]);
if (output_array.has_shape()) return;
Shape output_shape = input_array.shape();
void ProcessRankOperator(Model* model, RankOperator* op) {
CHECK_GE(op->inputs.size(), 1);
CHECK_EQ(op->outputs.size(), 1);
- auto& output_array = *model->arrays[op->outputs[0]];
+ auto& output_array = model->GetArray(op->outputs[0]);
if (output_array.has_shape()) {
// Shape already propagated
return;
}
- const auto& input_array = *model->arrays[op->inputs[0]];
+ const auto& input_array = model->GetArray(op->inputs[0]);
if (!input_array.has_shape()) {
// Yield until input dims have been resolved.
return;
void ProcessShapeOperator(Model* model, TensorFlowShapeOperator* op) {
CHECK_GE(op->inputs.size(), 1);
CHECK_EQ(op->outputs.size(), 1);
- auto& output_array = *model->arrays[op->outputs[0]];
+ auto& output_array = model->GetArray(op->outputs[0]);
if (output_array.has_shape()) {
// Shape already propagated
return;
}
- const auto& input_array = *model->arrays[op->inputs[0]];
+ const auto& input_array = model->GetArray(op->inputs[0]);
if (!input_array.has_shape()) {
// Yield until input dims have been resolved.
return;
void ProcessStackOperator(Model* model, StackOperator* op) {
CHECK_GE(op->inputs.size(), 1);
CHECK_EQ(op->outputs.size(), 1);
- auto& output_array = *model->arrays[op->outputs[0]];
+ auto& output_array = model->GetArray(op->outputs[0]);
if (output_array.has_shape()) {
// Shape already propagated
return;
void ProcessStridedSliceOperator(Model* model, StridedSliceOperator* op) {
CHECK_GE(op->inputs.size(), 1);
CHECK_EQ(op->outputs.size(), 1);
- auto& output_array = *model->arrays[op->outputs[0]];
+ auto& output_array = model->GetArray(op->outputs[0]);
if (output_array.has_shape()) {
// Shape already propagated
return;
CHECK_EQ(op->inputs.size(), 1);
CHECK_EQ(op->outputs.size(), 1);
- const auto& input_array = *model->arrays[op->inputs[0]];
+ const auto& input_array = model->GetArray(op->inputs[0]);
// Yield until input dims have been resolved.
if (!input_array.has_shape()) return;
- auto& output_array = *model->arrays[op->outputs[0]];
+ auto& output_array = model->GetArray(op->outputs[0]);
if (output_array.has_shape()) return;
const std::vector<int>& input_dims = input_array.shape().dims();
void ProcessSvdfOperator(Model* model, SvdfOperator* op) {
CHECK(op->inputs.size() == 3 || op->inputs.size() == 4);
- const auto& input_array = *model->arrays[op->inputs[0]];
+ const auto& input_array = model->GetArray(op->inputs[0]);
if (!input_array.has_shape()) return;
- auto& weights_feature_array = *model->arrays[op->inputs[1]];
+ auto& weights_feature_array = model->GetArray(op->inputs[1]);
if (!weights_feature_array.has_shape()) return;
- const auto& weights_time_array = *model->arrays[op->inputs[2]];
+ const auto& weights_time_array = model->GetArray(op->inputs[2]);
if (!weights_time_array.has_shape()) return;
const bool has_bias = (op->inputs.size() == 4);
if (has_bias) {
- const auto& bias_array = *model->arrays[op->inputs[3]];
+ const auto& bias_array = model->GetArray(op->inputs[3]);
if (!bias_array.has_shape()) return;
}
}
void ProcessTransposeOperator(Model* model, TransposeOperator* op) {
- auto& output_array = *model->arrays[op->outputs[0]];
+ auto& output_array = model->GetArray(op->outputs[0]);
if (output_array.has_shape()) {
// We have already run
return;
}
- const auto& input_array = *model->arrays[op->inputs[0]];
+ const auto& input_array = model->GetArray(op->inputs[0]);
if (!input_array.has_shape()) {
// Yield until input dims have been resolved.
return;
void ProcessArgMaxOperator(Model* model, ArgMaxOperator* op) {
CHECK_EQ(op->inputs.size(), 2);
- const auto& input_array = *model->arrays[op->inputs[0]];
+ const auto& input_array = model->GetArray(op->inputs[0]);
// Yield until input dims have been resolved.
if (!input_array.has_shape()) {
return;
}
output_dims.push_back(1);
const string& output_name = op->outputs[0];
- auto& output_array = *model->arrays[output_name];
+ auto& output_array = model->GetArray(output_name);
if (output_array.has_shape()) {
return;
}
auto* op = it->get();
std::unordered_map<string, std::vector<int>> old_output_dims;
for (const auto& output : op->outputs) {
- if (model->arrays[output]->has_shape()) {
- old_output_dims[output] = model->arrays[output]->shape().dims();
+ if (model->GetArray(output).has_shape()) {
+ old_output_dims[output] = model->GetArray(output).shape().dims();
}
}
// Return true if any output dim changed, false if none changed.
// Assumption: no transformation clears an output shape, they only add shapes.
for (const auto& output : op->outputs) {
- if (model->arrays[output]->has_shape() &&
- (old_output_dims[output] != model->arrays[output]->shape().dims())) {
+ if (model->GetArray(output).has_shape() &&
+ (old_output_dims[output] != model->GetArray(output).shape().dims())) {
AddMessageF("Set shape of %s to [%s]", output,
- absl::StrJoin(model->arrays[output]->shape().dims(), ","));
+ absl::StrJoin(model->GetArray(output).shape().dims(), ","));
return true;
}
}
model->flags.set_output_arrays(i, dequantize_op->inputs[0]);
}
}
- model->arrays.erase(dequantize_op->outputs[0]);
+ model->EraseArray(dequantize_op->outputs[0]);
model->operators.erase(dequantize_it);
}
}
// else.
for (int i = 1; i <= 2; i++) {
if (CountOpsWithInput(*model, fakequant_op->inputs[i]) == 1) {
- model->arrays.erase(fakequant_op->inputs[i]);
+ model->EraseArray(fakequant_op->inputs[i]);
}
}
fakequant_op->inputs.resize(1);
// Remove the node and its output array.
AddMessageF("Removed final %s", LogName(*dequantize_op));
- model->arrays.erase(output);
+ model->EraseArray(output);
model->operators.erase(dequantize_it);
return true;
}
// Now check if the constant operand makes this binary
// operator trivial.
const auto& constant_input_array =
- *model->arrays[binary_op->inputs[index_of_constant_input]];
+ model->GetArray(binary_op->inputs[index_of_constant_input]);
// For now, we only handle floats here.
if (constant_input_array.data_type != ArrayDataType::kFloat) {
return false;
for (const string& input : trivial_inputs) {
if (IsDiscardableArray(*model, input) &&
CountOpsWithInput(*model, input) == 1) {
- model->arrays.erase(input);
+ model->EraseArray(input);
}
}
concat_op->inputs = nontrivial_inputs;
}
}
if (!is_referenced) {
- model->arrays.erase(removal_candidate);
+ model->EraseArray(removal_candidate);
}
}
// the model. We allow specifying an arbitrary input_array,
// treating the part of the graph leading up to it as unused.
for (const auto& output : op->outputs) {
- CHECK(model->arrays.count(output));
+ CHECK(model->HasArray(output));
// If this output is provided as the model's input array,
// then we don't need this operator to produce its contents.
if (IsInputArray(*model, output)) {
if (IsDiscardableArray(*model, input) &&
CountOpsWithInput(*model, input) == 1 &&
!GetOpWithOutput(*model, input)) {
- model->arrays.erase(input);
+ model->EraseArray(input);
}
}
continue;
}
// Generic case: do delete this output array.
- model->arrays.erase(output);
+ model->EraseArray(output);
}
model->operators.erase(it);
return true;
}
// Remove the old param arrays
- model->arrays.erase(bn_op->inputs[1]);
- model->arrays.erase(bn_op->inputs[2]);
- model->arrays.erase(bn_op->inputs[3]);
+ model->EraseArray(bn_op->inputs[1]);
+ model->EraseArray(bn_op->inputs[2]);
+ model->EraseArray(bn_op->inputs[3]);
// Remove the old operator
DCHECK_EQ(bn_it->get(), bn_op);
return false;
// Handle crops
- const auto& crops_array = *model->arrays[op->inputs[2]];
+ const auto& crops_array = model->GetArray(op->inputs[2]);
if (!crops_array.has_shape()) return false;
const std::vector<int>& crops_dims = crops_array.shape().dims();
if (crops_dims.size() != 2) {
}
// Handle block_shape
- const auto& block_shape_array = *model->arrays[op->inputs[1]];
+ const auto& block_shape_array = model->GetArray(op->inputs[1]);
if (!block_shape_array.has_shape()) return false;
const std::vector<int>& block_shape_dims = block_shape_array.shape().dims();
CHECK_EQ(block_shape_dims.size(), 1);
void EvaluateBinaryOperatorOnConstantInputs(Model* model,
const Operator* binary_op) {
- const auto inputs_data_type = model->arrays[binary_op->inputs[0]]->data_type;
- const auto output_data_type = model->arrays[binary_op->outputs[0]]->data_type;
+ const auto inputs_data_type = model->GetArray(binary_op->inputs[0]).data_type;
+ const auto output_data_type =
+ model->GetArray(binary_op->outputs[0]).data_type;
#define TOCO_HANDLE_CASE(InputsDataType, OutputDataType) \
if (inputs_data_type == InputsDataType && \
output_data_type == OutputDataType) { \
return false;
}
- auto& output_array = *model->arrays[binary_op->outputs[0]];
+ auto& output_array = model->GetArray(binary_op->outputs[0]);
// Yield until the output array dims have been resolved.
if (!output_array.has_shape()) {
return false;
// Remove the binary operator and its inputs
if (CountOpsWithInput(*model, binary_op->inputs[0]) == 1) {
- model->arrays.erase(binary_op->inputs[0]);
+ model->EraseArray(binary_op->inputs[0]);
}
if (CountOpsWithInput(*model, binary_op->inputs[1]) == 1) {
- model->arrays.erase(binary_op->inputs[1]);
+ model->EraseArray(binary_op->inputs[1]);
}
AddMessageF("Resolved constant %s to the equivalent constant array",
LogName(*binary_op));
// Remove all the resolved arrays.
for (const string& input_name : concat_op->inputs) {
- model->arrays.erase(input_name);
+ model->EraseArray(input_name);
}
// Remove concatenate operator
output_buffer.data[i] = dst_val;
}
if (CountOpsWithInput(*model, fakequant_op->inputs[0]) == 1) {
- model->arrays.erase(fakequant_op->inputs[0]);
+ model->EraseArray(fakequant_op->inputs[0]);
}
model->operators.erase(fakequant_it);
// Erase input arrays if no longer used
if (IsDiscardableArray(*model, op->inputs[0]) &&
CountOpsWithInput(*model, op->inputs[0]) == 1) {
- model->arrays.erase(op->inputs[0]);
+ model->EraseArray(op->inputs[0]);
}
if (IsDiscardableArray(*model, op->inputs[1]) &&
CountOpsWithInput(*model, op->inputs[1]) == 1) {
- model->arrays.erase(op->inputs[1]);
+ model->EraseArray(op->inputs[1]);
}
// Erase the operator
auto* op = static_cast<RangeOperator*>(base_op);
CHECK_EQ(op->inputs.size(), 3);
- const auto& start_array = *model->arrays[op->inputs[0]];
+ const auto& start_array = model->GetArray(op->inputs[0]);
if (!start_array.has_shape()) {
// Yield until all input dims have been resolved.
return false;
}
- const auto& limit_array = *model->arrays[op->inputs[1]];
+ const auto& limit_array = model->GetArray(op->inputs[1]);
if (!limit_array.has_shape()) {
// Yield until all input dims have been resolved.
return false;
}
- const auto& delta_array = *model->arrays[op->inputs[2]];
+ const auto& delta_array = model->GetArray(op->inputs[2]);
if (!delta_array.has_shape()) {
// Yield until all input dims have been resolved.
return false;
}
CHECK_EQ(op->outputs.size(), 1);
- auto& output_array = *model->arrays[op->outputs[0]];
+ auto& output_array = model->GetArray(op->outputs[0]);
if (output_array.data_type == ArrayDataType::kNone) {
// Yield until the output type has been set by PropagateArrayDataTypes
return false;
// Delete the input array if no longer used
if (IsDiscardableArray(*model, op->inputs[0]) &&
CountOpsWithInput(*model, op->inputs[0]) == 1) {
- model->arrays.erase(op->inputs[0]);
+ model->EraseArray(op->inputs[0]);
}
if (IsDiscardableArray(*model, op->inputs[1]) &&
CountOpsWithInput(*model, op->inputs[1]) == 1) {
- model->arrays.erase(op->inputs[1]);
+ model->EraseArray(op->inputs[1]);
}
if (IsDiscardableArray(*model, op->inputs[2]) &&
CountOpsWithInput(*model, op->inputs[2]) == 1) {
- model->arrays.erase(op->inputs[2]);
+ model->EraseArray(op->inputs[2]);
}
// Delete the operator
// Delete the input array if no longer used
if (IsDiscardableArray(*model, op->inputs[0]) &&
CountOpsWithInput(*model, op->inputs[0]) == 1) {
- model->arrays.erase(op->inputs[0]);
+ model->EraseArray(op->inputs[0]);
}
model->operators.erase(it);
for (const auto& input : op->inputs) {
if (IsDiscardableArray(*model, input) &&
CountOpsWithInput(*model, input) == 1) {
- model->arrays.erase(input);
+ model->EraseArray(input);
}
}
// Erase input array if no longer used
if (IsDiscardableArray(*model, op->inputs[0]) &&
CountOpsWithInput(*model, op->inputs[0]) == 1) {
- model->arrays.erase(op->inputs[0]);
+ model->EraseArray(op->inputs[0]);
}
// Erase the operator
}
for (const auto& input : unary_op->inputs) {
if (CountOpsWithInput(*model, input) == 1) {
- model->arrays.erase(input);
+ model->EraseArray(input);
}
}
AddMessageF("Resolved constant %s to the equivalent constant array",
if (op->inputs.size() != 2) return false;
if (!IsConstantParameterArray(*model, op->inputs[1])) return false;
- const auto& indices_array = *model->arrays[op->inputs[1]];
+ const auto& indices_array = model->GetArray(op->inputs[1]);
if (!indices_array.has_shape()) return false;
op->axis = indices_array.GetBuffer<ArrayDataType::kInt32>().data;
return true;
CHECK_EQ(op->inputs.size(), 2);
if (!IsConstantParameterArray(*model, op->inputs[1])) return false;
- const auto& array = *model->arrays[op->inputs[1]];
+ const auto& array = model->GetArray(op->inputs[1]);
if (!array.has_shape()) return false;
const std::vector<int>& dims = array.shape().dims();
AddMessageF("Reordered axes for array %s", input_array_name);
// Remove the op and output array.
- model->arrays.erase(output_array_name);
+ model->EraseArray(output_array_name);
model->operators.erase(reorder_it);
return true;
}
if (!op->shape.empty()) return false;
if (IsConstantParameterArray(*model, reshape_op->inputs[1])) {
- const auto& constant_input_array = *model->arrays[reshape_op->inputs[1]];
+ const auto& constant_input_array = model->GetArray(reshape_op->inputs[1]);
op->shape = constant_input_array.GetBuffer<ArrayDataType::kInt32>().data;
}
if (!IsConstantParameterArray(*model, op->inputs[1])) return false;
if (!IsConstantParameterArray(*model, op->inputs[2])) return false;
- const auto& begin_array = *model->arrays[op->inputs[1]];
+ const auto& begin_array = model->GetArray(op->inputs[1]);
if (!begin_array.has_shape()) return false;
- const auto& size_array = *model->arrays[op->inputs[2]];
+ const auto& size_array = model->GetArray(op->inputs[2]);
if (!size_array.has_shape()) return false;
op->begin = begin_array.GetBuffer<ArrayDataType::kInt32>().data;
return false;
// Handle paddings.
- const auto& paddings_array = *model->arrays[op->inputs[paddings_index]];
+ const auto& paddings_array = model->GetArray(op->inputs[paddings_index]);
if (!paddings_array.has_shape()) return false;
const std::vector<int>& paddings_dims = paddings_array.shape().dims();
if (paddings_dims.size() != 2) {
}
// Handle block_shape.
- const auto& block_shape_array = *model->arrays[op->inputs[block_shape_index]];
+ const auto& block_shape_array =
+ model->GetArray(op->inputs[block_shape_index]);
if (!block_shape_array.has_shape()) return false;
const std::vector<int>& block_shape_dims = block_shape_array.shape().dims();
CHECK_EQ(block_shape_dims.size(), 1);
}
CHECK_EQ(op->inputs.size(), 4);
- const auto& start_array = *model->arrays[op->inputs[1]];
+ const auto& start_array = model->GetArray(op->inputs[1]);
if (!start_array.has_shape()) return false;
- const auto& stop_array = *model->arrays[op->inputs[2]];
+ const auto& stop_array = model->GetArray(op->inputs[2]);
if (!stop_array.has_shape()) return false;
- const auto& stride_array = *model->arrays[op->inputs[3]];
+ const auto& stride_array = model->GetArray(op->inputs[3]);
if (!stride_array.has_shape()) return false;
if (!IsConstantParameterArray(*model, op->inputs[1])) return false;
// Remove the axis array if it is not used by anything else.
if (CountOpsWithInput(*model, axis_name) == 1) {
- model->arrays.erase(axis_name);
+ model->EraseArray(axis_name);
}
// Remove the TensorFlowConcat op
model->operators.erase(concat_it);
LogName(*matmul_op), LogName(*fc_op));
const auto& previous_op_output = previous_op->outputs[0];
if (CountOpsWithInput(*model, previous_op_output) == 1) {
- model->arrays.erase(previous_op_output);
+ model->EraseArray(previous_op_output);
}
CHECK_EQ(previous_op->inputs.size(), 2);
fc_op->inputs = {previous_op->inputs[0], matmul_op->inputs[1]};
const auto& previous_op_shape = previous_op->inputs[1];
if (CountOpsWithInput(*model, previous_op_shape) == 1 &&
!GetOpWithOutput(*model, previous_op_shape)) {
- model->arrays.erase(previous_op_shape);
+ model->EraseArray(previous_op_shape);
}
model->operators.erase(previous_op_it);
}
// Remove the node and its output array.
AddMessageF("Removing already-resolved %s", LogName(*merge_op));
- model->arrays.erase(merge_op->outputs[0]);
+ model->EraseArray(merge_op->outputs[0]);
model->operators.erase(merge_it);
return true;
}
// Remove the output arrays if they are now unused.
for (int i = 0; i < 2; i++) {
if (!GetOpWithInput(*model, switch_op->outputs[i])) {
- model->arrays.erase(switch_op->outputs[i]);
+ model->EraseArray(switch_op->outputs[i]);
}
}
// Remove input arrays if they are only used by the switch itself and aren't
for (const auto& input : switch_op->inputs) {
if (CountOpsWithInput(*model, input) == 1 &&
!GetOpWithOutput(*model, input)) {
- model->arrays.erase(input);
+ model->EraseArray(input);
}
}
// Remove the switch node itself.
model->operators.erase(tile_it);
if (!CountOpsWithInput(*model, tile_multiplier_array) &&
!GetOpWithOutput(*model, tile_multiplier_array)) {
- model->arrays.erase(tile_multiplier_array);
+ model->EraseArray(tile_multiplier_array);
}
if (!CountOpsWithInput(*model, tile_output_array)) {
- model->arrays.erase(tile_output_array);
+ model->EraseArray(tile_output_array);
}
}
} // namespace
if (!IsConstantParameterArray(*model, op->inputs[1])) return false;
// Handling perm.
- const auto& perm_array = *model->arrays[op->inputs[1]];
+ const auto& perm_array = model->GetArray(op->inputs[1]);
if (!perm_array.has_shape()) return false;
const std::vector<int>& perm_dims = perm_array.shape().dims();
#include <gmock/gmock.h>
#include <gtest/gtest.h>
-//#include "tensorflow/contrib/lite/kernels/test_util.h"
#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
#include "tensorflow/contrib/lite/toco/model.h"
#include "tensorflow/contrib/lite/toco/tooling_util.h"
GraphTransformationsSet graph_transformation_set;
graph_transformation_set.Add(new toco::ResolveConstantConcatenation);
- EXPECT_THAT(model.arrays.size(), 5);
+ EXPECT_THAT(model.GetArrayMap().size(), 5);
(*graph_transformation_set.begin())->Run(&model, /*op_index=*/0);
- EXPECT_THAT(model.arrays.size(), 1);
+ EXPECT_THAT(model.GetArrayMap().size(), 1);
- auto& concatenated_array = (*model.arrays.begin()).second;
+ auto& concatenated_array = (*model.GetArrayMap().begin()).second;
EXPECT_THAT(concatenated_array->GetBuffer<toco::ArrayDataType::kFloat>().data,
ElementsAreArray(ArrayFloatNear(
{0., 1., 2., 3., 4., 5., 6., 7., 10., 11., 12.,
GraphTransformationsSet graph_transformation_set;
graph_transformation_set.Add(new toco::ResolveConstantConcatenation);
- EXPECT_THAT(model.arrays.size(), 5);
+ EXPECT_THAT(model.GetArrayMap().size(), 5);
(*graph_transformation_set.begin())->Run(&model, /*op_index=*/0);
- EXPECT_THAT(model.arrays.size(), 1);
+ EXPECT_THAT(model.GetArrayMap().size(), 1);
- auto& concatenated_array = (*model.arrays.begin()).second;
+ auto& concatenated_array = (*model.GetArrayMap().begin()).second;
EXPECT_THAT(concatenated_array->GetBuffer<toco::ArrayDataType::kFloat>().data,
ElementsAreArray(ArrayFloatNear(
{0., 1., 2., 3., 10., 11., 12., 13., 20., 21., 22.,
GraphTransformationsSet graph_transformation_set;
graph_transformation_set.Add(new toco::ResolveConstantConcatenation);
- EXPECT_THAT(model.arrays.size(), 5);
+ EXPECT_THAT(model.GetArrayMap().size(), 5);
(*graph_transformation_set.begin())->Run(&model, /*op_index=*/0);
- EXPECT_THAT(model.arrays.size(), 1);
+ EXPECT_THAT(model.GetArrayMap().size(), 1);
- auto& concatenated_array = (*model.arrays.begin()).second;
+ auto& concatenated_array = (*model.GetArrayMap().begin()).second;
EXPECT_THAT(concatenated_array->GetBuffer<toco::ArrayDataType::kFloat>().data,
ElementsAreArray(ArrayFloatNear(
{0., 1., 10., 11., 20., 21., 30., 31., 2., 3., 12.,
ac_op->outputs = op->outputs;
const string& tmp_array_name =
AvailableArrayName(*model, op->outputs[0] + "_unfused");
- CHECK(!model->arrays.count(tmp_array_name));
+ CHECK(!model->HasArray(tmp_array_name));
model->GetOrCreateArray(tmp_array_name);
ac_op->inputs = {tmp_array_name};
op->outputs = {tmp_array_name};
output = string(absl::StripPrefix(output, "^"));
}
}
- for (auto& array : model->arrays) {
+ for (auto& array : model->GetArrayMap()) {
if (absl::StartsWith(array.first, "^")) {
LOG(FATAL) << "What?";
}
// Our Model struct, represents an entire model (our "top-level" struct).
// Owns everything.
-struct Model {
+class Model {
+ public:
+ using ArrayMap = std::unordered_map<string, std::unique_ptr<Array>>;
+
+ bool HasArray(const string& name) const { return arrays.count(name) > 0; }
Array& GetArray(const string& name) const {
- DCHECK(arrays.count(name));
+ DCHECK(HasArray(name));
return *arrays.at(name);
}
Array& GetOrCreateArray(const string& name) {
// Make sure name is not used by an optional array
DCHECK(!optional_arrays.count(name));
- if (!arrays.count(name)) {
+ if (!HasArray(name)) {
Array* ptr = new Array;
arrays[name] = std::unique_ptr<Array>(ptr);
}
return optional_arrays.count(name);
}
+ // Note that this invalidates all array iterators.
+ void EraseArray(const string& name) { arrays.erase(name); }
+ void EraseArrays(std::function<bool(const string&)> discardable) {
+ for (auto it = arrays.begin(); it != arrays.end();) {
+ if (discardable(it->first)) {
+ it = arrays.erase(it);
+ } else {
+ ++it;
+ }
+ }
+ }
+ const ArrayMap& GetArrayMap() const { return arrays; }
+
// Optional arrays are used for optional tensors,
// these tensors do not have data, but with reserved names as op inputs.
std::set<string> optional_arrays;
+
// The list of operators. Notice how it's a list of unique_ptr's, implying
// that the Model is what owns Operator's and keeps them alive.
std::vector<std::unique_ptr<Operator>> operators;
- // The associative array mapping names to Array's.
- // Notice how it's a container of unique_ptr's, implying
- // that the Model is what owns Array's and keeps them alive.
- // The Operator's refer to these Array's by their name strings, not by their
- // addresses. See Operator::inputs, Operator::outputs.
- std::unordered_map<string, std::unique_ptr<Array>> arrays;
+
// Generic flags, a place where we combine information passed to us via
// command-line parameters (e.g. --input_width=N) with information that
// we may or may not find in the input model file.
std::size_t transient_data_size = 0;
// For code-generation only: required alignment of the transient_data buffer
std::size_t transient_data_alignment = 0;
+
+ private:
+ // The associative array mapping names to Array's.
+ // Notice how it's a container of unique_ptr's, implying
+ // that the Model is what owns Array's and keeps them alive.
+ // The Operator's refer to these Array's by their name strings, not by their
+ // addresses. See Operator::inputs, Operator::outputs.
+ std::unordered_map<string, std::unique_ptr<Array>> arrays;
};
} // namespace toco
void LoadTensorsMap(const Model& model, TensorsMap* tensors_map) {
// First find a list of unique array names.
std::set<string> names;
- for (const auto& array_pair : model.arrays) {
+ for (const auto& array_pair : model.GetArrayMap()) {
names.insert(array_pair.first);
}
// tensors in the tensors_map.
std::map<int, Offset<Tensor>> ordered_tensors;
- for (const auto& array_pair : model.arrays) {
+ for (const auto& array_pair : model.GetArrayMap()) {
const string& tensor_name = array_pair.first;
const toco::Array& array = *array_pair.second;
auto model = Import(ModelFlags(), InputModelAsString());
- ASSERT_GT(model->arrays.count("tensor_one"), 0);
+ ASSERT_GT(model->HasArray("tensor_one"), 0);
Array& a1 = model->GetArray("tensor_one");
EXPECT_EQ(ArrayDataType::kFloat, a1.data_type);
EXPECT_THAT(a1.GetBuffer<ArrayDataType::kFloat>().data,
for (int i = 0; i < model->flags.input_arrays_size(); i++) {
string const& array_name = model->flags.input_arrays(i).name();
- auto* array = model->arrays[array_name].get();
+ auto* array = &model->GetArray(array_name);
// Note that the notion of changing data types only applies to real-numbers
// arrays (see the documentation for inference_input_type).
// TODO(benoitjacob) this is assuming that uint8 arrays are quantized,
bool DeleteArrayIfUnused(const string& array_name, Model* model) {
if (CountOpsWithInput(*model, array_name) == 0) {
- model->arrays.erase(array_name);
+ model->EraseArray(array_name);
return true;
}
return false;
}
bool IsConstantParameterArray(const Model& model, const string& name) {
- if (!model.arrays.count(name)) {
+ if (!model.HasArray(name)) {
return false;
}
- return !!model.arrays.at(name)->buffer;
+ return !!model.GetArray(name).buffer;
}
namespace {
return;
}
for (const auto& input_array : model.flags.input_arrays()) {
- CHECK(model.arrays.count(input_array.name()))
+ CHECK(model.HasArray(input_array.name()))
<< "Input array not found: " << input_array.name();
}
for (const string& output_array : model.flags.output_arrays()) {
- CHECK(model.arrays.count(output_array))
+ CHECK(model.HasArray(output_array))
<< "Output array not found: " << output_array;
}
for (const auto& rnn_state : model.flags.rnn_states()) {
if (!rnn_state.discardable()) {
- CHECK(model.arrays.count(rnn_state.state_array()));
- CHECK(model.arrays.count(rnn_state.back_edge_source_array()));
+ CHECK(model.HasArray(rnn_state.state_array()));
+ CHECK(model.HasArray(rnn_state.back_edge_source_array()));
}
}
}
void CheckNoMissingArray(const Model& model) {
for (const auto& op : model.operators) {
for (const auto& input : op->inputs) {
- CHECK(model.arrays.count(input) || model.optional_arrays.count(input));
+ CHECK(model.HasArray(input) || model.optional_arrays.count(input));
}
for (const auto& output : op->outputs) {
- CHECK(model.arrays.count(output));
+ CHECK(model.HasArray(output));
}
}
CheckNonExistentIOArrays(model);
void FixNoMissingArray(Model* model) {
for (const auto& op : model->operators) {
for (const auto& input : op->inputs) {
- if (!model->arrays.count(input)) {
+ if (!model->HasArray(input)) {
model->GetOrCreateArray(input);
}
}
for (const auto& output : op->outputs) {
- if (!model->arrays.count(output)) {
+ if (!model->HasArray(output)) {
model->GetOrCreateArray(output);
}
}
void CheckNoOrphanedArray(const Model& model) {
std::unordered_set<string> arrays_without_known_use;
- for (const auto& array : model.arrays) {
+ for (const auto& array : model.GetArrayMap()) {
if (IsDiscardableArray(model, array.first)) {
arrays_without_known_use.insert(array.first);
}
void FixNoOrphanedArray(Model* model) {
std::unordered_set<string> arrays_without_known_use;
- for (const auto& array : model->arrays) {
+ for (const auto& array : model->GetArrayMap()) {
arrays_without_known_use.insert(array.first);
}
for (const auto& op : model->operators) {
}
for (const auto& array : arrays_without_known_use) {
if (IsDiscardableArray(*model, array)) {
- model->arrays.erase(array);
+ model->EraseArray(array);
}
}
}
void CheckArrayFieldsConsistent(const Model& model) {
- for (const auto& array_entry : model.arrays) {
+ for (const auto& array_entry : model.GetArrayMap()) {
const auto& array = array_entry.second;
if (array->has_shape()) {
for (int d : array->shape().dims()) {
void CheckOperatorOrdering(const Model& model) {
std::unordered_set<string> arrays_behind_us;
- for (const auto& array_entry : model.arrays) {
+ for (const auto& array_entry : model.GetArrayMap()) {
if (!GetOpWithOutput(model, array_entry.first)) {
arrays_behind_us.insert(array_entry.first);
}
void FixOperatorOrdering(Model* model) {
std::unordered_set<string> arrays_behind_us;
- for (const auto& array_entry : model->arrays) {
+ for (const auto& array_entry : model->GetArrayMap()) {
if (!GetOpWithOutput(*model, array_entry.first)) {
arrays_behind_us.insert(array_entry.first);
}
if (count_type == "None") {
continue;
} else if (count_type == "Arrays") {
- CheckCountInRange(model_check, model.arrays.size(), "count of arrays");
+ CheckCountInRange(model_check, model.GetArrayMap().size(),
+ "count of arrays");
} else if (count_type == "Total") {
CheckCountInRange(model_check, model.operators.size(),
"count of all operator instances");
return false;
}
}
- const auto& array = model.arrays.at(array_name);
+ const auto& array = &model.GetArray(array_name);
// An array with a constant buffer isn't a transient array.
if (!!array->buffer) {
return false;
}
string AvailableArrayName(const Model& model, const string& name) {
- if (!model.arrays.count(name) && !model.optional_arrays.count(name)) {
+ if (!model.HasArray(name) && !model.optional_arrays.count(name)) {
return name;
}
const int kNumSuffixesToTry = 1000;
for (int i = 0; i < kNumSuffixesToTry; i++) {
const string& name_with_suffix = toco::port::StringF("%s_%d", name, i);
- if (!model.arrays.count(name_with_suffix)) {
+ if (!model.HasArray(name_with_suffix)) {
return name_with_suffix;
}
}
}
void PrintArrayShape(Model* model, const string& name) {
- if (!model->arrays[name]->has_shape()) {
+ if (!model->GetArray(name).has_shape()) {
LOG(INFO) << name << " has no shape";
return;
}
LOG(INFO) << name
- << " has shape: " << ShapeToString(model->arrays[name]->shape());
+ << " has shape: " << ShapeToString(model->GetArray(name).shape());
}
bool IsArrayFullyConnectedWeights(const Model& model, const string& name) {
}
void CheckFinalDataTypesSatisfied(const Model& model) {
- for (const auto& array_entry : model.arrays) {
+ for (const auto& array_entry : model.GetArrayMap()) {
const auto& array = *array_entry.second;
if (array.final_data_type != ArrayDataType::kNone) {
CHECK(array.final_data_type == array.data_type)