}
}
+void FixEdgeArrays(Model* model) {
+ for (const string& output_array_name : model->flags.output_arrays()) {
+ if (!GetOpWithOutput(*model, output_array_name)) {
+ // Output has no operator producing it. Change that by inserting a copy.
+ LOG(WARNING) << "Fixing constant output array " << output_array_name
+ << " by inserting a copy. This is not optimal.";
+ string intermediate_array_name =
+ AvailableArrayName(*model, output_array_name + "_copy");
+ CloneArray(model, output_array_name, intermediate_array_name);
+ InsertCopyOperator(model, intermediate_array_name, output_array_name);
+ }
+ }
+}
+
+void InsertCopyOperator(Model* model, const string& source_array_name,
+ const string& target_array_name) {
+ // Drop constant data from the target array as the copy will be done at
+ // runtime.
+ Array& target_array = model->GetOrCreateArray(target_array_name);
+ target_array.buffer.reset();
+
+ // Reshape to the same size. This should be a no-op.
+ const Array& source_array = model->GetArray(source_array_name);
+ std::vector<int> shape = source_array.shape().dims();
+
+ // Insert copy operator.
+ auto* copy_op = new TensorFlowReshapeOperator;
+ copy_op->inputs = {
+ source_array_name,
+ CreateInt32Array(model, target_array_name + "_copy_shape", shape)};
+ copy_op->outputs = {target_array_name};
+ model->operators.emplace_back(copy_op);
+}
+
+namespace {
+template <ArrayDataType A>
+void CopyArrayBuffer(const Array& source_array, Array* target_array) {
+ if (source_array.buffer) {
+ const auto& source_buffer = source_array.GetBuffer<A>();
+ auto& target_buffer = target_array->GetMutableBuffer<A>();
+ target_buffer.data = source_buffer.data;
+ }
+}
+} // namespace
+
+void CloneArray(Model* model, const string& source_array_name,
+ const string& target_array_name) {
+ CHECK(!model->HasArray(target_array_name));
+ const Array& source_array = model->GetArray(source_array_name);
+ Array& target_array = model->GetOrCreateArray(target_array_name);
+
+ switch (source_array.data_type) {
+ case ArrayDataType::kBool:
+ CopyArrayBuffer<ArrayDataType::kBool>(source_array, &target_array);
+ break;
+ case ArrayDataType::kFloat:
+ CopyArrayBuffer<ArrayDataType::kFloat>(source_array, &target_array);
+ break;
+ case ArrayDataType::kInt8:
+ CopyArrayBuffer<ArrayDataType::kInt8>(source_array, &target_array);
+ break;
+ case ArrayDataType::kUint8:
+ CopyArrayBuffer<ArrayDataType::kUint8>(source_array, &target_array);
+ break;
+ case ArrayDataType::kInt16:
+ CopyArrayBuffer<ArrayDataType::kInt16>(source_array, &target_array);
+ break;
+ case ArrayDataType::kUint16:
+ CopyArrayBuffer<ArrayDataType::kUint16>(source_array, &target_array);
+ break;
+ case ArrayDataType::kInt32:
+ CopyArrayBuffer<ArrayDataType::kInt32>(source_array, &target_array);
+ break;
+ case ArrayDataType::kUint32:
+ CopyArrayBuffer<ArrayDataType::kUint32>(source_array, &target_array);
+ break;
+ case ArrayDataType::kInt64:
+ CopyArrayBuffer<ArrayDataType::kInt64>(source_array, &target_array);
+ break;
+ case ArrayDataType::kUint64:
+ CopyArrayBuffer<ArrayDataType::kUint64>(source_array, &target_array);
+ break;
+ case ArrayDataType::kString:
+ CopyArrayBuffer<ArrayDataType::kString>(source_array, &target_array);
+ break;
+ default:
+ LOG(FATAL) << "Unsupported data type: "
+ << ArrayDataTypeName(source_array.data_type);
+ return;
+ }
+
+ if (source_array.minmax) {
+ const auto& smm = source_array.GetMinMax();
+ auto& tmm = target_array.GetOrCreateMinMax();
+ tmm.min = smm.min;
+ tmm.max = smm.max;
+ }
+
+ if (source_array.quantization_params) {
+ const auto& sqp = source_array.GetQuantizationParams();
+ auto& tqp = target_array.GetOrCreateQuantizationParams();
+ tqp.zero_point = sqp.zero_point;
+ tqp.scale = sqp.scale;
+ }
+
+ target_array.data_type = source_array.data_type;
+ target_array.final_data_type = source_array.final_data_type;
+
+ target_array.copy_shape(source_array.shape());
+}
+
void MakeArrayDims(int num_dims, int batch, int height, int width, int depth,
std::vector<int>* out_dims) {
CHECK(out_dims->empty());