3 # Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
5 # Licensed under the Apache License, Version 2.0 (the "License");
6 # you may not use this file except in compliance with the License.
7 # You may obtain a copy of the License at
9 # http://www.apache.org/licenses/LICENSE-2.0
11 # Unless required by applicable law or agreed to in writing, software
12 # distributed under the License is distributed on an "AS IS" BASIS,
13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 # See the License for the specific language governing permissions and
15 # limitations under the License.
22 import tflite.SubGraph
23 import tflite.BuiltinOptions
27 # Assume we use only main model in model file
28 # Get selected operators from file, and return operator index list
29 def GetOperatorList(oplist_file):
30 lines = oplist_file.readlines()
37 opcode_list.append(int(word))
39 opcode_range = word.split('-')
40 if ((len(opcode_range) == 2) and opcode_range[0].isdigit()
41 and opcode_range[1].isdigit()):
42 start = int(opcode_range[0])
43 end = int(opcode_range[1])
44 for num in range(start, end + 1):
45 opcode_list.append(int(num))
47 print("Error: Cannot get operator list")
49 "Please pass operators as operator index or range list split by space and/or line"
53 if len(opcode_list) == 0:
54 print("No selected operator")
60 def GetUsedSubgraphsList(sample_model, subg_num, operator_list, used_subgraphs_list):
61 import tflite.IfOptions
62 import tflite.WhileOptions
66 selected_subgraph = sample_model.Subgraphs(subg_num)
68 for operator_idx in operator_list:
69 selected_operator = selected_subgraph.Operators(operator_idx)
70 if selected_operator.BuiltinOptionsType() == tflite.BuiltinOptions.BuiltinOptions(
72 selected_builtin_option = selected_operator.BuiltinOptions()
73 if_option = tflite.IfOptions.IfOptions()
74 if_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
76 subg_list.append(if_option.ElseSubgraphIndex())
77 subg_list.append(if_option.ThenSubgraphIndex())
79 if selected_operator.BuiltinOptionsType() == tflite.BuiltinOptions.BuiltinOptions(
81 selected_builtin_option = selected_operator.BuiltinOptions()
82 while_option = tflite.WhileOptions.WhileOptions()
83 while_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
85 subg_list.append(while_option.BodySubgraphIndex())
86 subg_list.append(while_option.CondSubgraphIndex())
89 if idx not in used_subgraphs_list:
90 used_subgraphs_list.append(idx)
91 GetUsedSubgraphsList(sample_model, idx,
92 range(sample_model.Subgraphs(idx).OperatorsLength() - 1),
96 def GenerateOperatorCodes(new_builder, sample_model, used_opcodes_dic,
98 operator_code_num = sample_model.OperatorCodesLength()
99 new_operator_code_list = []
100 new_operator_code_string_list = {}
102 if operator_code_num == 0:
105 # Create operator_code string
106 for operator_code_idx in range(operator_code_num):
107 if operator_code_idx in used_opcodes_dic:
108 operator_code = sample_model.OperatorCodes(operator_code_idx)
109 operator_code_string = operator_code.CustomCode()
110 if operator_code_string and (operator_code_string != "") and (
111 not operator_code_string in new_operator_code_string_list):
112 new_operator_code_string_list[
113 operator_code_string] = new_builder.CreateString(operator_code_string)
115 # Create tables of operator_code
116 for operator_code_idx in range(operator_code_num):
117 if operator_code_idx in used_opcodes_dic:
118 operator_code = sample_model.OperatorCodes(operator_code_idx)
120 # Create operator_code table
121 tflite.OperatorCode.OperatorCodeStart(new_builder)
122 tflite.OperatorCode.OperatorCodeAddBuiltinCode(new_builder,
123 operator_code.BuiltinCode())
125 new_operator_code_string = operator_code.CustomCode()
126 if new_operator_code_string in new_operator_code_string_list:
127 tflite.OperatorCode.OperatorCodeAddCustomCode(
128 new_builder, new_operator_code_string_list[new_operator_code_string])
129 new_operator_code = tflite.OperatorCode.OperatorCodeEnd(new_builder)
130 new_operator_code_list.append(new_operator_code)
132 # Create operator_code vector
133 new_operator_code_num = len(new_operator_code_list)
134 tflite.Model.ModelStartOperatorCodesVector(new_builder, new_operator_code_num)
135 for operator_code_idx in reversed(range(new_operator_code_num)):
136 new_builder.PrependUOffsetTRelative(new_operator_code_list[operator_code_idx])
138 return new_builder.EndVector(new_operator_code_num)
141 def GenerateQuantization(new_builder, selected_quantization):
143 min_num = selected_quantization.MinLength()
145 tflite.QuantizationParameters.QuantizationParametersStartMinVector(
146 new_builder, min_num)
147 for min_idx in reversed(range(min_num)):
148 new_builder.PrependFloat32(selected_quantization.Min(min_idx))
149 new_min = new_builder.EndVector(min_num)
152 max_num = selected_quantization.MaxLength()
154 tflite.QuantizationParameters.QuantizationParametersStartMaxVector(
155 new_builder, max_num)
156 for max_idx in reversed(range(max_num)):
157 new_builder.PrependFloat32(selected_quantization.Max(max_idx))
158 new_max = new_builder.EndVector(max_num)
160 # Create scale vector
161 scale_num = selected_quantization.ScaleLength()
163 tflite.QuantizationParameters.QuantizationParametersStartScaleVector(
164 new_builder, scale_num)
165 for scale_idx in reversed(range(scale_num)):
166 new_builder.PrependFloat32(selected_quantization.Scale(scale_idx))
167 new_scale = new_builder.EndVector(scale_num)
169 # Create zero_point vector
170 zeropoint_num = selected_quantization.ZeroPointLength()
171 if zeropoint_num != 0:
172 tflite.QuantizationParameters.QuantizationParametersStartZeroPointVector(
173 new_builder, zeropoint_num)
174 for zeropoint_idx in reversed(range(zeropoint_num)):
175 new_builder.PrependInt64(selected_quantization.ZeroPoint(zeropoint_idx))
176 new_zeropoint = new_builder.EndVector(zeropoint_num)
178 # Create quantization
179 tflite.QuantizationParameters.QuantizationParametersStart(new_builder)
181 tflite.QuantizationParameters.QuantizationParametersAddMin(new_builder, new_min)
183 tflite.QuantizationParameters.QuantizationParametersAddMax(new_builder, new_max)
185 tflite.QuantizationParameters.QuantizationParametersAddScale(
186 new_builder, new_scale)
187 if zeropoint_num != 0:
188 tflite.QuantizationParameters.QuantizationParametersAddZeroPoint(
189 new_builder, new_zeropoint)
191 return tflite.QuantizationParameters.QuantizationParametersEnd(new_builder)
194 def GenerateTensor(new_builder, selected_tensor, used_buffers_dic):
196 # Create shape vector for tensor
197 shape_num = selected_tensor.ShapeLength()
198 tflite.Tensor.TensorStartShapeVector(new_builder, shape_num)
200 for shape_idx in reversed(range(shape_num)):
201 new_builder.PrependInt32(selected_tensor.Shape(shape_idx))
202 new_shape = new_builder.EndVector(shape_num)
205 tensor_type = selected_tensor.Type()
207 # Create input vector for tensor
208 buffer_idx = selected_tensor.Buffer()
209 new_buffer_idx = used_buffers_dic[buffer_idx]
212 name_string = selected_tensor.Name()
213 if name_string != "":
214 new_name = new_builder.CreateString(name_string)
216 # Create quantization
217 quantization = selected_tensor.Quantization()
218 if quantization != None:
219 new_quantization = GenerateQuantization(new_builder, quantization)
222 tflite.Tensor.TensorStart(new_builder)
223 tflite.Tensor.TensorAddShape(new_builder, new_shape)
224 tflite.Tensor.TensorAddType(new_builder, tensor_type)
225 tflite.Tensor.TensorAddBuffer(new_builder, new_buffer_idx)
226 if name_string != "":
227 tflite.Tensor.TensorAddName(new_builder, new_name)
228 if quantization != None:
229 tflite.Tensor.TensorAddQuantization(new_builder, new_quantization)
231 return tflite.Tensor.TensorEnd(new_builder)
234 def GenerateTensors(new_builder, selected_subgraph, used_tensors_dic, used_buffers_dic):
235 tensor_num = selected_subgraph.TensorsLength()
241 for tensor_idx in range(tensor_num):
242 if tensor_idx in used_tensors_dic:
243 selected_tensor = selected_subgraph.Tensors(tensor_idx)
244 new_tensor = GenerateTensor(new_builder, selected_tensor, used_buffers_dic)
245 new_tensor_list.append(new_tensor)
247 new_tensor_num = len(new_tensor_list)
248 if new_tensor_num == 0:
251 tflite.SubGraph.SubGraphStartTensorsVector(new_builder, new_tensor_num)
252 for new_tensor in reversed(new_tensor_list):
253 new_builder.PrependUOffsetTRelative(new_tensor)
255 return new_builder.EndVector(new_tensor_num)
258 def GenerateBuiltinOption(new_builder, selected_builtin_option, builtin_option_type,
262 import tflite.Conv2DOptions
263 if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().Conv2DOptions:
265 conv2d_options = tflite.Conv2DOptions.Conv2DOptions()
266 conv2d_options.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
268 tflite.Conv2DOptions.Conv2DOptionsStart(new_builder)
269 tflite.Conv2DOptions.Conv2DOptionsAddPadding(new_builder,
270 conv2d_options.Padding())
271 tflite.Conv2DOptions.Conv2DOptionsAddStrideW(new_builder,
272 conv2d_options.StrideW())
273 tflite.Conv2DOptions.Conv2DOptionsAddStrideH(new_builder,
274 conv2d_options.StrideH())
275 tflite.Conv2DOptions.Conv2DOptionsAddDilationWFactor(
276 new_builder, conv2d_options.DilationWFactor())
277 tflite.Conv2DOptions.Conv2DOptionsAddDilationHFactor(
278 new_builder, conv2d_options.DilationHFactor())
279 tflite.Conv2DOptions.Conv2DOptionsAddFusedActivationFunction(
280 new_builder, conv2d_options.FusedActivationFunction())
281 return tflite.Conv2DOptions.Conv2DOptionsEnd(new_builder)
283 # DepthwiseConv2D option
284 import tflite.DepthwiseConv2DOptions
285 if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions(
286 ).DepthwiseConv2DOptions:
288 depthconv2d_option = tflite.DepthwiseConv2DOptions.DepthwiseConv2DOptions()
289 depthconv2d_option.Init(selected_builtin_option.Bytes,
290 selected_builtin_option.Pos)
292 tflite.DepthwiseConv2DOptions.DepthwiseConv2DOptionsStart(new_builder)
293 tflite.DepthwiseConv2DOptions.DepthwiseConv2DOptionsAddPadding(
294 new_builder, depthconv2d_option.Padding())
295 tflite.DepthwiseConv2DOptions.DepthwiseConv2DOptionsAddStrideW(
296 new_builder, depthconv2d_option.StrideW())
297 tflite.DepthwiseConv2DOptions.DepthwiseConv2DOptionsAddStrideH(
298 new_builder, depthconv2d_option.StrideH())
299 tflite.DepthwiseConv2DOptions.DepthwiseConv2DOptionsAddDepthMultiplier(
300 new_builder, depthconv2d_option.DepthMultiplier())
301 tflite.DepthwiseConv2DOptions.DepthwiseConv2DOptionsAddFusedActivationFunction(
302 new_builder, depthconv2d_option.FusedActivationFunction())
303 tflite.DepthwiseConv2DOptions.DepthwiseConv2DOptionsAddDilationWFactor(
304 new_builder, depthconv2d_option.DilationWFactor())
305 tflite.DepthwiseConv2DOptions.DepthwiseConv2DOptionsAddDilationHFactor(
306 new_builder, depthconv2d_option.DilationHFactor())
307 return tflite.DepthwiseConv2DOptions.DepthwiseConv2DOptionsEnd(new_builder)
309 # ConcatEmbeddingsOptions: not supported
310 # LSHProjectionOptions: not supported
313 import tflite.Pool2DOptions
314 if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().Pool2DOptions:
316 pool2d_option = tflite.Pool2DOptions.Pool2DOptions()
317 pool2d_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
319 tflite.Pool2DOptions.Pool2DOptionsStart(new_builder)
320 tflite.Pool2DOptions.Pool2DOptionsAddPadding(new_builder, pool2d_option.Padding())
321 tflite.Pool2DOptions.Pool2DOptionsAddStrideW(new_builder, pool2d_option.StrideW())
322 tflite.Pool2DOptions.Pool2DOptionsAddStrideH(new_builder, pool2d_option.StrideH())
323 tflite.Pool2DOptions.Pool2DOptionsAddFilterWidth(new_builder,
324 pool2d_option.FilterWidth())
325 tflite.Pool2DOptions.Pool2DOptionsAddFilterHeight(new_builder,
326 pool2d_option.FilterHeight())
327 tflite.Pool2DOptions.Pool2DOptionsAddFusedActivationFunction(
328 new_builder, pool2d_option.FusedActivationFunction())
329 return tflite.Pool2DOptions.Pool2DOptionsEnd(new_builder)
331 # SVDFOptions: not supported
334 import tflite.RNNOptions
335 if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().RNNOptions:
337 rnn_option = tflite.RNNOptions.RNNOptions()
338 rnn_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
340 tflite.RNNOptions.RNNOptionsStart(new_builder)
341 tflite.RNNOptions.RNNOptionsAddFusedActivationFunction(
342 new_builder, rnn_option.FusedActivationFunction())
343 return tflite.RNNOptions.RNNOptionsEnd(new_builder)
345 # FullyConnectedOptions
346 import tflite.FullyConnectedOptions
347 if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions(
348 ).FullyConnectedOptions:
350 fc_option = tflite.FullyConnectedOptions.FullyConnectedOptions()
351 fc_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
353 tflite.FullyConnectedOptions.FullyConnectedOptionsStart(new_builder)
354 tflite.FullyConnectedOptions.FullyConnectedOptionsAddFusedActivationFunction(
355 new_builder, fc_option.FusedActivationFunction())
356 return tflite.FullyConnectedOptions.FullyConnectedOptionsEnd(new_builder)
359 import tflite.SoftmaxOptions
360 if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().SoftmaxOptions:
362 softmax_option = tflite.SoftmaxOptions.SoftmaxOptions()
363 softmax_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
365 tflite.SoftmaxOptions.SoftmaxOptionsStart(new_builder)
366 tflite.SoftmaxOptions.SoftmaxOptionsAddBeta(new_builder, softmax_option.Beta())
367 return tflite.SoftmaxOptions.SoftmaxOptionsEnd(new_builder)
369 # ConcatenationOptions
370 import tflite.ConcatenationOptions
371 if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().ConcatenationOptions:
373 concat_option = tflite.ConcatenationOptions.ConcatenationOptions()
374 concat_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
376 tflite.ConcatenationOptions.ConcatenationOptionsStart(new_builder)
377 tflite.ConcatenationOptions.ConcatenationOptionsAddAxis(
378 new_builder, concat_option.Axis())
379 tflite.ConcatenationOptions.ConcatenationOptionsAddFusedActivationFunction(
380 new_builder, concat_option.FusedActivationFunction())
381 return tflite.ConcatenationOptions.ConcatenationOptionsEnd(new_builder)
384 import tflite.AddOptions
385 if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().AddOptions:
387 add_option = tflite.AddOptions.AddOptions()
388 add_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
390 tflite.AddOptions.AddOptionsStart(new_builder)
391 tflite.AddOptions.AddOptionsAddFusedActivationFunction(
392 new_builder, add_option.FusedActivationFunction())
393 return tflite.AddOptions.AddOptionsEnd(new_builder)
396 import tflite.L2NormOptions
397 if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().L2NormOptions:
399 l2norm_option = tflite.L2NormOptions.L2NormOptions()
400 l2norm_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
402 tflite.L2NormOptions.L2NormOptionsStart(new_builder)
403 tflite.L2NormOptions.L2NormOptionsAddFusedActivationFunction(
404 new_builder, l2norm_option.FusedActivationFunction())
405 return tflite.L2NormOptions.L2NormOptionsEnd(new_builder)
407 # LocalResponseNormalizationOptions
408 import tflite.LocalResponseNormalizationOptions
409 if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions(
410 ).LocalResponseNormalizationOptions:
412 lrn_option = tflite.LocalResponseNormalizationOptions.LocalResponseNormalizationOptions(
414 lrn_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
416 tflite.LocalResponseNormalizationOptions.LocalResponseNormalizationOptionsStart(
418 tflite.LocalResponseNormalizationOptions.LocalResponseNormalizationOptionsAddRadius(
419 new_builder, lrn_option.Radius())
420 tflite.LocalResponseNormalizationOptions.LocalResponseNormalizationOptionsAddBias(
421 new_builder, lrn_option.Bias())
422 tflite.LocalResponseNormalizationOptions.LocalResponseNormalizationOptionsAddAlpha(
423 new_builder, lrn_option.Alpha())
424 tflite.LocalResponseNormalizationOptions.LocalResponseNormalizationOptionsAddBeta(
425 new_builder, lrn_option.Beta())
426 return tflite.LocalResponseNormalizationOptions.LocalResponseNormalizationOptionsEnd(
429 # LSTMOptions: not supported
431 # ResizeBilinearOptions
432 import tflite.ResizeBilinearOptions
433 if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions(
434 ).ResizeBilinearOptions:
436 resize_bilinear_option = tflite.ResizeBilinearOptions.ResizeBilinearOptions()
437 resize_bilinear_option.Init(selected_builtin_option.Bytes,
438 selected_builtin_option.Pos)
440 tflite.ResizeBilinearOptions.ResizeBilinearOptionsStart(new_builder)
441 tflite.ResizeBilinearOptions.ResizeBilinearOptionsAddAlignCorners(
442 new_builder, resize_bilinear_option.AlignCorners())
443 return tflite.ResizeBilinearOptions.ResizeBilinearOptionsEnd(new_builder)
445 # CallOptions: not supported
448 import tflite.ReshapeOptions
449 if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().ReshapeOptions:
451 reshape_option = tflite.ReshapeOptions.ReshapeOptions()
452 reshape_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
454 shape_num = reshape_option.NewShapeLength()
456 tflite.ReshapeOptions.ReshapeOptionsStartNewShapeVector(
457 new_builder, shape_num)
458 for new_shape_idx in reversed(range(shape_num)):
459 new_shape_val = reshape_option.NewShape(new_shape_idx)
460 new_builder.PrependInt32(new_shape_val)
461 new_shape = new_builder.EndVector(shape_num)
463 tflite.ReshapeOptions.ReshapeOptionsStart(new_builder)
465 tflite.ReshapeOptions.ReshapeOptionsAddNewShape(new_builder, new_shape)
466 return tflite.ReshapeOptions.ReshapeOptionsEnd(new_builder)
468 # SkipGramOptions: not supported
470 # SpaceToDepthOptions
471 import tflite.SpaceToDepthOptions
472 if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().SpaceToDepthOptions:
474 space_to_depth_option = tflite.SpaceToDepthOptions.SpaceToDepthOptions()
475 space_to_depth_option.Init(selected_builtin_option.Bytes,
476 selected_builtin_option.Pos)
478 tflite.SpaceToDepthOptions.SpaceToDepthOptionsStart(new_builder)
479 tflite.SpaceToDepthOptions.SpaceToDepthOptionsAddBlockSize(
480 new_builder, space_to_depth_option.BlockSize())
481 return tflite.SpaceToDepthOptions.SpaceToDepthOptionsEnd(new_builder)
483 # EmbeddingLookupSparseOptions: not supported
486 import tflite.MulOptions
487 if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().MulOptions:
489 mul_option = tflite.MulOptions.MulOptions()
490 mul_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
492 tflite.MulOptions.MulOptionsStart(new_builder)
493 tflite.MulOptions.MulOptionsAddFusedActivationFunction(
494 new_builder, mul_option.FusedActivationFunction())
495 return tflite.MulOptions.MulOptionsEnd(new_builder)
498 import tflite.PadOptions
499 if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().PadOptions:
501 pad_option = tflite.PadOptions.PadOptions()
502 pad_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
504 tflite.PadOptions.PadOptionsStart(new_builder)
505 return tflite.PadOptions.PadOptionsEnd(new_builder)
508 import tflite.GatherOptions
509 if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().GatherOptions:
511 gather_option = tflite.GatherOptions.GatherOptions()
512 gather_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
514 tflite.GatherOptions.GatherOptionsStart(new_builder)
515 tflite.GatherOptions.GatherOptionsAddAxis(new_builder, gather_option.Axis())
516 return tflite.GatherOptions.GatherOptionsEnd(new_builder)
518 # BatchToSpaceNDOptions
519 import tflite.BatchToSpaceNDOptions
520 if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions(
521 ).BatchToSpaceNDOptions:
523 btsnd_option = tflite.BatchToSpaceNDOptions.BatchToSpaceNDOptions()
524 btsnd_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
526 tflite.BatchToSpaceNDOptions.BatchToSpaceNDOptionsStart(new_builder)
527 return tflite.BatchToSpaceNDOptions.BatchToSpaceNDOptionsEnd(new_builder)
529 # SpaceToBatchNDOptions
530 import tflite.SpaceToBatchNDOptions
531 if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions(
532 ).SpaceToBatchNDOptions:
534 stbnd_option = tflite.SpaceToBatchNDOptions.SpaceToBatchNDOptions()
535 stbnd_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
537 tflite.SpaceToBatchNDOptions.SpaceToBatchNDOptionsStart(new_builder)
538 return tflite.SpaceToBatchNDOptions.SpaceToBatchNDOptionsEnd(new_builder)
541 import tflite.TransposeOptions
542 if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().TransposeOptions:
544 transpose_option = tflite.TransposeOptions.TransposeOptions()
545 transpose_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
547 tflite.TransposeOptions.TransposeOptionsStart(new_builder)
548 return tflite.TransposeOptions.TransposeOptionsEnd(new_builder)
551 import tflite.ReducerOptions
552 if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().ReducerOptions:
554 reducer_option = tflite.ReducerOptions.ReducerOptions()
555 reducer_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
557 tflite.ReducerOptions.ReducerOptionsStart(new_builder)
558 tflite.ReducerOptions.ReducerOptionsAddKeepDims(new_builder,
559 reducer_option.KeepDims())
560 return tflite.ReducerOptions.ReducerOptionsEnd(new_builder)
563 import tflite.SubOptions
564 if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().SubOptions:
566 sub_option = tflite.SubOptions.SubOptions()
567 sub_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
569 tflite.SubOptions.SubOptionsStart(new_builder)
570 tflite.SubOptions.SubOptionsAddFusedActivationFunction(
571 new_builder, sub_option.FusedActivationFunction())
572 return tflite.SubOptions.SubOptionsEnd(new_builder)
575 import tflite.DivOptions
576 if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().DivOptions:
578 div_option = tflite.DivOptions.DivOptions()
579 div_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
581 tflite.DivOptions.DivOptionsStart(new_builder)
582 tflite.DivOptions.DivOptionsAddFusedActivationFunction(
583 new_builder, div_option.FusedActivationFunction())
584 return tflite.DivOptions.DivOptionsEnd(new_builder)
587 import tflite.SqueezeOptions
588 if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().SqueezeOptions:
590 squeeze_option = tflite.SqueezeOptions.SqueezeOptions()
591 squeeze_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
593 squeeze_dims_num = squeeze_option.SqueezeDimsLength()
594 if squeeze_dims_num != 0:
595 tflite.SqueezeOptions.SqueezeOptionsStartSqueezeDimsVector(
596 new_builder, squeeze_dims_num)
597 for squeeze_dims_idx in reversed(range(squeeze_dims_num)):
598 squeeze_dims_val = squeeze_option.SqueezeDims(squeeze_dims_idx)
599 new_builder.PrependInt32(squeeze_dims_val)
600 new_squeeze_dims = new_builder.EndVector(squeeze_dims_num)
602 tflite.SqueezeOptions.SqueezeOptionsStart(new_builder)
603 if squeeze_dims_num != 0:
604 tflite.SqueezeOptions.SqueezeOptionsAddSqueezeDims(new_builder,
606 return tflite.SqueezeOptions.SqueezeOptionsEnd(new_builder)
608 # SequenceRNNOptions: not supported
610 # StridedSliceOptions
611 import tflite.StridedSliceOptions
612 if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().StridedSliceOptions:
614 stride_slice_option = tflite.StridedSliceOptions.StridedSliceOptions()
615 stride_slice_option.Init(selected_builtin_option.Bytes,
616 selected_builtin_option.Pos)
618 tflite.StridedSliceOptions.StridedSliceOptionsStart(new_builder)
619 tflite.StridedSliceOptions.StridedSliceOptionsAddBeginMask(
620 new_builder, stride_slice_option.BeginMask())
621 tflite.StridedSliceOptions.StridedSliceOptionsAddEndMask(
622 new_builder, stride_slice_option.EndMask())
623 tflite.StridedSliceOptions.StridedSliceOptionsAddEllipsisMask(
624 new_builder, stride_slice_option.EllipsisMask())
625 tflite.StridedSliceOptions.StridedSliceOptionsAddNewAxisMask(
626 new_builder, stride_slice_option.NewAxisMask())
627 tflite.StridedSliceOptions.StridedSliceOptionsAddShrinkAxisMask(
628 new_builder, stride_slice_option.ShrinkAxisMask())
630 return tflite.StridedSliceOptions.StridedSliceOptionsEnd(new_builder)
633 import tflite.ExpOptions
634 if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().ExpOptions:
636 exp_option = tflite.ExpOptions.ExpOptions()
637 exp_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
639 tflite.ExpOptions.ExpOptionsStart(new_builder)
640 return tflite.ExpOptions.ExpOptionsEnd(new_builder)
643 import tflite.TopKV2Options
644 if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().TopKV2Options:
646 topkv2_option = tflite.TopKV2Options.TopKV2Options()
647 topkv2_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
649 tflite.TopKV2Options.TopKV2OptionsStart(new_builder)
650 return tflite.TopKV2Options.TopKV2OptionsEnd(new_builder)
653 import tflite.SplitOptions
654 if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().SplitOptions:
656 split_option = tflite.SplitOptions.SplitOptions()
657 split_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
659 tflite.SplitOptions.SplitOptionsStart(new_builder)
660 tflite.SplitOptions.SplitOptionsAddNumSplits(new_builder,
661 split_option.NumSplits())
662 return tflite.SplitOptions.SplitOptionsEnd(new_builder)
664 # LogSoftmaxOptions: not supported
666 # CastOptions: not supported
667 import tflite.CastOptions
668 if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().CastOptions:
670 cast_option = tflite.CastOptions.CastOptions()
671 cast_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
673 tflite.CastOptions.CastOptionsStart(new_builder)
674 return tflite.CastOptions.CastOptionsEnd(new_builder)
677 import tflite.DequantizeOptions
678 if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().DequantizeOptions:
680 dequantize_option = tflite.DequantizeOptions.DequantizeOptions()
681 dequantize_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
683 tflite.EqualOptions.DequantizeOptionsStart(new_builder)
684 return tflite.DequantizeOptions.DequantizeOptionsEnd(new_builder)
686 # MaximumMinimumOptions: not supported
689 import tflite.ArgMaxOptions
690 if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().ArgMaxOptions:
692 arg_max_option = tflite.ArgMaxOptions.ArgMaxOptions()
693 arg_max_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
695 tflite.ArgMaxOptions.ArgMaxOptionsStart(new_builder)
696 tflite.ArgMaxOptions.ArgMaxOptionsAddOutputType(new_builder,
697 arg_max_option.OutputType())
698 return tflite.ArgMaxOptions.ArgMaxOptionsEnd(new_builder)
700 # LessOptions: not supported
703 import tflite.NegOptions
704 if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().NegOptions:
706 neg_option = tflite.NegOptions.NegOptions()
707 neg_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
709 tflite.NegOptions.NegOptionsStart(new_builder)
710 return tflite.NegOptions.NegOptionsEnd(new_builder)
713 import tflite.EqualOptions
714 if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().EqualOptions:
716 equal_option = tflite.EqualOptions.EqualOptions()
717 equal_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
719 tflite.EqualOptions.EqualOptionsStart(new_builder)
720 return tflite.EqualOptions.EqualOptionsEnd(new_builder)
722 # PadV2Options: not supported
723 # GreaterOptions: not supported
724 # GreaterEqualOptions: not supported
725 # LessEqualOptions: not supported
728 import tflite.SelectOptions
729 if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().SelectOptions:
731 select_option = tflite.SelectOptions.SelectOptions()
732 select_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
734 tflite.SelectOptions.SelectOptionsStart(new_builder)
735 return tflite.SelectOptions.SelectOptionsEnd(new_builder)
737 # SliceOptions: not supported
739 # TransposeConvOptions
740 import tflite.TransposeConvOptions
741 if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().TransposeConvOptions:
743 transposeconv_option = tflite.TransposeConvOptions.TransposeConvOptions()
744 transposeconv_option.Init(selected_builtin_option.Bytes,
745 selected_builtin_option.Pos)
747 tflite.TransposeConvOptions.TransposeConvOptionsStart(new_builder)
748 tflite.TransposeConvOptions.TransposeConvOptionsAddPadding(
749 new_builder, transposeconv_option.Padding())
750 tflite.TransposeConvOptions.TransposeConvOptionsAddStrideW(
751 new_builder, transposeconv_option.StrideW())
752 tflite.TransposeConvOptions.TransposeConvOptionsAddStrideH(
753 new_builder, transposeconv_option.StrideH())
754 return tflite.TransposeConvOptions.TransposeConvOptionsEnd(new_builder)
756 # SparseToDenseOptions: not supported
757 # TileOptions: not supported
760 import tflite.ExpandDimsOptions
761 if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().ExpandDimsOptions:
763 expanddims_option = tflite.ExpandDimsOptions.ExpandDimsOptions()
764 expanddims_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
766 tflite.ExpandDimsOptions.ExpandDimsOptionsStart(new_builder)
767 return tflite.ExpandDimsOptions.ExpandDimsOptionsEnd(new_builder)
770 import tflite.NotEqualOptions
771 if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().NotEqualOptions:
773 notequal_option = tflite.NotEqualOptions.NotEqualOptions()
774 notequal_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
776 tflite.NotEqualOptions.NotEqualOptionsStart(new_builder)
777 return tflite.NotEqualOptions.NotEqualOptionsEnd(new_builder)
780 import tflite.ShapeOptions
781 if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().ShapeOptions:
783 shape_option = tflite.ShapeOptions.ShapeOptions()
784 shape_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
786 tflite.ShapeOptions.ShapeOptionsStart(new_builder)
787 tflite.ShapeOptions.ShapeOptionsAddOutType(new_builder, shape_option.OutType())
788 return tflite.ShapeOptions.ShapeOptionsEnd(new_builder)
790 # PowOptions: not supported
791 # ArgMinOptions: not supported
792 # FakeQuantOptions: not supported
795 import tflite.PackOptions
796 if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().PackOptions:
798 pack_option = tflite.PackOptions.PackOptions()
799 pack_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
801 tflite.PackOptions.PackOptionsStart(new_builder)
802 tflite.PackOptions.PackOptionsAddValuesCount(new_builder,
803 pack_option.ValuesCount())
804 tflite.PackOptions.PackOptionsAddAxis(new_builder, pack_option.Axis())
805 return tflite.PackOptions.PackOptionsEnd(new_builder)
808 import tflite.LogicalOrOptions
809 if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().LogicalOrOptions:
811 logical_or_option = tflite.LogicalAndOptions.LogicalOrOptions()
812 logical_or_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
814 tflite.LogicalOrOptions.LogicalOrOptionsStart(new_builder)
815 return tflite.LogicalOrOptions.LogicalOrOptionsEnd(new_builder)
817 # OneHotOptions: not supported
818 import tflite.OneHotOptions
819 if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().OneHotOptions:
821 one_hot_option = tflite.OneHotOptions.OneHotOptions()
822 one_hot_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
824 tflite.OneHotOptions.OneHotOptionsStart(new_builder)
825 tflite.OneHotOptions.OneHotOptionsAddAxis(new_builder, one_hot_option.Axis())
826 return tflite.OneHotOptions.OneHotOptionsEnd(new_builder)
829 import tflite.LogicalNotOptions
830 if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().LogicalNotOptions:
832 equal_option = tflite.LogicalNotOptions.LogicalNotOptions()
833 equal_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
835 tflite.LogicalNotOptions.LogicalNotOptionsStart(new_builder)
836 return tflite.LogicalNotOptions.LogicalNotOptionsEnd(new_builder)
839 import tflite.UnpackOptions
840 if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().UnpackOptions:
842 unpack_option = tflite.UnpackOptions.UnpackOptions()
843 unpack_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
845 tflite.UnpackOptions.UnpackOptionsStart(new_builder)
846 tflite.UnpackOptions.UnpackOptionsAddNum(new_builder, unpack_option.Num())
847 tflite.UnpackOptions.UnpackOptionsAddAxis(new_builder, unpack_option.Axis())
848 return tflite.UnpackOptions.UnpackOptionsEnd(new_builder)
850 # FloorDivOptions: not supported
851 # SquareOptions: not supported
852 # ZerosLikeOptions: not supported
853 # FillOptions: not supported
856 import tflite.LogicalAndOptions
857 if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().LogicalAndOptions:
859 logical_and_option = tflite.LogicalAndOptions.LogicalAndOptions()
860 logical_and_option.Init(selected_builtin_option.Bytes,
861 selected_builtin_option.Pos)
863 tflite.LogicalAndOptions.LogicalAndOptionsStart(new_builder)
864 return tflite.LogicalAndOptions.LogicalAndOptionsEnd(new_builder)
866 # LogicalNotOptions: not supported
867 # UnpackOptions: not supported
868 # FloorDivOptions: not supported
869 # SquareOptions: not supported
870 # ZerosLikeOptions: not supported
871 # FillOptions: not supported
872 # BidirectionalSequenceLSTMOptions: not supported
873 # BidirectionalSequenceRNNOptions: not supported
874 # UnidirectionalSequenceLSTMOptions: not supported
875 # FloorModOptions: not supported
876 # RangeOptions: not supported
877 # ResizeNearestNeighborOptions: not supported
880 import tflite.LeakyReluOptions
881 if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().LeakyReluOptions:
883 leaky_relu_option = tflite.LeakyReluOptions.LeakyReluOptions()
884 leaky_relu_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
886 tflite.LeakyReluOptions.LeakyReluOptionsStart(new_builder)
887 tflite.LeakyReluOptions.LeakyReluOptionsAddAlpha(new_builder,
888 leaky_relu_option.Alpha())
889 return tflite.LeakyReluOptions.LeakyReluOptionsEnd(new_builder)
891 # SquaredDifferenceOptions
892 import tflite.SquaredDifferenceOptions
893 if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions(
894 ).SquaredDifferenceOptions:
896 squared_difference_option = tflite.SquaredDifferenceOptions.SquaredDifferenceOptions(
898 squared_difference_option.Init(selected_builtin_option.Bytes,
899 selected_builtin_option.Pos)
901 tflite.SquaredDifferenceOptions.SquaredDifferenceOptionsStart(new_builder)
902 return tflite.SquaredDifferenceOptions.SquaredDifferenceOptionsEnd(new_builder)
904 # MirrorPadOptions: not supported
905 # AbsOptions: not supported
906 # SplitVOptions: not supported
909 import tflite.IfOptions
910 if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().IfOptions:
912 if_option = tflite.IfOptions.IfOptions()
913 if_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
915 tflite.IfOptions.IfOptionsStart(new_builder)
916 tflite.IfOptions.IfOptionsAddElseSubgraphIndex(
917 new_builder, used_subgraphs_dic[if_option.ElseSubgraphIndex()])
918 tflite.IfOptions.IfOptionsAddThenSubgraphIndex(
919 new_builder, used_subgraphs_dic[if_option.ThenSubgraphIndex()])
920 return tflite.IfOptions.IfOptionsEnd(new_builder)
923 import tflite.WhileOptions
924 if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().WhileOptions:
926 while_option = tflite.WhileOptions.WhileOptions()
927 while_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
929 tflite.WhileOptions.WhileOptionsStart(new_builder)
930 tflite.WhileOptions.WhileOptionsAddBodySubgraphIndex(
931 new_builder, used_subgraphs_dic[while_option.BodySubgraphIndex()])
932 tflite.WhileOptions.WhileOptionsAddCondSubgraphIndex(
933 new_builder, used_subgraphs_dic[while_option.CondSubgraphIndex()])
934 return tflite.WhileOptions.WhileOptionsEnd(new_builder)
936 # Cannot handle builtin option type yet
937 print("Cannot handle BuiltinOptions {} yet. See BuiltinOptions.py for op name".format(
938 builtin_option_type))
942 def GenerateOperator(new_builder, selected_operator, used_tensors_dic, used_opcodes_dic,
945 # define opcode_index
946 opcode_index = selected_operator.OpcodeIndex()
947 new_opcode_index = used_opcodes_dic[opcode_index]
949 # create input vector
950 input_num = selected_operator.InputsLength()
952 tflite.Operator.OperatorStartInputsVector(new_builder, input_num)
953 for input_idx in reversed(range(input_num)):
954 input_tensor_idx = selected_operator.Inputs(input_idx)
955 if input_tensor_idx == -1:
956 new_input_tensor_idx = -1
958 new_input_tensor_idx = used_tensors_dic[input_tensor_idx]
959 new_builder.PrependInt32(new_input_tensor_idx)
960 new_input = new_builder.EndVector(input_num)
962 # create output_vector
963 output_num = selected_operator.OutputsLength()
965 tflite.Operator.OperatorStartOutputsVector(new_builder, output_num)
966 for output_idx in reversed(range(output_num)):
967 output_tensor_idx = selected_operator.Outputs(output_idx)
968 new_output_tensor_idx = used_tensors_dic[output_tensor_idx]
969 new_builder.PrependInt32(new_output_tensor_idx)
970 new_output = new_builder.EndVector(output_num)
972 # Create builtin_option
973 builtin_option_type = selected_operator.BuiltinOptionsType()
974 if builtin_option_type != 0:
975 selected_builtin_option = selected_operator.BuiltinOptions()
976 new_builtin_option = GenerateBuiltinOption(
977 new_builder, selected_builtin_option, builtin_option_type, used_subgraphs_dic)
979 # Create custum option vector
980 custom_option_num = selected_operator.CustomOptionsLength()
981 if custom_option_num != 0:
982 tflite.Operator.OperatorStartCustomOptionsVector(new_builder, custom_option_num)
983 for custom_option_idx in reversed(range(custom_option_num)):
984 new_builder.PrependUint8(selected_operator.CustomOptions(custom_option_idx))
985 new_custom_option = new_builder.EndVector(custom_option_num)
987 # Create custum option type
988 custom_option_type = selected_operator.CustomOptionsFormat()
991 tflite.Operator.OperatorStart(new_builder)
992 tflite.Operator.OperatorAddOpcodeIndex(new_builder, new_opcode_index)
994 tflite.Operator.OperatorAddInputs(new_builder, new_input)
996 tflite.Operator.OperatorAddOutputs(new_builder, new_output)
997 tflite.Operator.OperatorAddBuiltinOptionsType(new_builder, builtin_option_type)
998 if builtin_option_type != 0:
999 tflite.Operator.OperatorAddBuiltinOptions(new_builder, new_builtin_option)
1000 if custom_option_num != 0:
1001 tflite.Operator.OperatorAddCustomOptions(new_builder, new_custom_option)
1002 tflite.Operator.OperatorAddCustomOptionsFormat(new_builder, custom_option_type)
1003 return tflite.Operator.OperatorEnd(new_builder)
1006 def GenerateOperators(new_builder, selected_subgraph, operator_list, used_tensors_dic,
1007 used_opcodes_dic, used_subgraphs_dic):
1008 operator_num = selected_subgraph.OperatorsLength()
1009 new_operator_list = []
1011 if operator_num == 0:
1014 for operator_idx in range(operator_num):
1015 if operator_idx in operator_list:
1016 selected_operator = selected_subgraph.Operators(operator_idx)
1017 new_operator = GenerateOperator(new_builder, selected_operator,
1018 used_tensors_dic, used_opcodes_dic,
1020 new_operator_list.append(new_operator)
1022 new_operator_num = len(new_operator_list)
1023 if new_operator_num == 0:
1026 tflite.SubGraph.SubGraphStartOperatorsVector(new_builder, new_operator_num)
1027 for new_operator in reversed(new_operator_list):
1028 new_builder.PrependUOffsetTRelative(new_operator)
1030 return new_builder.EndVector(new_operator_num)
1033 def GenerateSubgraph(new_builder, selected_subgraph, operator_list, new_input_tensor,
1034 new_output_tensor, used_tensors_dic, used_buffers_dic,
1035 used_opcodes_dic, used_subgraphs_dic):
1038 tensors = GenerateTensors(new_builder, selected_subgraph, used_tensors_dic,
1041 # Create input vector for subgraph table
1042 new_input_tensor_num = len(new_input_tensor)
1043 if new_input_tensor_num != 0:
1044 tflite.SubGraph.SubGraphStartInputsVector(new_builder, new_input_tensor_num)
1045 for input_tensor_idx in reversed(new_input_tensor):
1046 new_input_tensor_idx = used_tensors_dic[input_tensor_idx]
1047 new_builder.PrependInt32(new_input_tensor_idx)
1048 new_inputs = new_builder.EndVector(new_input_tensor_num)
1050 # Create output vector for subgraph table
1051 new_output_tensor_num = len(new_output_tensor)
1052 if new_output_tensor_num != 0:
1053 tflite.SubGraph.SubGraphStartOutputsVector(new_builder, new_output_tensor_num)
1054 for output_tensor_idx in reversed(new_output_tensor):
1055 new_output_tensor_idx = used_tensors_dic[output_tensor_idx]
1056 new_builder.PrependInt32(new_output_tensor_idx)
1057 new_outputs = new_builder.EndVector(new_output_tensor_num)
1060 operators = GenerateOperators(new_builder, selected_subgraph, operator_list,
1061 used_tensors_dic, used_opcodes_dic, used_subgraphs_dic)
1064 subgraph_name = selected_subgraph.Name()
1066 if subgraph_name and subgraph_name != "":
1068 new_subgraph_name = new_builder.CreateString(subgraph_name)
1070 tflite.SubGraph.SubGraphStart(new_builder)
1071 tflite.SubGraph.SubGraphAddTensors(new_builder, tensors)
1072 if new_input_tensor_num != 0:
1073 tflite.SubGraph.SubGraphAddInputs(new_builder, new_inputs)
1074 if new_output_tensor_num != 0:
1075 tflite.SubGraph.SubGraphAddOutputs(new_builder, new_outputs)
1076 tflite.SubGraph.SubGraphAddOperators(new_builder, operators)
1078 tflite.SubGraph.SubGraphAddName(new_builder, new_subgraph_name)
1080 return tflite.SubGraph.SubGraphEnd(new_builder)
1083 def GenerateSubgraphs(args, new_builder, sample_model, operator_list, new_input_tensor,
1084 new_output_tensor, used_tensors_dic, used_buffers_dic,
1085 used_opcodes_dic, used_subgraphs_dic):
1087 new_subgraph_list = []
1089 # The selected subgraph will be primary subgraph of the model to be created newly
1090 selected_subgraph = sample_model.Subgraphs(args.subgraph)
1092 # k: old subg index, v: new subg index
1093 # new subg index is sequential in used_subgraphs_dic
1094 for k, v in used_subgraphs_dic.items():
1095 print("Append subgraphs, old index : ", k, ", new index : ", v)
1096 if k == args.subgraph:
1098 new_subgraph = GenerateSubgraph(new_builder, selected_subgraph, operator_list,
1099 new_input_tensor, new_output_tensor,
1100 used_tensors_dic, used_buffers_dic,
1101 used_opcodes_dic, used_subgraphs_dic)
1102 new_subgraph_list.append(new_subgraph)
1104 subg = sample_model.Subgraphs(k)
1105 subg_opperator_idx_list = range(subg.OperatorsLength())
1106 subg_input_tensors = subg.InputsAsNumpy()
1107 subg_output_tensors = subg.OutputsAsNumpy()
1108 subg_tensors = range(subg.TensorsLength())
1109 subg_tensors_dic = {tensor_idx: tensor_idx for tensor_idx in subg_tensors}
1110 subg_buffers_dic = {(subg.Tensors(idx)).Buffer():
1111 (subg.Tensors(idx)).Buffer()
1112 for idx in subg_tensors}
1113 new_subgraph = GenerateSubgraph(new_builder, subg, subg_opperator_idx_list,
1114 subg_input_tensors, subg_output_tensors,
1115 subg_tensors_dic, subg_buffers_dic,
1116 used_opcodes_dic, used_subgraphs_dic)
1117 new_subgraph_list.append(new_subgraph)
1119 new_subgraph_num = len(new_subgraph_list)
1120 tflite.Model.ModelStartSubgraphsVector(new_builder, new_subgraph_num)
1121 for subgraph_idx in reversed(range(new_subgraph_num)):
1122 new_builder.PrependUOffsetTRelative(new_subgraph_list[subgraph_idx])
1124 return new_builder.EndVector(new_subgraph_num)
1127 def GenerateBuffers(new_builder, sample_model, used_buffers_dic):
1128 buffer_num = sample_model.BuffersLength()
1129 new_buffer_data_list = {}
1130 new_buffer_list = []
1135 # Create data vector for buffer table
1136 for buffer_idx in range(buffer_num):
1137 buffer = sample_model.Buffers(buffer_idx)
1138 buffer_length = buffer.DataLength()
1140 if (buffer_length != 0) and (buffer_idx in used_buffers_dic):
1141 tflite.Buffer.BufferStartDataVector(new_builder, buffer_length)
1142 for buffer_data_idx in reversed(range(buffer_length)):
1143 new_builder.PrependUint8(buffer.Data(buffer_data_idx))
1144 new_buffer = new_builder.EndVector(buffer_length)
1145 new_buffer_data_list[buffer_idx] = new_buffer
1147 # Create tables of buffer
1148 for buffer_idx in range(buffer_num):
1149 buffer = sample_model.Buffers(buffer_idx)
1151 if buffer_idx in used_buffers_dic:
1152 # Create buffer table
1153 tflite.Buffer.BufferStart(new_builder)
1154 if buffer.DataLength() != 0:
1155 tflite.Buffer.BufferAddData(new_builder, new_buffer_data_list[buffer_idx])
1156 new_buffer = tflite.Buffer.BufferEnd(new_builder)
1157 new_buffer_list.append(new_buffer)
1159 # Create buffer vector
1160 new_buffer_num = len(new_buffer_list)
1161 if new_buffer_num == 0:
1164 tflite.Model.ModelStartBuffersVector(new_builder, new_buffer_num)
1165 for new_buffer_idx in reversed(range(new_buffer_num)):
1166 new_builder.PrependUOffsetTRelative(new_buffer_list[new_buffer_idx])
1168 return new_builder.EndVector(new_buffer_num)
1171 def GenerateModel(args, new_builder, sample_model, operator_list, new_input_tensors,
1172 new_output_tensors, used_tensors_dic, used_buffers_dic,
1173 used_opcodes_dic, used_subgraphs_dic):
1175 version = sample_model.Version()
1177 # pointer of operator code 'table' vector
1178 operator_codes = GenerateOperatorCodes(new_builder, sample_model, used_opcodes_dic,
1182 subgraphs = GenerateSubgraphs(args, new_builder, sample_model, operator_list,
1183 new_input_tensors, new_output_tensors, used_tensors_dic,
1184 used_buffers_dic, used_opcodes_dic, used_subgraphs_dic)
1187 description_string = new_builder.CreateString(sample_model.Description())
1190 buffers = GenerateBuffers(new_builder, sample_model, used_buffers_dic)
1193 tflite.Model.ModelStart(new_builder)
1194 tflite.Model.ModelAddVersion(new_builder, version)
1195 tflite.Model.ModelAddOperatorCodes(new_builder, operator_codes)
1196 tflite.Model.ModelAddSubgraphs(new_builder, subgraphs)
1197 tflite.Model.ModelAddDescription(new_builder, description_string)
1198 tflite.Model.ModelAddBuffers(new_builder, buffers)
1200 return tflite.Model.ModelEnd(new_builder)
1204 input_model_file = args.input_model
1205 oplist_file = args.opcode_list
1206 output_model_file = args.output_model
1207 subgraph = args.subgraph
1209 # Parse operator list file
1210 operator_list = GetOperatorList(oplist_file)
1212 # Get sample model and subgraph
1213 # We use only 1st subgraph
1214 sample_buf = input_model_file.read()
1215 sample_buf = bytearray(sample_buf)
1216 sample_model = tflite.Model.Model.GetRootAsModel(sample_buf, 0)
1217 sample_subgraph = sample_model.Subgraphs(subgraph)
1219 used_subgraphs_list = []
1220 used_subgraphs_list.append(args.subgraph)
1221 GetUsedSubgraphsList(sample_model, args.subgraph, operator_list, used_subgraphs_list)
1223 used_subgraphs_dic = {}
1224 for new_subgraph_idx in range(len(used_subgraphs_list)):
1225 sample_subgraph_idx = used_subgraphs_list[new_subgraph_idx]
1226 used_subgraphs_dic[sample_subgraph_idx] = new_subgraph_idx
1228 # Collect used tensor & used operator
1232 for operator_idx in operator_list:
1233 operator = sample_subgraph.Operators(operator_idx)
1234 for input_idx in range(operator.InputsLength()):
1235 input_tensor_idx = operator.Inputs(input_idx)
1236 if not input_tensor_idx == -1 and not input_tensor_idx in used_tensors:
1237 # default: same as input sample
1238 used_tensors.append(input_tensor_idx)
1240 for output_idx in range(operator.OutputsLength()):
1241 output_tensor_idx = operator.Outputs(output_idx)
1242 if not output_tensor_idx in used_tensors:
1243 # default: same as input sample
1244 used_tensors.append(output_tensor_idx)
1246 opcode_idx = operator.OpcodeIndex()
1247 if not opcode_idx in used_opcodes:
1248 used_opcodes.append(opcode_idx)
1250 # Append opcodes of child subgraphs
1251 for subgraph_idx in used_subgraphs_list:
1252 if subgraph_idx == subgraph:
1254 for operator_idx in range(sample_model.Subgraphs(subgraph_idx).OperatorsLength()):
1255 operator = sample_model.Subgraphs(subgraph_idx).Operators(operator_idx)
1256 opcode_idx = operator.OpcodeIndex()
1257 if not opcode_idx in used_opcodes:
1258 used_opcodes.append(opcode_idx)
1263 # Collect used buffer
1264 # buffer[0] should be blank. So it should start from 1
1267 for used_tensor in used_tensors:
1268 # key and value is same in prepare phase
1269 buf_idx = (sample_subgraph.Tensors(used_tensor)).Buffer()
1270 used_buffers.append(buf_idx)
1272 # Append buffers of tensors of child subgraphs
1273 for subgraph_idx in used_subgraphs_list:
1274 if subgraph_idx == subgraph:
1276 for tensor_idx in range(sample_model.Subgraphs(subgraph_idx).TensorsLength()):
1277 tensor = sample_model.Subgraphs(subgraph_idx).Tensors(tensor_idx)
1278 used_buffers.append(tensor.Buffer())
1282 # Assign new index for operator
1283 used_opcodes_dic = {}
1285 for new_operator_idx in range(len(used_opcodes)):
1286 sample_operator_idx = used_opcodes[new_operator_idx]
1287 used_opcodes_dic[sample_operator_idx] = new_operator_idx
1289 # Assign new index for tensor
1290 used_tensors_dic = {}
1292 for new_tensor_idx in range(len(used_tensors)):
1293 sample_tensor_idx = used_tensors[new_tensor_idx]
1294 used_tensors_dic[sample_tensor_idx] = new_tensor_idx
1296 # Assign new index for buffer
1297 used_buffers_dic = {}
1299 for new_buffer_idx in range(len(used_buffers)):
1300 sample_buffer_idx = used_buffers[new_buffer_idx]
1301 used_buffers_dic[sample_buffer_idx] = new_buffer_idx
1303 # Find input & output tensor in new model
1304 new_input_tensors = used_tensors[:]
1305 new_output_tensors = used_tensors[:]
1307 for operator_idx in operator_list:
1308 operator = sample_subgraph.Operators(operator_idx)
1309 for input_idx in range(operator.InputsLength()):
1310 input_tensor_idx = operator.Inputs(input_idx)
1311 if input_tensor_idx == -1:
1313 if input_tensor_idx in new_output_tensors:
1314 new_output_tensors.remove(input_tensor_idx)
1315 if input_tensor_idx in new_input_tensors:
1316 matched_buffer_idx = sample_subgraph.Tensors(input_tensor_idx).Buffer()
1317 matched_buffer = sample_model.Buffers(matched_buffer_idx)
1318 if matched_buffer.DataLength() != 0:
1319 new_input_tensors.remove(input_tensor_idx)
1321 for output_idx in range(operator.OutputsLength()):
1322 output_tensor_idx = operator.Outputs(output_idx)
1323 if output_tensor_idx in new_input_tensors:
1324 new_input_tensors.remove(output_tensor_idx)
1325 if output_tensor_idx in new_output_tensors:
1326 matched_buffer_idx = sample_subgraph.Tensors(output_tensor_idx).Buffer()
1327 matched_buffer = sample_model.Buffers(matched_buffer_idx)
1328 if matched_buffer.DataLength() != 0:
1329 new_output_tensors.remove(input_tensor_idx)
1331 new_input_tensors_newidx = []
1332 new_output_tensors_newidx = []
1334 for input_tensor_idx in new_input_tensors:
1335 new_input_tensors_newidx.append(used_tensors_dic[input_tensor_idx])
1336 for output_tensor_idx in new_output_tensors:
1337 new_output_tensors_newidx.append(used_tensors_dic[output_tensor_idx])
1339 print("Input tensor(s): " + str(new_input_tensors_newidx))
1340 print("Output tensor(s): " + str(new_output_tensors_newidx))
1342 # Create new model file
1343 new_builder = flatbuffers.Builder(1024)
1345 new_model = GenerateModel(args, new_builder, sample_model, operator_list,
1346 new_input_tensors, new_output_tensors, used_tensors_dic,
1347 used_buffers_dic, used_opcodes_dic, used_subgraphs_dic)
1349 new_builder.Finish(new_model, file_identifier=b'TFL3')
1350 new_buf = new_builder.Output()
1352 output_model_file.write(new_buf)
1355 if __name__ == '__main__':
1356 # Define argument and read
1357 arg_parser = argparse.ArgumentParser()
1358 arg_parser.add_argument(
1360 type=argparse.FileType('rb'),
1361 help="input tflite model file to read")
1362 arg_parser.add_argument(
1364 type=argparse.FileType('r'),
1365 help="text file including selected operator list")
1366 arg_parser.add_argument(
1367 "output_model", type=argparse.FileType('wb'), help="output tflite model file")
1368 arg_parser.add_argument(
1369 '-g', '--subgraph', type=int, default=0, help="subgraph to use (default: 0)")
1372 # Select multiple subgraph
1373 # Select subgraph by using opcode list file
1374 # Select opcode list by using argument
1376 args = arg_parser.parse_args()
1378 # Call main function