3 # Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
5 # Licensed under the Apache License, Version 2.0 (the "License");
6 # you may not use this file except in compliance with the License.
7 # You may obtain a copy of the License at
9 # http://www.apache.org/licenses/LICENSE-2.0
11 # Unless required by applicable law or agreed to in writing, software
12 # distributed under the License is distributed on an "AS IS" BASIS,
13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 # See the License for the specific language governing permissions and
15 # limitations under the License.
21 sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tflite'))
24 os.path.dirname(os.path.abspath(__file__)), '../../externals/flatbuffers/python'))
28 import tflite.SubGraph
29 import tflite.BuiltinOptions
33 # Assume we use only main model in model file
34 # Get selected operators from file, and return operator index list
35 def GetOperatorList(oplist_file):
36 lines = oplist_file.readlines()
43 opcode_list.append(int(word))
45 opcode_range = word.split('-')
46 if ((len(opcode_range) == 2) and opcode_range[0].isdigit()
47 and opcode_range[1].isdigit()):
48 start = int(opcode_range[0])
49 end = int(opcode_range[1])
50 for num in range(start, end + 1):
51 opcode_list.append(int(num))
53 print("Error: Cannot get operator list")
55 "Please pass operators as operator index or range list split by space and/or line"
59 if len(opcode_list) == 0:
60 print("No selected operator")
66 def GetUsedSubgraphsList(sample_model, subg_num, operator_list, used_subgraphs_list):
67 import tflite.IfOptions
68 import tflite.WhileOptions
72 selected_subgraph = sample_model.Subgraphs(subg_num)
74 for operator_idx in operator_list:
75 selected_operator = selected_subgraph.Operators(operator_idx)
76 if selected_operator.BuiltinOptionsType() == tflite.BuiltinOptions.BuiltinOptions(
78 selected_builtin_option = selected_operator.BuiltinOptions()
79 if_option = tflite.IfOptions.IfOptions()
80 if_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
82 subg_list.append(if_option.ElseSubgraphIndex())
83 subg_list.append(if_option.ThenSubgraphIndex())
85 if selected_operator.BuiltinOptionsType() == tflite.BuiltinOptions.BuiltinOptions(
87 selected_builtin_option = selected_operator.BuiltinOptions()
88 while_option = tflite.WhileOptions.WhileOptions()
89 while_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
91 subg_list.append(while_option.BodySubgraphIndex())
92 subg_list.append(while_option.CondSubgraphIndex())
95 if idx not in used_subgraphs_list:
96 used_subgraphs_list.append(idx)
97 GetUsedSubgraphsList(sample_model, idx,
98 range(sample_model.Subgraphs(idx).OperatorsLength() - 1),
102 def GenerateOperatorCodes(new_builder, sample_model, used_opcodes_dic,
104 operator_code_num = sample_model.OperatorCodesLength()
105 new_operator_code_list = []
106 new_operator_code_string_list = {}
108 if operator_code_num == 0:
111 # Create operator_code string
112 for operator_code_idx in range(operator_code_num):
113 if operator_code_idx in used_opcodes_dic:
114 operator_code = sample_model.OperatorCodes(operator_code_idx)
115 operator_code_string = operator_code.CustomCode()
116 if operator_code_string and (operator_code_string != "") and (
117 not operator_code_string in new_operator_code_string_list):
118 new_operator_code_string_list[
119 operator_code_string] = new_builder.CreateString(operator_code_string)
121 # Create tables of operator_code
122 for operator_code_idx in range(operator_code_num):
123 if operator_code_idx in used_opcodes_dic:
124 operator_code = sample_model.OperatorCodes(operator_code_idx)
126 # Create operator_code table
127 tflite.OperatorCode.OperatorCodeStart(new_builder)
128 tflite.OperatorCode.OperatorCodeAddBuiltinCode(new_builder,
129 operator_code.BuiltinCode())
131 new_operator_code_string = operator_code.CustomCode()
132 if new_operator_code_string in new_operator_code_string_list:
133 tflite.OperatorCode.OperatorCodeAddCustomCode(
134 new_builder, new_operator_code_string_list[new_operator_code_string])
135 new_operator_code = tflite.OperatorCode.OperatorCodeEnd(new_builder)
136 new_operator_code_list.append(new_operator_code)
138 # Create operator_code vector
139 new_operator_code_num = len(new_operator_code_list)
140 tflite.Model.ModelStartOperatorCodesVector(new_builder, new_operator_code_num)
141 for operator_code_idx in reversed(range(new_operator_code_num)):
142 new_builder.PrependUOffsetTRelative(new_operator_code_list[operator_code_idx])
144 return new_builder.EndVector(new_operator_code_num)
147 def GenerateQuantization(new_builder, selected_quantization):
149 min_num = selected_quantization.MinLength()
151 tflite.QuantizationParameters.QuantizationParametersStartMinVector(
152 new_builder, min_num)
153 for min_idx in reversed(range(min_num)):
154 new_builder.PrependFloat32(selected_quantization.Min(min_idx))
155 new_min = new_builder.EndVector(min_num)
158 max_num = selected_quantization.MaxLength()
160 tflite.QuantizationParameters.QuantizationParametersStartMaxVector(
161 new_builder, max_num)
162 for max_idx in reversed(range(max_num)):
163 new_builder.PrependFloat32(selected_quantization.Max(max_idx))
164 new_max = new_builder.EndVector(max_num)
166 # Create scale vector
167 scale_num = selected_quantization.ScaleLength()
169 tflite.QuantizationParameters.QuantizationParametersStartScaleVector(
170 new_builder, scale_num)
171 for scale_idx in reversed(range(scale_num)):
172 new_builder.PrependFloat32(selected_quantization.Scale(scale_idx))
173 new_scale = new_builder.EndVector(scale_num)
175 # Create zero_point vector
176 zeropoint_num = selected_quantization.ZeroPointLength()
177 if zeropoint_num != 0:
178 tflite.QuantizationParameters.QuantizationParametersStartZeroPointVector(
179 new_builder, zeropoint_num)
180 for zeropoint_idx in reversed(range(zeropoint_num)):
181 new_builder.PrependInt64(selected_quantization.ZeroPoint(zeropoint_idx))
182 new_zeropoint = new_builder.EndVector(zeropoint_num)
184 # Create quantization
185 tflite.QuantizationParameters.QuantizationParametersStart(new_builder)
187 tflite.QuantizationParameters.QuantizationParametersAddMin(new_builder, new_min)
189 tflite.QuantizationParameters.QuantizationParametersAddMax(new_builder, new_max)
191 tflite.QuantizationParameters.QuantizationParametersAddScale(
192 new_builder, new_scale)
193 if zeropoint_num != 0:
194 tflite.QuantizationParameters.QuantizationParametersAddZeroPoint(
195 new_builder, new_zeropoint)
197 return tflite.QuantizationParameters.QuantizationParametersEnd(new_builder)
200 def GenerateTensor(new_builder, selected_tensor, used_buffers_dic):
202 # Create shape vector for tensor
203 shape_num = selected_tensor.ShapeLength()
204 tflite.Tensor.TensorStartShapeVector(new_builder, shape_num)
206 for shape_idx in reversed(range(shape_num)):
207 new_builder.PrependInt32(selected_tensor.Shape(shape_idx))
208 new_shape = new_builder.EndVector(shape_num)
211 tensor_type = selected_tensor.Type()
213 # Create input vector for tensor
214 buffer_idx = selected_tensor.Buffer()
215 new_buffer_idx = used_buffers_dic[buffer_idx]
218 name_string = selected_tensor.Name()
219 if name_string != "":
220 new_name = new_builder.CreateString(name_string)
222 # Create quantization
223 quantization = selected_tensor.Quantization()
224 if quantization != None:
225 new_quantization = GenerateQuantization(new_builder, quantization)
228 tflite.Tensor.TensorStart(new_builder)
229 tflite.Tensor.TensorAddShape(new_builder, new_shape)
230 tflite.Tensor.TensorAddType(new_builder, tensor_type)
231 tflite.Tensor.TensorAddBuffer(new_builder, new_buffer_idx)
232 if name_string != "":
233 tflite.Tensor.TensorAddName(new_builder, new_name)
234 if quantization != None:
235 tflite.Tensor.TensorAddQuantization(new_builder, new_quantization)
237 return tflite.Tensor.TensorEnd(new_builder)
240 def GenerateTensors(new_builder, selected_subgraph, used_tensors_dic, used_buffers_dic):
241 tensor_num = selected_subgraph.TensorsLength()
247 for tensor_idx in range(tensor_num):
248 if tensor_idx in used_tensors_dic:
249 selected_tensor = selected_subgraph.Tensors(tensor_idx)
250 new_tensor = GenerateTensor(new_builder, selected_tensor, used_buffers_dic)
251 new_tensor_list.append(new_tensor)
253 new_tensor_num = len(new_tensor_list)
254 if new_tensor_num == 0:
257 tflite.SubGraph.SubGraphStartTensorsVector(new_builder, new_tensor_num)
258 for new_tensor in reversed(new_tensor_list):
259 new_builder.PrependUOffsetTRelative(new_tensor)
261 return new_builder.EndVector(new_tensor_num)
264 def GenerateBuiltinOption(new_builder, selected_builtin_option, builtin_option_type,
268 import tflite.Conv2DOptions
269 if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().Conv2DOptions:
271 conv2d_options = tflite.Conv2DOptions.Conv2DOptions()
272 conv2d_options.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
274 tflite.Conv2DOptions.Conv2DOptionsStart(new_builder)
275 tflite.Conv2DOptions.Conv2DOptionsAddPadding(new_builder,
276 conv2d_options.Padding())
277 tflite.Conv2DOptions.Conv2DOptionsAddStrideW(new_builder,
278 conv2d_options.StrideW())
279 tflite.Conv2DOptions.Conv2DOptionsAddStrideH(new_builder,
280 conv2d_options.StrideH())
281 tflite.Conv2DOptions.Conv2DOptionsAddFusedActivationFunction(
282 new_builder, conv2d_options.FusedActivationFunction())
283 return tflite.Conv2DOptions.Conv2DOptionsEnd(new_builder)
285 # DepthwiseConv2D option
286 import tflite.DepthwiseConv2DOptions
287 if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions(
288 ).DepthwiseConv2DOptions:
290 depthconv2d_option = tflite.DepthwiseConv2DOptions.DepthwiseConv2DOptions()
291 depthconv2d_option.Init(selected_builtin_option.Bytes,
292 selected_builtin_option.Pos)
294 tflite.DepthwiseConv2DOptions.DepthwiseConv2DOptionsStart(new_builder)
295 tflite.DepthwiseConv2DOptions.DepthwiseConv2DOptionsAddPadding(
296 new_builder, depthconv2d_option.Padding())
297 tflite.DepthwiseConv2DOptions.DepthwiseConv2DOptionsAddStrideW(
298 new_builder, depthconv2d_option.StrideW())
299 tflite.DepthwiseConv2DOptions.DepthwiseConv2DOptionsAddStrideH(
300 new_builder, depthconv2d_option.StrideH())
301 tflite.DepthwiseConv2DOptions.DepthwiseConv2DOptionsAddDepthMultiplier(
302 new_builder, depthconv2d_option.DepthMultiplier())
303 tflite.DepthwiseConv2DOptions.DepthwiseConv2DOptionsAddFusedActivationFunction(
304 new_builder, depthconv2d_option.FusedActivationFunction())
305 tflite.DepthwiseConv2DOptions.DepthwiseConv2DOptionsAddDilationWFactor(
306 new_builder, depthconv2d_option.DilationWFactor())
307 tflite.DepthwiseConv2DOptions.DepthwiseConv2DOptionsAddDilationHFactor(
308 new_builder, depthconv2d_option.DilationHFactor())
309 return tflite.DepthwiseConv2DOptions.DepthwiseConv2DOptionsEnd(new_builder)
311 # ConcatEmbeddingsOptions: not supported
312 # LSHProjectionOptions: not supported
315 import tflite.Pool2DOptions
316 if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().Pool2DOptions:
318 pool2d_option = tflite.Pool2DOptions.Pool2DOptions()
319 pool2d_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
321 tflite.Pool2DOptions.Pool2DOptionsStart(new_builder)
322 tflite.Pool2DOptions.Pool2DOptionsAddPadding(new_builder, pool2d_option.Padding())
323 tflite.Pool2DOptions.Pool2DOptionsAddStrideW(new_builder, pool2d_option.StrideW())
324 tflite.Pool2DOptions.Pool2DOptionsAddStrideH(new_builder, pool2d_option.StrideH())
325 tflite.Pool2DOptions.Pool2DOptionsAddFilterWidth(new_builder,
326 pool2d_option.FilterWidth())
327 tflite.Pool2DOptions.Pool2DOptionsAddFilterHeight(new_builder,
328 pool2d_option.FilterHeight())
329 tflite.Pool2DOptions.Pool2DOptionsAddFusedActivationFunction(
330 new_builder, pool2d_option.FusedActivationFunction())
331 return tflite.Pool2DOptions.Pool2DOptionsEnd(new_builder)
333 # SVDFOptions: not supported
336 import tflite.RNNOptions
337 if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().RNNOptions:
339 rnn_option = tflite.RNNOptions.RNNOptions()
340 rnn_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
342 tflite.RNNOptions.RNNOptionsStart(new_builder)
343 tflite.RNNOptions.RNNOptionsAddFusedActivationFunction(
344 new_builder, rnn_option.FusedActivationFunction())
345 return tflite.RNNOptions.RNNOptionsEnd(new_builder)
347 # FullyConnectedOptions
348 import tflite.FullyConnectedOptions
349 if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions(
350 ).FullyConnectedOptions:
352 fc_option = tflite.FullyConnectedOptions.FullyConnectedOptions()
353 fc_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
355 tflite.FullyConnectedOptions.FullyConnectedOptionsStart(new_builder)
356 tflite.FullyConnectedOptions.FullyConnectedOptionsAddFusedActivationFunction(
357 new_builder, fc_option.FusedActivationFunction())
358 return tflite.FullyConnectedOptions.FullyConnectedOptionsEnd(new_builder)
361 import tflite.SoftmaxOptions
362 if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().SoftmaxOptions:
364 softmax_option = tflite.SoftmaxOptions.SoftmaxOptions()
365 softmax_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
367 tflite.SoftmaxOptions.SoftmaxOptionsStart(new_builder)
368 tflite.SoftmaxOptions.SoftmaxOptionsAddBeta(new_builder, softmax_option.Beta())
369 return tflite.SoftmaxOptions.SoftmaxOptionsEnd(new_builder)
371 # ConcatenationOptions
372 import tflite.ConcatenationOptions
373 if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().ConcatenationOptions:
375 concat_option = tflite.ConcatenationOptions.ConcatenationOptions()
376 concat_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
378 tflite.ConcatenationOptions.ConcatenationOptionsStart(new_builder)
379 tflite.ConcatenationOptions.ConcatenationOptionsAddAxis(
380 new_builder, concat_option.Axis())
381 tflite.ConcatenationOptions.ConcatenationOptionsAddFusedActivationFunction(
382 new_builder, concat_option.FusedActivationFunction())
383 return tflite.ConcatenationOptions.ConcatenationOptionsEnd(new_builder)
386 import tflite.AddOptions
387 if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().AddOptions:
389 add_option = tflite.AddOptions.AddOptions()
390 add_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
392 tflite.AddOptions.AddOptionsStart(new_builder)
393 tflite.AddOptions.AddOptionsAddFusedActivationFunction(
394 new_builder, add_option.FusedActivationFunction())
395 return tflite.AddOptions.AddOptionsEnd(new_builder)
398 import tflite.L2NormOptions
399 if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().L2NormOptions:
401 l2norm_option = tflite.L2NormOptions.L2NormOptions()
402 l2norm_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
404 tflite.L2NormOptions.L2NormOptionsStart(new_builder)
405 tflite.L2NormOptions.L2NormOptionsAddFusedActivationFunction(
406 new_builder, l2norm_option.FusedActivationFunction())
407 return tflite.L2NormOptions.L2NormOptionsEnd(new_builder)
409 # LocalResponseNormalizationOptions
410 import tflite.LocalResponseNormalizationOptions
411 if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions(
412 ).LocalResponseNormalizationOptions:
414 lrn_option = tflite.LocalResponseNormalizationOptions.LocalResponseNormalizationOptions(
416 lrn_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
418 tflite.LocalResponseNormalizationOptions.LocalResponseNormalizationOptionsStart(
420 tflite.LocalResponseNormalizationOptions.LocalResponseNormalizationOptionsAddRadius(
421 new_builder, lrn_option.Radius())
422 tflite.LocalResponseNormalizationOptions.LocalResponseNormalizationOptionsAddBias(
423 new_builder, lrn_option.Bias())
424 tflite.LocalResponseNormalizationOptions.LocalResponseNormalizationOptionsAddAlpha(
425 new_builder, lrn_option.Alpha())
426 tflite.LocalResponseNormalizationOptions.LocalResponseNormalizationOptionsAddBeta(
427 new_builder, lrn_option.Beta())
428 return tflite.LocalResponseNormalizationOptions.LocalResponseNormalizationOptionsEnd(
431 # LSTMOptions: not supported
433 # ResizeBilinearOptions
434 import tflite.ResizeBilinearOptions
435 if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions(
436 ).ResizeBilinearOptions:
438 resize_bilinear_option = tflite.ResizeBilinearOptions.ResizeBilinearOptions()
439 resize_bilinear_option.Init(selected_builtin_option.Bytes,
440 selected_builtin_option.Pos)
442 tflite.ResizeBilinearOptions.ResizeBilinearOptionsStart(new_builder)
443 tflite.ResizeBilinearOptions.ResizeBilinearOptionsAddAlignCorners(
444 new_builder, resize_bilinear_option.AlignCorners())
445 return tflite.ResizeBilinearOptions.ResizeBilinearOptionsEnd(new_builder)
447 # CallOptions: not supported
450 import tflite.ReshapeOptions
451 if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().ReshapeOptions:
453 reshape_option = tflite.ReshapeOptions.ReshapeOptions()
454 reshape_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
456 shape_num = reshape_option.NewShapeLength()
458 tflite.ReshapeOptions.ReshapeOptionsStartNewShapeVector(
459 new_builder, shape_num)
460 for new_shape_idx in reversed(range(shape_num)):
461 new_shape_val = reshape_option.NewShape(new_shape_idx)
462 new_builder.PrependInt32(new_shape_val)
463 new_shape = new_builder.EndVector(shape_num)
465 tflite.ReshapeOptions.ReshapeOptionsStart(new_builder)
467 tflite.ReshapeOptions.ReshapeOptionsAddNewShape(new_builder, new_shape)
468 return tflite.ReshapeOptions.ReshapeOptionsEnd(new_builder)
470 # SkipGramOptions: not supported
472 # SpaceToDepthOptions
473 import tflite.SpaceToDepthOptions
474 if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().SpaceToDepthOptions:
476 space_to_depth_option = tflite.SpaceToDepthOptions.SpaceToDepthOptions()
477 space_to_depth_option.Init(selected_builtin_option.Bytes,
478 selected_builtin_option.Pos)
480 tflite.SpaceToDepthOptions.SpaceToDepthOptionsStart(new_builder)
481 tflite.SpaceToDepthOptions.SpaceToDepthOptionsAddBlockSize(
482 new_builder, space_to_depth_option.BlockSize())
483 return tflite.SpaceToDepthOptions.SpaceToDepthOptionsEnd(new_builder)
485 # EmbeddingLookupSparseOptions: not supported
488 import tflite.MulOptions
489 if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().MulOptions:
491 mul_option = tflite.MulOptions.MulOptions()
492 mul_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
494 tflite.MulOptions.MulOptionsStart(new_builder)
495 tflite.MulOptions.MulOptionsAddFusedActivationFunction(
496 new_builder, mul_option.FusedActivationFunction())
497 return tflite.MulOptions.MulOptionsEnd(new_builder)
500 import tflite.PadOptions
501 if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().PadOptions:
503 pad_option = tflite.PadOptions.PadOptions()
504 pad_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
506 tflite.PadOptions.PadOptionsStart(new_builder)
507 return tflite.PadOptions.PadOptionsEnd(new_builder)
510 import tflite.GatherOptions
511 if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().GatherOptions:
513 gather_option = tflite.GatherOptions.GatherOptions()
514 gather_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
516 tflite.GatherOptions.GatherOptionsStart(new_builder)
517 tflite.GatherOptions.GatherOptionsAddAxis(new_builder, gather_option.Axis())
518 return tflite.GatherOptions.GatherOptionsEnd(new_builder)
520 # BatchToSpaceNDOptions
521 import tflite.BatchToSpaceNDOptions
522 if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions(
523 ).BatchToSpaceNDOptions:
525 btsnd_option = tflite.BatchToSpaceNDOptions.BatchToSpaceNDOptions()
526 btsnd_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
528 tflite.BatchToSpaceNDOptions.BatchToSpaceNDOptionsStart(new_builder)
529 return tflite.BatchToSpaceNDOptions.BatchToSpaceNDOptionsEnd(new_builder)
531 # SpaceToBatchNDOptions
532 import tflite.SpaceToBatchNDOptions
533 if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions(
534 ).SpaceToBatchNDOptions:
536 stbnd_option = tflite.SpaceToBatchNDOptions.SpaceToBatchNDOptions()
537 stbnd_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
539 tflite.SpaceToBatchNDOptions.SpaceToBatchNDOptionsStart(new_builder)
540 return tflite.SpaceToBatchNDOptions.SpaceToBatchNDOptionsEnd(new_builder)
543 import tflite.TransposeOptions
544 if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().TransposeOptions:
546 transpose_option = tflite.TransposeOptions.TransposeOptions()
547 transpose_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
549 tflite.TransposeOptions.TransposeOptionsStart(new_builder)
550 return tflite.TransposeOptions.TransposeOptionsEnd(new_builder)
553 import tflite.ReducerOptions
554 if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().ReducerOptions:
556 reducer_option = tflite.ReducerOptions.ReducerOptions()
557 reducer_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
559 tflite.ReducerOptions.ReducerOptionsStart(new_builder)
560 tflite.ReducerOptions.ReducerOptionsAddKeepDims(new_builder,
561 reducer_option.KeepDims())
562 return tflite.ReducerOptions.ReducerOptionsEnd(new_builder)
565 import tflite.SubOptions
566 if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().SubOptions:
568 sub_option = tflite.SubOptions.SubOptions()
569 sub_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
571 tflite.SubOptions.SubOptionsStart(new_builder)
572 tflite.SubOptions.SubOptionsAddFusedActivationFunction(
573 new_builder, sub_option.FusedActivationFunction())
574 return tflite.SubOptions.SubOptionsEnd(new_builder)
577 import tflite.DivOptions
578 if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().DivOptions:
580 div_option = tflite.DivOptions.DivOptions()
581 div_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
583 tflite.DivOptions.DivOptionsStart(new_builder)
584 tflite.DivOptions.DivOptionsAddFusedActivationFunction(
585 new_builder, div_option.FusedActivationFunction())
586 return tflite.DivOptions.DivOptionsEnd(new_builder)
589 import tflite.SqueezeOptions
590 if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().SqueezeOptions:
592 squeeze_option = tflite.SqueezeOptions.SqueezeOptions()
593 squeeze_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
595 squeeze_dims_num = squeeze_option.SqueezeDimsLength()
596 if squeeze_dims_num != 0:
597 tflite.SqueezeOptions.SqueezeOptionsStartSqueezeDimsVector(
598 new_builder, squeeze_dims_num)
599 for squeeze_dims_idx in reversed(range(squeeze_dims_num)):
600 squeeze_dims_val = squeeze_option.SqueezeDims(squeeze_dims_idx)
601 new_builder.PrependInt32(squeeze_dims_val)
602 new_squeeze_dims = new_builder.EndVector(squeeze_dims_num)
604 tflite.SqueezeOptions.SqueezeOptionsStart(new_builder)
605 if squeeze_dims_num != 0:
606 tflite.SqueezeOptions.SqueezeOptionsAddSqueezeDims(new_builder,
608 return tflite.SqueezeOptions.SqueezeOptionsEnd(new_builder)
610 # SequenceRNNOptions: not supported
612 # StridedSliceOptions
613 import tflite.StridedSliceOptions
614 if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().StridedSliceOptions:
616 stride_slice_option = tflite.StridedSliceOptions.StridedSliceOptions()
617 stride_slice_option.Init(selected_builtin_option.Bytes,
618 selected_builtin_option.Pos)
620 tflite.StridedSliceOptions.StridedSliceOptionsStart(new_builder)
621 tflite.StridedSliceOptions.StridedSliceOptionsAddBeginMask(
622 new_builder, stride_slice_option.BeginMask())
623 tflite.StridedSliceOptions.StridedSliceOptionsAddEndMask(
624 new_builder, stride_slice_option.EndMask())
625 tflite.StridedSliceOptions.StridedSliceOptionsAddEllipsisMask(
626 new_builder, stride_slice_option.EllipsisMask())
627 tflite.StridedSliceOptions.StridedSliceOptionsAddNewAxisMask(
628 new_builder, stride_slice_option.NewAxisMask())
629 tflite.StridedSliceOptions.StridedSliceOptionsAddShrinkAxisMask(
630 new_builder, stride_slice_option.ShrinkAxisMask())
632 return tflite.StridedSliceOptions.StridedSliceOptionsEnd(new_builder)
635 import tflite.ExpOptions
636 if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().ExpOptions:
638 exp_option = tflite.ExpOptions.ExpOptions()
639 exp_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
641 tflite.ExpOptions.ExpOptionsStart(new_builder)
642 return tflite.ExpOptions.ExpOptionsEnd(new_builder)
645 import tflite.TopKV2Options
646 if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().TopKV2Options:
648 topkv2_option = tflite.TopKV2Options.TopKV2Options()
649 topkv2_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
651 tflite.TopKV2Options.TopKV2OptionsStart(new_builder)
652 return tflite.TopKV2Options.TopKV2OptionsEnd(new_builder)
655 import tflite.SplitOptions
656 if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().SplitOptions:
658 split_option = tflite.SplitOptions.SplitOptions()
659 split_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
661 tflite.SplitOptions.SplitOptionsStart(new_builder)
662 tflite.SplitOptions.SplitOptionsAddNumSplits(new_builder,
663 split_option.NumSplits())
664 return tflite.SplitOptions.SplitOptionsEnd(new_builder)
666 # LogSoftmaxOptions: not supported
668 # CastOptions: not supported
669 import tflite.CastOptions
670 if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().CastOptions:
672 cast_option = tflite.CastOptions.CastOptions()
673 cast_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
675 tflite.CastOptions.CastOptionsStart(new_builder)
676 return tflite.CastOptions.CastOptionsEnd(new_builder)
679 import tflite.DequantizeOptions
680 if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().DequantizeOptions:
682 dequantize_option = tflite.DequantizeOptions.DequantizeOptions()
683 dequantize_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
685 tflite.EqualOptions.DequantizeOptionsStart(new_builder)
686 return tflite.DequantizeOptions.DequantizeOptionsEnd(new_builder)
688 # MaximumMinimumOptions: not supported
691 import tflite.ArgMaxOptions
692 if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().ArgMaxOptions:
694 arg_max_option = tflite.ArgMaxOptions.ArgMaxOptions()
695 arg_max_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
697 tflite.ArgMaxOptions.ArgMaxOptionsStart(new_builder)
698 tflite.ArgMaxOptions.ArgMaxOptionsAddOutputType(new_builder,
699 arg_max_option.OutputType())
700 return tflite.ArgMaxOptions.ArgMaxOptionsEnd(new_builder)
702 # LessOptions: not supported
705 import tflite.NegOptions
706 if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().NegOptions:
708 neg_option = tflite.NegOptions.NegOptions()
709 neg_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
711 tflite.NegOptions.NegOptionsStart(new_builder)
712 return tflite.NegOptions.NegOptionsEnd(new_builder)
715 import tflite.EqualOptions
716 if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().EqualOptions:
718 equal_option = tflite.EqualOptions.EqualOptions()
719 equal_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
721 tflite.EqualOptions.EqualOptionsStart(new_builder)
722 return tflite.EqualOptions.EqualOptionsEnd(new_builder)
724 # PadV2Options: not supported
725 # GreaterOptions: not supported
726 # GreaterEqualOptions: not supported
727 # LessEqualOptions: not supported
728 # SelectOptions: not supported
729 # SliceOptions: not supported
731 # TransposeConvOptions
732 import tflite.TransposeConvOptions
733 if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().TransposeConvOptions:
735 transposeconv_option = tflite.TransposeConvOptions.TransposeConvOptions()
736 transposeconv_option.Init(selected_builtin_option.Bytes,
737 selected_builtin_option.Pos)
739 tflite.TransposeConvOptions.TransposeConvOptionsStart(new_builder)
740 tflite.TransposeConvOptions.TransposeConvOptionsAddPadding(
741 new_builder, transposeconv_option.Padding())
742 tflite.TransposeConvOptions.TransposeConvOptionsAddStrideW(
743 new_builder, transposeconv_option.StrideW())
744 tflite.TransposeConvOptions.TransposeConvOptionsAddStrideH(
745 new_builder, transposeconv_option.StrideH())
746 return tflite.TransposeConvOptions.TransposeConvOptionsEnd(new_builder)
748 # SparseToDenseOptions: not supported
749 # TileOptions: not supported
752 import tflite.ExpandDimsOptions
753 if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().ExpandDimsOptions:
755 expanddims_option = tflite.ExpandDimsOptions.ExpandDimsOptions()
756 expanddims_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
758 tflite.ExpandDimsOptions.ExpandDimsOptionsStart(new_builder)
759 return tflite.ExpandDimsOptions.ExpandDimsOptionsEnd(new_builder)
762 import tflite.NotEqualOptions
763 if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().NotEqualOptions:
765 notequal_option = tflite.NotEqualOptions.NotEqualOptions()
766 notequal_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
768 tflite.NotEqualOptions.NotEqualOptionsStart(new_builder)
769 return tflite.NotEqualOptions.NotEqualOptionsEnd(new_builder)
772 import tflite.ShapeOptions
773 if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().ShapeOptions:
775 shape_option = tflite.ShapeOptions.ShapeOptions()
776 shape_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
778 tflite.ShapeOptions.ShapeOptionsStart(new_builder)
779 tflite.ShapeOptions.ShapeOptionsAddOutType(new_builder, shape_option.OutType())
780 return tflite.ShapeOptions.ShapeOptionsEnd(new_builder)
782 # PowOptions: not supported
783 # ArgMinOptions: not supported
784 # FakeQuantOptions: not supported
787 import tflite.PackOptions
788 if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().PackOptions:
790 pack_option = tflite.PackOptions.PackOptions()
791 pack_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
793 tflite.PackOptions.PackOptionsStart(new_builder)
794 tflite.PackOptions.PackOptionsAddValuesCount(new_builder,
795 pack_option.ValuesCount())
796 tflite.PackOptions.PackOptionsAddAxis(new_builder, pack_option.Axis())
797 return tflite.PackOptions.PackOptionsEnd(new_builder)
800 import tflite.LogicalOrOptions
801 if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().LogicalOrOptions:
803 logical_or_option = tflite.LogicalAndOptions.LogicalOrOptions()
804 logical_or_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
806 tflite.LogicalOrOptions.LogicalOrOptionsStart(new_builder)
807 return tflite.LogicalOrOptions.LogicalOrOptionsEnd(new_builder)
809 # OneHotOptions: not supported
810 import tflite.OneHotOptions
811 if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().OneHotOptions:
813 one_hot_option = tflite.OneHotOptions.OneHotOptions()
814 one_hot_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
816 tflite.OneHotOptions.OneHotOptionsStart(new_builder)
817 tflite.OneHotOptions.OneHotOptionsAddAxis(new_builder, one_hot_option.Axis())
818 return tflite.OneHotOptions.OneHotOptionsEnd(new_builder)
821 import tflite.LogicalNotOptions
822 if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().LogicalNotOptions:
824 equal_option = tflite.LogicalNotOptions.LogicalNotOptions()
825 equal_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
827 tflite.LogicalNotOptions.LogicalNotOptionsStart(new_builder)
828 return tflite.LogicalNotOptions.LogicalNotOptionsEnd(new_builder)
831 import tflite.UnpackOptions
832 if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().UnpackOptions:
834 unpack_option = tflite.UnpackOptions.UnpackOptions()
835 unpack_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
837 tflite.UnpackOptions.UnpackOptionsStart(new_builder)
838 tflite.UnpackOptions.UnpackOptionsAddNum(new_builder, unpack_option.Num())
839 tflite.UnpackOptions.UnpackOptionsAddAxis(new_builder, unpack_option.Axis())
840 return tflite.UnpackOptions.UnpackOptionsEnd(new_builder)
842 # FloorDivOptions: not supported
843 # SquareOptions: not supported
844 # ZerosLikeOptions: not supported
845 # FillOptions: not supported
848 import tflite.LogicalAndOptions
849 if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().LogicalAndOptions:
851 logical_and_option = tflite.LogicalAndOptions.LogicalAndOptions()
852 logical_and_option.Init(selected_builtin_option.Bytes,
853 selected_builtin_option.Pos)
855 tflite.LogicalAndOptions.LogicalAndOptionsStart(new_builder)
856 return tflite.LogicalAndOptions.LogicalAndOptionsEnd(new_builder)
858 # LogicalNotOptions: not supported
859 # UnpackOptions: not supported
860 # FloorDivOptions: not supported
861 # SquareOptions: not supported
862 # ZerosLikeOptions: not supported
863 # FillOptions: not supported
864 # BidirectionalSequenceLSTMOptions: not supported
865 # BidirectionalSequenceRNNOptions: not supported
866 # UnidirectionalSequenceLSTMOptions: not supported
867 # FloorModOptions: not supported
868 # RangeOptions: not supported
869 # ResizeNearestNeighborOptions: not supported
870 # LeakyReluOptions: not supported
872 # SquaredDifferenceOptions
873 import tflite.SquaredDifferenceOptions
874 if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions(
875 ).SquaredDifferenceOptions:
877 squared_difference_option = tflite.SquaredDifferenceOptions.SquaredDifferenceOptions(
879 squared_difference_option.Init(selected_builtin_option.Bytes,
880 selected_builtin_option.Pos)
882 tflite.SquaredDifferenceOptions.SquaredDifferenceOptionsStart(new_builder)
883 return tflite.SquaredDifferenceOptions.SquaredDifferenceOptionsEnd(new_builder)
885 # MirrorPadOptions: not supported
886 # AbsOptions: not supported
887 # SplitVOptions: not supported
890 import tflite.IfOptions
891 if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().IfOptions:
893 if_option = tflite.IfOptions.IfOptions()
894 if_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
896 tflite.IfOptions.IfOptionsStart(new_builder)
897 tflite.IfOptions.IfOptionsAddElseSubgraphIndex(
898 new_builder, used_subgraphs_dic[if_option.ElseSubgraphIndex()])
899 tflite.IfOptions.IfOptionsAddThenSubgraphIndex(
900 new_builder, used_subgraphs_dic[if_option.ThenSubgraphIndex()])
901 return tflite.IfOptions.IfOptionsEnd(new_builder)
904 import tflite.WhileOptions
905 if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().WhileOptions:
907 while_option = tflite.WhileOptions.WhileOptions()
908 while_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
910 tflite.WhileOptions.WhileOptionsStart(new_builder)
911 tflite.WhileOptions.WhileOptionsAddBodySubgraphIndex(
912 new_builder, used_subgraphs_dic[while_option.BodySubgraphIndex()])
913 tflite.WhileOptions.WhileOptionsAddCondSubgraphIndex(
914 new_builder, used_subgraphs_dic[while_option.CondSubgraphIndex()])
915 return tflite.WhileOptions.WhileOptionsEnd(new_builder)
917 # Cannot handle builtin option type yet
918 print("Cannot handle this option yet")
922 def GenerateOperator(new_builder, selected_operator, used_tensors_dic, used_opcodes_dic,
925 # define opcode_index
926 opcode_index = selected_operator.OpcodeIndex()
927 new_opcode_index = used_opcodes_dic[opcode_index]
929 # create input vector
930 input_num = selected_operator.InputsLength()
932 tflite.Operator.OperatorStartInputsVector(new_builder, input_num)
933 for input_idx in reversed(range(input_num)):
934 input_tensor_idx = selected_operator.Inputs(input_idx)
935 if input_tensor_idx == -1:
936 new_input_tensor_idx = -1
938 new_input_tensor_idx = used_tensors_dic[input_tensor_idx]
939 new_builder.PrependInt32(new_input_tensor_idx)
940 new_input = new_builder.EndVector(input_num)
942 # create output_vector
943 output_num = selected_operator.OutputsLength()
945 tflite.Operator.OperatorStartOutputsVector(new_builder, output_num)
946 for output_idx in reversed(range(output_num)):
947 output_tensor_idx = selected_operator.Outputs(output_idx)
948 new_output_tensor_idx = used_tensors_dic[output_tensor_idx]
949 new_builder.PrependInt32(new_output_tensor_idx)
950 new_output = new_builder.EndVector(output_num)
952 # Create builtin_option
953 builtin_option_type = selected_operator.BuiltinOptionsType()
954 if builtin_option_type != 0:
955 selected_builtin_option = selected_operator.BuiltinOptions()
956 new_builtin_option = GenerateBuiltinOption(
957 new_builder, selected_builtin_option, builtin_option_type, used_subgraphs_dic)
959 # Create custum option vector
960 custom_option_num = selected_operator.CustomOptionsLength()
961 if custom_option_num != 0:
962 tflite.Operator.OperatorStartCustomOptionsVector(new_builder, custom_option_num)
963 for custom_option_idx in reversed(range(custom_option_num)):
964 new_builder.PrependUint8(selected_operator.CustomOptions(custom_option_idx))
965 new_custom_option = new_builder.EndVector(custom_option_num)
967 # Create custum option type
968 custom_option_type = selected_operator.CustomOptionsFormat()
971 tflite.Operator.OperatorStart(new_builder)
972 tflite.Operator.OperatorAddOpcodeIndex(new_builder, new_opcode_index)
974 tflite.Operator.OperatorAddInputs(new_builder, new_input)
976 tflite.Operator.OperatorAddOutputs(new_builder, new_output)
977 tflite.Operator.OperatorAddBuiltinOptionsType(new_builder, builtin_option_type)
978 if builtin_option_type != 0:
979 tflite.Operator.OperatorAddBuiltinOptions(new_builder, new_builtin_option)
980 if custom_option_num != 0:
981 tflite.Operator.OperatorAddCustomOptions(new_builder, new_custom_option)
982 tflite.Operator.OperatorAddCustomOptionsFormat(new_builder, custom_option_type)
983 return tflite.Operator.OperatorEnd(new_builder)
986 def GenerateOperators(new_builder, selected_subgraph, operator_list, used_tensors_dic,
987 used_opcodes_dic, used_subgraphs_dic):
988 operator_num = selected_subgraph.OperatorsLength()
989 new_operator_list = []
991 if operator_num == 0:
994 for operator_idx in range(operator_num):
995 if operator_idx in operator_list:
996 selected_operator = selected_subgraph.Operators(operator_idx)
997 new_operator = GenerateOperator(new_builder, selected_operator,
998 used_tensors_dic, used_opcodes_dic,
1000 new_operator_list.append(new_operator)
1002 new_operator_num = len(new_operator_list)
1003 if new_operator_num == 0:
1006 tflite.SubGraph.SubGraphStartOperatorsVector(new_builder, new_operator_num)
1007 for new_operator in reversed(new_operator_list):
1008 new_builder.PrependUOffsetTRelative(new_operator)
1010 return new_builder.EndVector(new_operator_num)
1013 def GenerateSubgraph(new_builder, selected_subgraph, operator_list, new_input_tensor,
1014 new_output_tensor, used_tensors_dic, used_buffers_dic,
1015 used_opcodes_dic, used_subgraphs_dic):
1018 tensors = GenerateTensors(new_builder, selected_subgraph, used_tensors_dic,
1021 # Create input vector for subgraph table
1022 new_input_tensor_num = len(new_input_tensor)
1023 if new_input_tensor_num != 0:
1024 tflite.SubGraph.SubGraphStartInputsVector(new_builder, new_input_tensor_num)
1025 for input_tensor_idx in reversed(new_input_tensor):
1026 new_input_tensor_idx = used_tensors_dic[input_tensor_idx]
1027 new_builder.PrependInt32(new_input_tensor_idx)
1028 new_inputs = new_builder.EndVector(new_input_tensor_num)
1030 # Create output vector for subgraph table
1031 new_output_tensor_num = len(new_output_tensor)
1032 if new_output_tensor_num != 0:
1033 tflite.SubGraph.SubGraphStartOutputsVector(new_builder, new_output_tensor_num)
1034 for output_tensor_idx in reversed(new_output_tensor):
1035 new_output_tensor_idx = used_tensors_dic[output_tensor_idx]
1036 new_builder.PrependInt32(new_output_tensor_idx)
1037 new_outputs = new_builder.EndVector(new_output_tensor_num)
1040 operators = GenerateOperators(new_builder, selected_subgraph, operator_list,
1041 used_tensors_dic, used_opcodes_dic, used_subgraphs_dic)
1044 subgraph_name = selected_subgraph.Name()
1046 if subgraph_name and subgraph_name != "":
1048 new_subgraph_name = new_builder.CreateString(subgraph_name)
1050 tflite.SubGraph.SubGraphStart(new_builder)
1051 tflite.SubGraph.SubGraphAddTensors(new_builder, tensors)
1052 if new_input_tensor_num != 0:
1053 tflite.SubGraph.SubGraphAddInputs(new_builder, new_inputs)
1054 if new_output_tensor_num != 0:
1055 tflite.SubGraph.SubGraphAddOutputs(new_builder, new_outputs)
1056 tflite.SubGraph.SubGraphAddOperators(new_builder, operators)
1058 tflite.SubGraph.SubGraphAddName(new_builder, new_subgraph_name)
1060 return tflite.SubGraph.SubGraphEnd(new_builder)
1063 def GenerateSubgraphs(args, new_builder, sample_model, operator_list, new_input_tensor,
1064 new_output_tensor, used_tensors_dic, used_buffers_dic,
1065 used_opcodes_dic, used_subgraphs_dic):
1067 new_subgraph_list = []
1069 # The selected subgraph will be primary subgraph of the model to be created newly
1070 selected_subgraph = sample_model.Subgraphs(args.subgraph)
1072 # k: old subg index, v: new subg index
1073 # new subg index is sequential in used_subgraphs_dic
1074 for k, v in used_subgraphs_dic.items():
1075 print("Append subgraphs, old index : ", k, ", new index : ", v)
1076 if k == args.subgraph:
1078 new_subgraph = GenerateSubgraph(new_builder, selected_subgraph, operator_list,
1079 new_input_tensor, new_output_tensor,
1080 used_tensors_dic, used_buffers_dic,
1081 used_opcodes_dic, used_subgraphs_dic)
1082 new_subgraph_list.append(new_subgraph)
1084 subg = sample_model.Subgraphs(k)
1085 subg_opperator_idx_list = range(subg.OperatorsLength())
1086 subg_input_tensors = subg.InputsAsNumpy()
1087 subg_output_tensors = subg.OutputsAsNumpy()
1088 subg_tensors = range(subg.TensorsLength())
1089 subg_tensors_dic = {tensor_idx: tensor_idx for tensor_idx in subg_tensors}
1090 subg_buffers_dic = {(subg.Tensors(idx)).Buffer():
1091 (subg.Tensors(idx)).Buffer()
1092 for idx in subg_tensors}
1093 new_subgraph = GenerateSubgraph(new_builder, subg, subg_opperator_idx_list,
1094 subg_input_tensors, subg_output_tensors,
1095 subg_tensors_dic, subg_buffers_dic,
1096 used_opcodes_dic, used_subgraphs_dic)
1097 new_subgraph_list.append(new_subgraph)
1099 new_subgraph_num = len(new_subgraph_list)
1100 tflite.Model.ModelStartSubgraphsVector(new_builder, new_subgraph_num)
1101 for subgraph_idx in reversed(range(new_subgraph_num)):
1102 new_builder.PrependUOffsetTRelative(new_subgraph_list[subgraph_idx])
1104 return new_builder.EndVector(new_subgraph_num)
1107 def GenerateBuffers(new_builder, sample_model, used_buffers_dic):
1108 buffer_num = sample_model.BuffersLength()
1109 new_buffer_data_list = {}
1110 new_buffer_list = []
1115 # Create data vector for buffer table
1116 for buffer_idx in range(buffer_num):
1117 buffer = sample_model.Buffers(buffer_idx)
1118 buffer_length = buffer.DataLength()
1120 if (buffer_length != 0) and (buffer_idx in used_buffers_dic):
1121 tflite.Buffer.BufferStartDataVector(new_builder, buffer_length)
1122 for buffer_data_idx in reversed(range(buffer_length)):
1123 new_builder.PrependUint8(buffer.Data(buffer_data_idx))
1124 new_buffer = new_builder.EndVector(buffer_length)
1125 new_buffer_data_list[buffer_idx] = new_buffer
1127 # Create tables of buffer
1128 for buffer_idx in range(buffer_num):
1129 buffer = sample_model.Buffers(buffer_idx)
1131 if buffer_idx in used_buffers_dic:
1132 # Create buffer table
1133 tflite.Buffer.BufferStart(new_builder)
1134 if buffer.DataLength() != 0:
1135 tflite.Buffer.BufferAddData(new_builder, new_buffer_data_list[buffer_idx])
1136 new_buffer = tflite.Buffer.BufferEnd(new_builder)
1137 new_buffer_list.append(new_buffer)
1139 # Create buffer vector
1140 new_buffer_num = len(new_buffer_list)
1141 if new_buffer_num == 0:
1144 tflite.Model.ModelStartBuffersVector(new_builder, new_buffer_num)
1145 for new_buffer_idx in reversed(range(new_buffer_num)):
1146 new_builder.PrependUOffsetTRelative(new_buffer_list[new_buffer_idx])
1148 return new_builder.EndVector(new_buffer_num)
1151 def GenerateModel(args, new_builder, sample_model, operator_list, new_input_tensors,
1152 new_output_tensors, used_tensors_dic, used_buffers_dic,
1153 used_opcodes_dic, used_subgraphs_dic):
1155 version = sample_model.Version()
1157 # pointer of operator code 'table' vector
1158 operator_codes = GenerateOperatorCodes(new_builder, sample_model, used_opcodes_dic,
1162 subgraphs = GenerateSubgraphs(args, new_builder, sample_model, operator_list,
1163 new_input_tensors, new_output_tensors, used_tensors_dic,
1164 used_buffers_dic, used_opcodes_dic, used_subgraphs_dic)
1167 description_string = new_builder.CreateString(sample_model.Description())
1170 buffers = GenerateBuffers(new_builder, sample_model, used_buffers_dic)
1173 tflite.Model.ModelStart(new_builder)
1174 tflite.Model.ModelAddVersion(new_builder, version)
1175 tflite.Model.ModelAddOperatorCodes(new_builder, operator_codes)
1176 tflite.Model.ModelAddSubgraphs(new_builder, subgraphs)
1177 tflite.Model.ModelAddDescription(new_builder, description_string)
1178 tflite.Model.ModelAddBuffers(new_builder, buffers)
1180 return tflite.Model.ModelEnd(new_builder)
1184 input_model_file = args.input_model
1185 oplist_file = args.opcode_list
1186 output_model_file = args.output_model
1187 subgraph = args.subgraph
1189 # Parse operator list file
1190 operator_list = GetOperatorList(oplist_file)
1192 # Get sample model and subgraph
1193 # We use only 1st subgraph
1194 sample_buf = input_model_file.read()
1195 sample_buf = bytearray(sample_buf)
1196 sample_model = tflite.Model.Model.GetRootAsModel(sample_buf, 0)
1197 sample_subgraph = sample_model.Subgraphs(subgraph)
1199 used_subgraphs_list = []
1200 used_subgraphs_list.append(args.subgraph)
1201 GetUsedSubgraphsList(sample_model, args.subgraph, operator_list, used_subgraphs_list)
1203 used_subgraphs_dic = {}
1204 for new_subgraph_idx in range(len(used_subgraphs_list)):
1205 sample_subgraph_idx = used_subgraphs_list[new_subgraph_idx]
1206 used_subgraphs_dic[sample_subgraph_idx] = new_subgraph_idx
1208 # Collect used tensor & used operator
1212 for operator_idx in operator_list:
1213 operator = sample_subgraph.Operators(operator_idx)
1214 for input_idx in range(operator.InputsLength()):
1215 input_tensor_idx = operator.Inputs(input_idx)
1216 if not input_tensor_idx == -1 and not input_tensor_idx in used_tensors:
1217 # default: same as input sample
1218 used_tensors.append(input_tensor_idx)
1220 for output_idx in range(operator.OutputsLength()):
1221 output_tensor_idx = operator.Outputs(output_idx)
1222 if not output_tensor_idx in used_tensors:
1223 # default: same as input sample
1224 used_tensors.append(output_tensor_idx)
1226 opcode_idx = operator.OpcodeIndex()
1227 if not opcode_idx in used_opcodes:
1228 used_opcodes.append(opcode_idx)
1230 # Append opcodes of child subgraphs
1231 for subgraph_idx in used_subgraphs_list:
1232 if subgraph_idx == subgraph:
1234 for operator_idx in range(sample_model.Subgraphs(subgraph_idx).OperatorsLength()):
1235 operator = sample_model.Subgraphs(subgraph_idx).Operators(operator_idx)
1236 opcode_idx = operator.OpcodeIndex()
1237 if not opcode_idx in used_opcodes:
1238 used_opcodes.append(opcode_idx)
1243 # Collect used buffer
1244 # buffer[0] should be blank. So it should start from 1
1247 for used_tensor in used_tensors:
1248 # key and value is same in prepare phase
1249 buf_idx = (sample_subgraph.Tensors(used_tensor)).Buffer()
1250 used_buffers.append(buf_idx)
1252 # Append buffers of tensors of child subgraphs
1253 for subgraph_idx in used_subgraphs_list:
1254 if subgraph_idx == subgraph:
1256 for tensor_idx in range(sample_model.Subgraphs(subgraph_idx).TensorsLength()):
1257 tensor = sample_model.Subgraphs(subgraph_idx).Tensors(tensor_idx)
1258 used_buffers.append(tensor.Buffer())
1262 # Assign new index for operator
1263 used_opcodes_dic = {}
1265 for new_operator_idx in range(len(used_opcodes)):
1266 sample_operator_idx = used_opcodes[new_operator_idx]
1267 used_opcodes_dic[sample_operator_idx] = new_operator_idx
1269 # Assign new index for tensor
1270 used_tensors_dic = {}
1272 for new_tensor_idx in range(len(used_tensors)):
1273 sample_tensor_idx = used_tensors[new_tensor_idx]
1274 used_tensors_dic[sample_tensor_idx] = new_tensor_idx
1276 # Assign new index for buffer
1277 used_buffers_dic = {}
1279 for new_buffer_idx in range(len(used_buffers)):
1280 sample_buffer_idx = used_buffers[new_buffer_idx]
1281 used_buffers_dic[sample_buffer_idx] = new_buffer_idx
1283 # Find input & output tensor in new model
1284 new_input_tensors = used_tensors[:]
1285 new_output_tensors = used_tensors[:]
1287 for operator_idx in operator_list:
1288 operator = sample_subgraph.Operators(operator_idx)
1289 for input_idx in range(operator.InputsLength()):
1290 input_tensor_idx = operator.Inputs(input_idx)
1291 if input_tensor_idx == -1:
1293 if input_tensor_idx in new_output_tensors:
1294 new_output_tensors.remove(input_tensor_idx)
1295 if input_tensor_idx in new_input_tensors:
1296 matched_buffer_idx = sample_subgraph.Tensors(input_tensor_idx).Buffer()
1297 matched_buffer = sample_model.Buffers(matched_buffer_idx)
1298 if matched_buffer.DataLength() != 0:
1299 new_input_tensors.remove(input_tensor_idx)
1301 for output_idx in range(operator.OutputsLength()):
1302 output_tensor_idx = operator.Outputs(output_idx)
1303 if output_tensor_idx in new_input_tensors:
1304 new_input_tensors.remove(output_tensor_idx)
1305 if output_tensor_idx in new_output_tensors:
1306 matched_buffer_idx = sample_subgraph.Tensors(output_tensor_idx).Buffer()
1307 matched_buffer = sample_model.Buffers(matched_buffer_idx)
1308 if matched_buffer.DataLength() != 0:
1309 new_output_tensors.remove(input_tensor_idx)
1311 new_input_tensors_newidx = []
1312 new_output_tensors_newidx = []
1314 for input_tensor_idx in new_input_tensors:
1315 new_input_tensors_newidx.append(used_tensors_dic[input_tensor_idx])
1316 for output_tensor_idx in new_output_tensors:
1317 new_output_tensors_newidx.append(used_tensors_dic[output_tensor_idx])
1319 print("Input tensor(s): " + str(new_input_tensors_newidx))
1320 print("Output tensor(s): " + str(new_output_tensors_newidx))
1322 # Create new model file
1323 new_builder = flatbuffers.Builder(1024)
1325 new_model = GenerateModel(args, new_builder, sample_model, operator_list,
1326 new_input_tensors, new_output_tensors, used_tensors_dic,
1327 used_buffers_dic, used_opcodes_dic, used_subgraphs_dic)
1329 new_builder.Finish(new_model, file_identifier=b'TFL3')
1330 new_buf = new_builder.Output()
1332 output_model_file.write(new_buf)
1335 if __name__ == '__main__':
1336 # Define argument and read
1337 arg_parser = argparse.ArgumentParser()
1338 arg_parser.add_argument(
1340 type=argparse.FileType('rb'),
1341 help="input tflite model file to read")
1342 arg_parser.add_argument(
1344 type=argparse.FileType('r'),
1345 help="text file including selected operator list")
1346 arg_parser.add_argument(
1347 "output_model", type=argparse.FileType('wb'), help="output tflite model file")
1348 arg_parser.add_argument(
1349 '-g', '--subgraph', type=int, default=0, help="subgraph to use (default: 0)")
1352 # Select multiple subgraph
1353 # Select subgraph by using opcode list file
1354 # Select opcode list by using argument
1356 args = arg_parser.parse_args()
1358 # Call main function