// model that does not already contain such MinMax information.
message ArraysExtraInfo {
message Entry {
- // Next ID to use: 5.
+ // Next ID to use: 7.
optional string name = 1;
optional float min = 2;
optional float max = 3;
optional IODataType data_type = 4;
+ optional InputArrayShape shape = 5;
+ optional float constant_float_value = 6;
}
repeated Entry entries = 1;
}
void UseArraysExtraInfo(Model* model) {
for (const auto& entry : model->flags.arrays_extra_info().entries()) {
- QCHECK(model->HasArray(entry.name()))
- << "ArraysExtraInfo refers to non-existent array name: "
- << entry.name();
+ if (!model->HasArray(entry.name())) {
+ continue;
+ }
auto& array = model->GetArray(entry.name());
auto& minmax = array.GetOrCreateMinMax();
if (entry.has_min() || entry.has_max()) {
array.final_data_type =
ConvertIODataTypeToArrayDataType(entry.data_type());
}
+ if (entry.has_shape()) {
+ array.clear_shape();
+ // Make sure to create the shape even if there are no dims, to
+ // correctly record 0-D shapes.
+ array.mutable_shape();
+ for (int dim : entry.shape().dims()) {
+ array.mutable_shape()->mutable_dims()->push_back(dim);
+ }
+ }
+ if (entry.has_constant_float_value()) {
+ CHECK(array.has_shape());
+ CHECK(array.data_type == ArrayDataType::kFloat);
+ auto& data = array.GetMutableBuffer<ArrayDataType::kFloat>().data;
+ data.resize(RequiredBufferSizeForShape(array.shape()));
+ for (float& f : data) {
+ f = entry.constant_float_value();
+ }
+ }
}
}