Imported Upstream version 1.9.0
[platform/core/ml/nnfw.git] / tools / tflitefile_tool / select_operator.py
1 #!/usr/bin/env python
2
3 # Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
4 #
5 # Licensed under the Apache License, Version 2.0 (the "License");
6 # you may not use this file except in compliance with the License.
7 # You may obtain a copy of the License at
8 #
9 #    http://www.apache.org/licenses/LICENSE-2.0
10 #
11 # Unless required by applicable law or agreed to in writing, software
12 # distributed under the License is distributed on an "AS IS" BASIS,
13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 # See the License for the specific language governing permissions and
15 # limitations under the License.
16
17 import os
18 import sys
19 import numpy
20 import flatbuffers
21 import tflite.Model
22 import tflite.SubGraph
23 import tflite.BuiltinOptions
24 import argparse
25
26
27 # Assume we use only main model in model file
28 # Get selected operators from file, and return operator index list
29 def GetOperatorList(oplist_file):
30     lines = oplist_file.readlines()
31     opcode_list = []
32
33     for line in lines:
34         words = line.split()
35         for word in words:
36             if word.isdigit():
37                 opcode_list.append(int(word))
38             else:
39                 opcode_range = word.split('-')
40                 if ((len(opcode_range) == 2) and opcode_range[0].isdigit()
41                         and opcode_range[1].isdigit()):
42                     start = int(opcode_range[0])
43                     end = int(opcode_range[1])
44                     for num in range(start, end + 1):
45                         opcode_list.append(int(num))
46                 else:
47                     print("Error: Cannot get operator list")
48                     print(
49                         "Please pass operators as operator index or range list split by space and/or line"
50                     )
51                     exit(1)
52
53     if len(opcode_list) == 0:
54         print("No selected operator")
55         exit(1)
56
57     return opcode_list
58
59
60 def GetUsedSubgraphsList(sample_model, subg_num, operator_list, used_subgraphs_list):
61     import tflite.IfOptions
62     import tflite.WhileOptions
63
64     subg_list = []
65
66     selected_subgraph = sample_model.Subgraphs(subg_num)
67
68     for operator_idx in operator_list:
69         selected_operator = selected_subgraph.Operators(operator_idx)
70         if selected_operator.BuiltinOptionsType() == tflite.BuiltinOptions.BuiltinOptions(
71         ).IfOptions:
72             selected_builtin_option = selected_operator.BuiltinOptions()
73             if_option = tflite.IfOptions.IfOptions()
74             if_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
75
76             subg_list.append(if_option.ElseSubgraphIndex())
77             subg_list.append(if_option.ThenSubgraphIndex())
78
79         if selected_operator.BuiltinOptionsType() == tflite.BuiltinOptions.BuiltinOptions(
80         ).WhileOptions:
81             selected_builtin_option = selected_operator.BuiltinOptions()
82             while_option = tflite.WhileOptions.WhileOptions()
83             while_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
84
85             subg_list.append(while_option.BodySubgraphIndex())
86             subg_list.append(while_option.CondSubgraphIndex())
87
88     for idx in subg_list:
89         if idx not in used_subgraphs_list:
90             used_subgraphs_list.append(idx)
91             GetUsedSubgraphsList(sample_model, idx,
92                                  range(sample_model.Subgraphs(idx).OperatorsLength() - 1),
93                                  used_subgraphs_list)
94
95
96 def GenerateOperatorCodes(new_builder, sample_model, used_opcodes_dic,
97                           used_subgraphs_dic):
98     operator_code_num = sample_model.OperatorCodesLength()
99     new_operator_code_list = []
100     new_operator_code_string_list = {}
101
102     if operator_code_num == 0:
103         return 0
104
105     # Create operator_code string
106     for operator_code_idx in range(operator_code_num):
107         if operator_code_idx in used_opcodes_dic:
108             operator_code = sample_model.OperatorCodes(operator_code_idx)
109             operator_code_string = operator_code.CustomCode()
110             if operator_code_string and (operator_code_string != "") and (
111                     not operator_code_string in new_operator_code_string_list):
112                 new_operator_code_string_list[
113                     operator_code_string] = new_builder.CreateString(operator_code_string)
114
115     # Create tables of operator_code
116     for operator_code_idx in range(operator_code_num):
117         if operator_code_idx in used_opcodes_dic:
118             operator_code = sample_model.OperatorCodes(operator_code_idx)
119
120             # Create operator_code table
121             tflite.OperatorCode.OperatorCodeStart(new_builder)
122             tflite.OperatorCode.OperatorCodeAddBuiltinCode(new_builder,
123                                                            operator_code.BuiltinCode())
124
125             new_operator_code_string = operator_code.CustomCode()
126             if new_operator_code_string in new_operator_code_string_list:
127                 tflite.OperatorCode.OperatorCodeAddCustomCode(
128                     new_builder, new_operator_code_string_list[new_operator_code_string])
129             new_operator_code = tflite.OperatorCode.OperatorCodeEnd(new_builder)
130             new_operator_code_list.append(new_operator_code)
131
132     # Create operator_code vector
133     new_operator_code_num = len(new_operator_code_list)
134     tflite.Model.ModelStartOperatorCodesVector(new_builder, new_operator_code_num)
135     for operator_code_idx in reversed(range(new_operator_code_num)):
136         new_builder.PrependUOffsetTRelative(new_operator_code_list[operator_code_idx])
137
138     return new_builder.EndVector(new_operator_code_num)
139
140
141 def GenerateQuantization(new_builder, selected_quantization):
142     # Create min vector
143     min_num = selected_quantization.MinLength()
144     if min_num != 0:
145         tflite.QuantizationParameters.QuantizationParametersStartMinVector(
146             new_builder, min_num)
147         for min_idx in reversed(range(min_num)):
148             new_builder.PrependFloat32(selected_quantization.Min(min_idx))
149         new_min = new_builder.EndVector(min_num)
150
151     # Create max vector
152     max_num = selected_quantization.MaxLength()
153     if max_num != 0:
154         tflite.QuantizationParameters.QuantizationParametersStartMaxVector(
155             new_builder, max_num)
156         for max_idx in reversed(range(max_num)):
157             new_builder.PrependFloat32(selected_quantization.Max(max_idx))
158         new_max = new_builder.EndVector(max_num)
159
160     # Create scale vector
161     scale_num = selected_quantization.ScaleLength()
162     if scale_num != 0:
163         tflite.QuantizationParameters.QuantizationParametersStartScaleVector(
164             new_builder, scale_num)
165         for scale_idx in reversed(range(scale_num)):
166             new_builder.PrependFloat32(selected_quantization.Scale(scale_idx))
167         new_scale = new_builder.EndVector(scale_num)
168
169     # Create zero_point vector
170     zeropoint_num = selected_quantization.ZeroPointLength()
171     if zeropoint_num != 0:
172         tflite.QuantizationParameters.QuantizationParametersStartZeroPointVector(
173             new_builder, zeropoint_num)
174         for zeropoint_idx in reversed(range(zeropoint_num)):
175             new_builder.PrependInt64(selected_quantization.ZeroPoint(zeropoint_idx))
176         new_zeropoint = new_builder.EndVector(zeropoint_num)
177
178     # Create quantization
179     tflite.QuantizationParameters.QuantizationParametersStart(new_builder)
180     if min_num != 0:
181         tflite.QuantizationParameters.QuantizationParametersAddMin(new_builder, new_min)
182     if max_num != 0:
183         tflite.QuantizationParameters.QuantizationParametersAddMax(new_builder, new_max)
184     if scale_num != 0:
185         tflite.QuantizationParameters.QuantizationParametersAddScale(
186             new_builder, new_scale)
187     if zeropoint_num != 0:
188         tflite.QuantizationParameters.QuantizationParametersAddZeroPoint(
189             new_builder, new_zeropoint)
190
191     return tflite.QuantizationParameters.QuantizationParametersEnd(new_builder)
192
193
194 def GenerateTensor(new_builder, selected_tensor, used_buffers_dic):
195
196     # Create shape vector for tensor
197     shape_num = selected_tensor.ShapeLength()
198     tflite.Tensor.TensorStartShapeVector(new_builder, shape_num)
199     if shape_num != 0:
200         for shape_idx in reversed(range(shape_num)):
201             new_builder.PrependInt32(selected_tensor.Shape(shape_idx))
202     new_shape = new_builder.EndVector(shape_num)
203
204     # Create tensor_type
205     tensor_type = selected_tensor.Type()
206
207     # Create input vector for tensor
208     buffer_idx = selected_tensor.Buffer()
209     new_buffer_idx = used_buffers_dic[buffer_idx]
210
211     # Create name string
212     name_string = selected_tensor.Name()
213     if name_string != "":
214         new_name = new_builder.CreateString(name_string)
215
216     # Create quantization
217     quantization = selected_tensor.Quantization()
218     if quantization != None:
219         new_quantization = GenerateQuantization(new_builder, quantization)
220
221     # Create tensor
222     tflite.Tensor.TensorStart(new_builder)
223     tflite.Tensor.TensorAddShape(new_builder, new_shape)
224     tflite.Tensor.TensorAddType(new_builder, tensor_type)
225     tflite.Tensor.TensorAddBuffer(new_builder, new_buffer_idx)
226     if name_string != "":
227         tflite.Tensor.TensorAddName(new_builder, new_name)
228     if quantization != None:
229         tflite.Tensor.TensorAddQuantization(new_builder, new_quantization)
230
231     return tflite.Tensor.TensorEnd(new_builder)
232
233
234 def GenerateTensors(new_builder, selected_subgraph, used_tensors_dic, used_buffers_dic):
235     tensor_num = selected_subgraph.TensorsLength()
236     new_tensor_list = []
237
238     if tensor_num == 0:
239         return 0
240
241     for tensor_idx in range(tensor_num):
242         if tensor_idx in used_tensors_dic:
243             selected_tensor = selected_subgraph.Tensors(tensor_idx)
244             new_tensor = GenerateTensor(new_builder, selected_tensor, used_buffers_dic)
245             new_tensor_list.append(new_tensor)
246
247     new_tensor_num = len(new_tensor_list)
248     if new_tensor_num == 0:
249         return 0
250
251     tflite.SubGraph.SubGraphStartTensorsVector(new_builder, new_tensor_num)
252     for new_tensor in reversed(new_tensor_list):
253         new_builder.PrependUOffsetTRelative(new_tensor)
254
255     return new_builder.EndVector(new_tensor_num)
256
257
258 def GenerateBuiltinOption(new_builder, selected_builtin_option, builtin_option_type,
259                           used_subgraphs_dic):
260
261     # Conv2D option
262     import tflite.Conv2DOptions
263     if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().Conv2DOptions:
264
265         conv2d_options = tflite.Conv2DOptions.Conv2DOptions()
266         conv2d_options.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
267
268         tflite.Conv2DOptions.Conv2DOptionsStart(new_builder)
269         tflite.Conv2DOptions.Conv2DOptionsAddPadding(new_builder,
270                                                      conv2d_options.Padding())
271         tflite.Conv2DOptions.Conv2DOptionsAddStrideW(new_builder,
272                                                      conv2d_options.StrideW())
273         tflite.Conv2DOptions.Conv2DOptionsAddStrideH(new_builder,
274                                                      conv2d_options.StrideH())
275         tflite.Conv2DOptions.Conv2DOptionsAddDilationWFactor(
276             new_builder, conv2d_options.DilationWFactor())
277         tflite.Conv2DOptions.Conv2DOptionsAddDilationHFactor(
278             new_builder, conv2d_options.DilationHFactor())
279         tflite.Conv2DOptions.Conv2DOptionsAddFusedActivationFunction(
280             new_builder, conv2d_options.FusedActivationFunction())
281         return tflite.Conv2DOptions.Conv2DOptionsEnd(new_builder)
282
283     # DepthwiseConv2D option
284     import tflite.DepthwiseConv2DOptions
285     if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions(
286     ).DepthwiseConv2DOptions:
287
288         depthconv2d_option = tflite.DepthwiseConv2DOptions.DepthwiseConv2DOptions()
289         depthconv2d_option.Init(selected_builtin_option.Bytes,
290                                 selected_builtin_option.Pos)
291
292         tflite.DepthwiseConv2DOptions.DepthwiseConv2DOptionsStart(new_builder)
293         tflite.DepthwiseConv2DOptions.DepthwiseConv2DOptionsAddPadding(
294             new_builder, depthconv2d_option.Padding())
295         tflite.DepthwiseConv2DOptions.DepthwiseConv2DOptionsAddStrideW(
296             new_builder, depthconv2d_option.StrideW())
297         tflite.DepthwiseConv2DOptions.DepthwiseConv2DOptionsAddStrideH(
298             new_builder, depthconv2d_option.StrideH())
299         tflite.DepthwiseConv2DOptions.DepthwiseConv2DOptionsAddDepthMultiplier(
300             new_builder, depthconv2d_option.DepthMultiplier())
301         tflite.DepthwiseConv2DOptions.DepthwiseConv2DOptionsAddFusedActivationFunction(
302             new_builder, depthconv2d_option.FusedActivationFunction())
303         tflite.DepthwiseConv2DOptions.DepthwiseConv2DOptionsAddDilationWFactor(
304             new_builder, depthconv2d_option.DilationWFactor())
305         tflite.DepthwiseConv2DOptions.DepthwiseConv2DOptionsAddDilationHFactor(
306             new_builder, depthconv2d_option.DilationHFactor())
307         return tflite.DepthwiseConv2DOptions.DepthwiseConv2DOptionsEnd(new_builder)
308
309     # ConcatEmbeddingsOptions: not supported
310     # LSHProjectionOptions: not supported
311
312     # Pool2DPOption
313     import tflite.Pool2DOptions
314     if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().Pool2DOptions:
315
316         pool2d_option = tflite.Pool2DOptions.Pool2DOptions()
317         pool2d_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
318
319         tflite.Pool2DOptions.Pool2DOptionsStart(new_builder)
320         tflite.Pool2DOptions.Pool2DOptionsAddPadding(new_builder, pool2d_option.Padding())
321         tflite.Pool2DOptions.Pool2DOptionsAddStrideW(new_builder, pool2d_option.StrideW())
322         tflite.Pool2DOptions.Pool2DOptionsAddStrideH(new_builder, pool2d_option.StrideH())
323         tflite.Pool2DOptions.Pool2DOptionsAddFilterWidth(new_builder,
324                                                          pool2d_option.FilterWidth())
325         tflite.Pool2DOptions.Pool2DOptionsAddFilterHeight(new_builder,
326                                                           pool2d_option.FilterHeight())
327         tflite.Pool2DOptions.Pool2DOptionsAddFusedActivationFunction(
328             new_builder, pool2d_option.FusedActivationFunction())
329         return tflite.Pool2DOptions.Pool2DOptionsEnd(new_builder)
330
331     # SVDFOptions: not supported
332
333     # RNNOptions
334     import tflite.RNNOptions
335     if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().RNNOptions:
336
337         rnn_option = tflite.RNNOptions.RNNOptions()
338         rnn_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
339
340         tflite.RNNOptions.RNNOptionsStart(new_builder)
341         tflite.RNNOptions.RNNOptionsAddFusedActivationFunction(
342             new_builder, rnn_option.FusedActivationFunction())
343         return tflite.RNNOptions.RNNOptionsEnd(new_builder)
344
345     # FullyConnectedOptions
346     import tflite.FullyConnectedOptions
347     if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions(
348     ).FullyConnectedOptions:
349
350         fc_option = tflite.FullyConnectedOptions.FullyConnectedOptions()
351         fc_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
352
353         tflite.FullyConnectedOptions.FullyConnectedOptionsStart(new_builder)
354         tflite.FullyConnectedOptions.FullyConnectedOptionsAddFusedActivationFunction(
355             new_builder, fc_option.FusedActivationFunction())
356         return tflite.FullyConnectedOptions.FullyConnectedOptionsEnd(new_builder)
357
358     # SoftmaxOptions
359     import tflite.SoftmaxOptions
360     if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().SoftmaxOptions:
361
362         softmax_option = tflite.SoftmaxOptions.SoftmaxOptions()
363         softmax_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
364
365         tflite.SoftmaxOptions.SoftmaxOptionsStart(new_builder)
366         tflite.SoftmaxOptions.SoftmaxOptionsAddBeta(new_builder, softmax_option.Beta())
367         return tflite.SoftmaxOptions.SoftmaxOptionsEnd(new_builder)
368
369     # ConcatenationOptions
370     import tflite.ConcatenationOptions
371     if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().ConcatenationOptions:
372
373         concat_option = tflite.ConcatenationOptions.ConcatenationOptions()
374         concat_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
375
376         tflite.ConcatenationOptions.ConcatenationOptionsStart(new_builder)
377         tflite.ConcatenationOptions.ConcatenationOptionsAddAxis(
378             new_builder, concat_option.Axis())
379         tflite.ConcatenationOptions.ConcatenationOptionsAddFusedActivationFunction(
380             new_builder, concat_option.FusedActivationFunction())
381         return tflite.ConcatenationOptions.ConcatenationOptionsEnd(new_builder)
382
383     # AddOptions
384     import tflite.AddOptions
385     if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().AddOptions:
386
387         add_option = tflite.AddOptions.AddOptions()
388         add_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
389
390         tflite.AddOptions.AddOptionsStart(new_builder)
391         tflite.AddOptions.AddOptionsAddFusedActivationFunction(
392             new_builder, add_option.FusedActivationFunction())
393         return tflite.AddOptions.AddOptionsEnd(new_builder)
394
395     # L2NormOptions
396     import tflite.L2NormOptions
397     if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().L2NormOptions:
398
399         l2norm_option = tflite.L2NormOptions.L2NormOptions()
400         l2norm_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
401
402         tflite.L2NormOptions.L2NormOptionsStart(new_builder)
403         tflite.L2NormOptions.L2NormOptionsAddFusedActivationFunction(
404             new_builder, l2norm_option.FusedActivationFunction())
405         return tflite.L2NormOptions.L2NormOptionsEnd(new_builder)
406
407     # LocalResponseNormalizationOptions
408     import tflite.LocalResponseNormalizationOptions
409     if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions(
410     ).LocalResponseNormalizationOptions:
411
412         lrn_option = tflite.LocalResponseNormalizationOptions.LocalResponseNormalizationOptions(
413         )
414         lrn_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
415
416         tflite.LocalResponseNormalizationOptions.LocalResponseNormalizationOptionsStart(
417             new_builder)
418         tflite.LocalResponseNormalizationOptions.LocalResponseNormalizationOptionsAddRadius(
419             new_builder, lrn_option.Radius())
420         tflite.LocalResponseNormalizationOptions.LocalResponseNormalizationOptionsAddBias(
421             new_builder, lrn_option.Bias())
422         tflite.LocalResponseNormalizationOptions.LocalResponseNormalizationOptionsAddAlpha(
423             new_builder, lrn_option.Alpha())
424         tflite.LocalResponseNormalizationOptions.LocalResponseNormalizationOptionsAddBeta(
425             new_builder, lrn_option.Beta())
426         return tflite.LocalResponseNormalizationOptions.LocalResponseNormalizationOptionsEnd(
427             new_builder)
428
429     # LSTMOptions: not supported
430
431     # ResizeBilinearOptions
432     import tflite.ResizeBilinearOptions
433     if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions(
434     ).ResizeBilinearOptions:
435
436         resize_bilinear_option = tflite.ResizeBilinearOptions.ResizeBilinearOptions()
437         resize_bilinear_option.Init(selected_builtin_option.Bytes,
438                                     selected_builtin_option.Pos)
439
440         tflite.ResizeBilinearOptions.ResizeBilinearOptionsStart(new_builder)
441         tflite.ResizeBilinearOptions.ResizeBilinearOptionsAddAlignCorners(
442             new_builder, resize_bilinear_option.AlignCorners())
443         return tflite.ResizeBilinearOptions.ResizeBilinearOptionsEnd(new_builder)
444
445     # CallOptions: not supported
446
447     # ReshapeOptions
448     import tflite.ReshapeOptions
449     if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().ReshapeOptions:
450
451         reshape_option = tflite.ReshapeOptions.ReshapeOptions()
452         reshape_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
453
454         shape_num = reshape_option.NewShapeLength()
455         if shape_num != 0:
456             tflite.ReshapeOptions.ReshapeOptionsStartNewShapeVector(
457                 new_builder, shape_num)
458             for new_shape_idx in reversed(range(shape_num)):
459                 new_shape_val = reshape_option.NewShape(new_shape_idx)
460                 new_builder.PrependInt32(new_shape_val)
461             new_shape = new_builder.EndVector(shape_num)
462
463         tflite.ReshapeOptions.ReshapeOptionsStart(new_builder)
464         if shape_num != 0:
465             tflite.ReshapeOptions.ReshapeOptionsAddNewShape(new_builder, new_shape)
466         return tflite.ReshapeOptions.ReshapeOptionsEnd(new_builder)
467
468     # SkipGramOptions: not supported
469
470     # SpaceToDepthOptions
471     import tflite.SpaceToDepthOptions
472     if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().SpaceToDepthOptions:
473
474         space_to_depth_option = tflite.SpaceToDepthOptions.SpaceToDepthOptions()
475         space_to_depth_option.Init(selected_builtin_option.Bytes,
476                                    selected_builtin_option.Pos)
477
478         tflite.SpaceToDepthOptions.SpaceToDepthOptionsStart(new_builder)
479         tflite.SpaceToDepthOptions.SpaceToDepthOptionsAddBlockSize(
480             new_builder, space_to_depth_option.BlockSize())
481         return tflite.SpaceToDepthOptions.SpaceToDepthOptionsEnd(new_builder)
482
483     # EmbeddingLookupSparseOptions: not supported
484
485     # MulOptions
486     import tflite.MulOptions
487     if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().MulOptions:
488
489         mul_option = tflite.MulOptions.MulOptions()
490         mul_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
491
492         tflite.MulOptions.MulOptionsStart(new_builder)
493         tflite.MulOptions.MulOptionsAddFusedActivationFunction(
494             new_builder, mul_option.FusedActivationFunction())
495         return tflite.MulOptions.MulOptionsEnd(new_builder)
496
497     # PadOptions
498     import tflite.PadOptions
499     if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().PadOptions:
500
501         pad_option = tflite.PadOptions.PadOptions()
502         pad_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
503
504         tflite.PadOptions.PadOptionsStart(new_builder)
505         return tflite.PadOptions.PadOptionsEnd(new_builder)
506
507     # GatherOptions
508     import tflite.GatherOptions
509     if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().GatherOptions:
510
511         gather_option = tflite.GatherOptions.GatherOptions()
512         gather_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
513
514         tflite.GatherOptions.GatherOptionsStart(new_builder)
515         tflite.GatherOptions.GatherOptionsAddAxis(new_builder, gather_option.Axis())
516         return tflite.GatherOptions.GatherOptionsEnd(new_builder)
517
518     # BatchToSpaceNDOptions
519     import tflite.BatchToSpaceNDOptions
520     if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions(
521     ).BatchToSpaceNDOptions:
522
523         btsnd_option = tflite.BatchToSpaceNDOptions.BatchToSpaceNDOptions()
524         btsnd_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
525
526         tflite.BatchToSpaceNDOptions.BatchToSpaceNDOptionsStart(new_builder)
527         return tflite.BatchToSpaceNDOptions.BatchToSpaceNDOptionsEnd(new_builder)
528
529     # SpaceToBatchNDOptions
530     import tflite.SpaceToBatchNDOptions
531     if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions(
532     ).SpaceToBatchNDOptions:
533
534         stbnd_option = tflite.SpaceToBatchNDOptions.SpaceToBatchNDOptions()
535         stbnd_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
536
537         tflite.SpaceToBatchNDOptions.SpaceToBatchNDOptionsStart(new_builder)
538         return tflite.SpaceToBatchNDOptions.SpaceToBatchNDOptionsEnd(new_builder)
539
540     # TransposeOptions:
541     import tflite.TransposeOptions
542     if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().TransposeOptions:
543
544         transpose_option = tflite.TransposeOptions.TransposeOptions()
545         transpose_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
546
547         tflite.TransposeOptions.TransposeOptionsStart(new_builder)
548         return tflite.TransposeOptions.TransposeOptionsEnd(new_builder)
549
550     # ReducerOptions
551     import tflite.ReducerOptions
552     if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().ReducerOptions:
553
554         reducer_option = tflite.ReducerOptions.ReducerOptions()
555         reducer_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
556
557         tflite.ReducerOptions.ReducerOptionsStart(new_builder)
558         tflite.ReducerOptions.ReducerOptionsAddKeepDims(new_builder,
559                                                         reducer_option.KeepDims())
560         return tflite.ReducerOptions.ReducerOptionsEnd(new_builder)
561
562     # SubOptions
563     import tflite.SubOptions
564     if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().SubOptions:
565
566         sub_option = tflite.SubOptions.SubOptions()
567         sub_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
568
569         tflite.SubOptions.SubOptionsStart(new_builder)
570         tflite.SubOptions.SubOptionsAddFusedActivationFunction(
571             new_builder, sub_option.FusedActivationFunction())
572         return tflite.SubOptions.SubOptionsEnd(new_builder)
573
574     # DivOptions
575     import tflite.DivOptions
576     if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().DivOptions:
577
578         div_option = tflite.DivOptions.DivOptions()
579         div_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
580
581         tflite.DivOptions.DivOptionsStart(new_builder)
582         tflite.DivOptions.DivOptionsAddFusedActivationFunction(
583             new_builder, div_option.FusedActivationFunction())
584         return tflite.DivOptions.DivOptionsEnd(new_builder)
585
586     # SqueezeOptions
587     import tflite.SqueezeOptions
588     if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().SqueezeOptions:
589
590         squeeze_option = tflite.SqueezeOptions.SqueezeOptions()
591         squeeze_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
592
593         squeeze_dims_num = squeeze_option.SqueezeDimsLength()
594         if squeeze_dims_num != 0:
595             tflite.SqueezeOptions.SqueezeOptionsStartSqueezeDimsVector(
596                 new_builder, squeeze_dims_num)
597             for squeeze_dims_idx in reversed(range(squeeze_dims_num)):
598                 squeeze_dims_val = squeeze_option.SqueezeDims(squeeze_dims_idx)
599                 new_builder.PrependInt32(squeeze_dims_val)
600             new_squeeze_dims = new_builder.EndVector(squeeze_dims_num)
601
602         tflite.SqueezeOptions.SqueezeOptionsStart(new_builder)
603         if squeeze_dims_num != 0:
604             tflite.SqueezeOptions.SqueezeOptionsAddSqueezeDims(new_builder,
605                                                                new_squeeze_dims)
606         return tflite.SqueezeOptions.SqueezeOptionsEnd(new_builder)
607
608     # SequenceRNNOptions: not supported
609
610     # StridedSliceOptions
611     import tflite.StridedSliceOptions
612     if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().StridedSliceOptions:
613
614         stride_slice_option = tflite.StridedSliceOptions.StridedSliceOptions()
615         stride_slice_option.Init(selected_builtin_option.Bytes,
616                                  selected_builtin_option.Pos)
617
618         tflite.StridedSliceOptions.StridedSliceOptionsStart(new_builder)
619         tflite.StridedSliceOptions.StridedSliceOptionsAddBeginMask(
620             new_builder, stride_slice_option.BeginMask())
621         tflite.StridedSliceOptions.StridedSliceOptionsAddEndMask(
622             new_builder, stride_slice_option.EndMask())
623         tflite.StridedSliceOptions.StridedSliceOptionsAddEllipsisMask(
624             new_builder, stride_slice_option.EllipsisMask())
625         tflite.StridedSliceOptions.StridedSliceOptionsAddNewAxisMask(
626             new_builder, stride_slice_option.NewAxisMask())
627         tflite.StridedSliceOptions.StridedSliceOptionsAddShrinkAxisMask(
628             new_builder, stride_slice_option.ShrinkAxisMask())
629
630         return tflite.StridedSliceOptions.StridedSliceOptionsEnd(new_builder)
631
632     # ExpOptions
633     import tflite.ExpOptions
634     if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().ExpOptions:
635
636         exp_option = tflite.ExpOptions.ExpOptions()
637         exp_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
638
639         tflite.ExpOptions.ExpOptionsStart(new_builder)
640         return tflite.ExpOptions.ExpOptionsEnd(new_builder)
641
642     # TopKV2Options
643     import tflite.TopKV2Options
644     if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().TopKV2Options:
645
646         topkv2_option = tflite.TopKV2Options.TopKV2Options()
647         topkv2_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
648
649         tflite.TopKV2Options.TopKV2OptionsStart(new_builder)
650         return tflite.TopKV2Options.TopKV2OptionsEnd(new_builder)
651
652     # SplitOptions
653     import tflite.SplitOptions
654     if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().SplitOptions:
655
656         split_option = tflite.SplitOptions.SplitOptions()
657         split_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
658
659         tflite.SplitOptions.SplitOptionsStart(new_builder)
660         tflite.SplitOptions.SplitOptionsAddNumSplits(new_builder,
661                                                      split_option.NumSplits())
662         return tflite.SplitOptions.SplitOptionsEnd(new_builder)
663
664     # LogSoftmaxOptions: not supported
665
666     # CastOptions: not supported
667     import tflite.CastOptions
668     if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().CastOptions:
669
670         cast_option = tflite.CastOptions.CastOptions()
671         cast_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
672
673         tflite.CastOptions.CastOptionsStart(new_builder)
674         return tflite.CastOptions.CastOptionsEnd(new_builder)
675
676     # DequantizeOptions:
677     import tflite.DequantizeOptions
678     if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().DequantizeOptions:
679
680         dequantize_option = tflite.DequantizeOptions.DequantizeOptions()
681         dequantize_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
682
683         tflite.EqualOptions.DequantizeOptionsStart(new_builder)
684         return tflite.DequantizeOptions.DequantizeOptionsEnd(new_builder)
685
686     # MaximumMinimumOptions: not supported
687
688     # ArgMaxOptions
689     import tflite.ArgMaxOptions
690     if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().ArgMaxOptions:
691
692         arg_max_option = tflite.ArgMaxOptions.ArgMaxOptions()
693         arg_max_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
694
695         tflite.ArgMaxOptions.ArgMaxOptionsStart(new_builder)
696         tflite.ArgMaxOptions.ArgMaxOptionsAddOutputType(new_builder,
697                                                         arg_max_option.OutputType())
698         return tflite.ArgMaxOptions.ArgMaxOptionsEnd(new_builder)
699
700     # LessOptions: not supported
701
702     # NegOptions
703     import tflite.NegOptions
704     if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().NegOptions:
705
706         neg_option = tflite.NegOptions.NegOptions()
707         neg_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
708
709         tflite.NegOptions.NegOptionsStart(new_builder)
710         return tflite.NegOptions.NegOptionsEnd(new_builder)
711
712     # EqualOptions
713     import tflite.EqualOptions
714     if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().EqualOptions:
715
716         equal_option = tflite.EqualOptions.EqualOptions()
717         equal_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
718
719         tflite.EqualOptions.EqualOptionsStart(new_builder)
720         return tflite.EqualOptions.EqualOptionsEnd(new_builder)
721
722     # PadV2Options: not supported
723     # GreaterOptions: not supported
724     # GreaterEqualOptions: not supported
725     # LessEqualOptions: not supported
726
727     # SelectOptions
728     import tflite.SelectOptions
729     if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().SelectOptions:
730
731         select_option = tflite.SelectOptions.SelectOptions()
732         select_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
733
734         tflite.SelectOptions.SelectOptionsStart(new_builder)
735         return tflite.SelectOptions.SelectOptionsEnd(new_builder)
736
737     # SliceOptions: not supported
738
739     # TransposeConvOptions
740     import tflite.TransposeConvOptions
741     if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().TransposeConvOptions:
742
743         transposeconv_option = tflite.TransposeConvOptions.TransposeConvOptions()
744         transposeconv_option.Init(selected_builtin_option.Bytes,
745                                   selected_builtin_option.Pos)
746
747         tflite.TransposeConvOptions.TransposeConvOptionsStart(new_builder)
748         tflite.TransposeConvOptions.TransposeConvOptionsAddPadding(
749             new_builder, transposeconv_option.Padding())
750         tflite.TransposeConvOptions.TransposeConvOptionsAddStrideW(
751             new_builder, transposeconv_option.StrideW())
752         tflite.TransposeConvOptions.TransposeConvOptionsAddStrideH(
753             new_builder, transposeconv_option.StrideH())
754         return tflite.TransposeConvOptions.TransposeConvOptionsEnd(new_builder)
755
756     # SparseToDenseOptions: not supported
757     # TileOptions: not supported
758
759     # ExpandDimsOptions:
760     import tflite.ExpandDimsOptions
761     if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().ExpandDimsOptions:
762
763         expanddims_option = tflite.ExpandDimsOptions.ExpandDimsOptions()
764         expanddims_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
765
766         tflite.ExpandDimsOptions.ExpandDimsOptionsStart(new_builder)
767         return tflite.ExpandDimsOptions.ExpandDimsOptionsEnd(new_builder)
768
769     # NotEqualOptions:
770     import tflite.NotEqualOptions
771     if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().NotEqualOptions:
772
773         notequal_option = tflite.NotEqualOptions.NotEqualOptions()
774         notequal_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
775
776         tflite.NotEqualOptions.NotEqualOptionsStart(new_builder)
777         return tflite.NotEqualOptions.NotEqualOptionsEnd(new_builder)
778
779     # ShapeOptions:
780     import tflite.ShapeOptions
781     if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().ShapeOptions:
782
783         shape_option = tflite.ShapeOptions.ShapeOptions()
784         shape_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
785
786         tflite.ShapeOptions.ShapeOptionsStart(new_builder)
787         tflite.ShapeOptions.ShapeOptionsAddOutType(new_builder, shape_option.OutType())
788         return tflite.ShapeOptions.ShapeOptionsEnd(new_builder)
789
790     # PowOptions: not supported
791     # ArgMinOptions: not supported
792     # FakeQuantOptions: not supported
793
794     # PackOptions:
795     import tflite.PackOptions
796     if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().PackOptions:
797
798         pack_option = tflite.PackOptions.PackOptions()
799         pack_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
800
801         tflite.PackOptions.PackOptionsStart(new_builder)
802         tflite.PackOptions.PackOptionsAddValuesCount(new_builder,
803                                                      pack_option.ValuesCount())
804         tflite.PackOptions.PackOptionsAddAxis(new_builder, pack_option.Axis())
805         return tflite.PackOptions.PackOptionsEnd(new_builder)
806
807     # LogicalOrOptions:
808     import tflite.LogicalOrOptions
809     if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().LogicalOrOptions:
810
811         logical_or_option = tflite.LogicalAndOptions.LogicalOrOptions()
812         logical_or_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
813
814         tflite.LogicalOrOptions.LogicalOrOptionsStart(new_builder)
815         return tflite.LogicalOrOptions.LogicalOrOptionsEnd(new_builder)
816
817     # OneHotOptions: not supported
818     import tflite.OneHotOptions
819     if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().OneHotOptions:
820
821         one_hot_option = tflite.OneHotOptions.OneHotOptions()
822         one_hot_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
823
824         tflite.OneHotOptions.OneHotOptionsStart(new_builder)
825         tflite.OneHotOptions.OneHotOptionsAddAxis(new_builder, one_hot_option.Axis())
826         return tflite.OneHotOptions.OneHotOptionsEnd(new_builder)
827
828     # LogicalNotOptions
829     import tflite.LogicalNotOptions
830     if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().LogicalNotOptions:
831
832         equal_option = tflite.LogicalNotOptions.LogicalNotOptions()
833         equal_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
834
835         tflite.LogicalNotOptions.LogicalNotOptionsStart(new_builder)
836         return tflite.LogicalNotOptions.LogicalNotOptionsEnd(new_builder)
837
838     # UnpackOptions:
839     import tflite.UnpackOptions
840     if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().UnpackOptions:
841
842         unpack_option = tflite.UnpackOptions.UnpackOptions()
843         unpack_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
844
845         tflite.UnpackOptions.UnpackOptionsStart(new_builder)
846         tflite.UnpackOptions.UnpackOptionsAddNum(new_builder, unpack_option.Num())
847         tflite.UnpackOptions.UnpackOptionsAddAxis(new_builder, unpack_option.Axis())
848         return tflite.UnpackOptions.UnpackOptionsEnd(new_builder)
849
850     # FloorDivOptions: not supported
851     # SquareOptions: not supported
852     # ZerosLikeOptions: not supported
853     # FillOptions: not supported
854
855     # LogicalAndOptions
856     import tflite.LogicalAndOptions
857     if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().LogicalAndOptions:
858
859         logical_and_option = tflite.LogicalAndOptions.LogicalAndOptions()
860         logical_and_option.Init(selected_builtin_option.Bytes,
861                                 selected_builtin_option.Pos)
862
863         tflite.LogicalAndOptions.LogicalAndOptionsStart(new_builder)
864         return tflite.LogicalAndOptions.LogicalAndOptionsEnd(new_builder)
865
866     # LogicalNotOptions: not supported
867     # UnpackOptions: not supported
868     # FloorDivOptions: not supported
869     # SquareOptions: not supported
870     # ZerosLikeOptions: not supported
871     # FillOptions: not supported
872     # BidirectionalSequenceLSTMOptions: not supported
873     # BidirectionalSequenceRNNOptions: not supported
874     # UnidirectionalSequenceLSTMOptions: not supported
875     # FloorModOptions: not supported
876     # RangeOptions: not supported
877     # ResizeNearestNeighborOptions: not supported
878
879     # LeakyReluOptions
880     import tflite.LeakyReluOptions
881     if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().LeakyReluOptions:
882
883         leaky_relu_option = tflite.LeakyReluOptions.LeakyReluOptions()
884         leaky_relu_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
885
886         tflite.LeakyReluOptions.LeakyReluOptionsStart(new_builder)
887         tflite.LeakyReluOptions.LeakyReluOptionsAddAlpha(new_builder,
888                                                          leaky_relu_option.Alpha())
889         return tflite.LeakyReluOptions.LeakyReluOptionsEnd(new_builder)
890
891     # SquaredDifferenceOptions
892     import tflite.SquaredDifferenceOptions
893     if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions(
894     ).SquaredDifferenceOptions:
895
896         squared_difference_option = tflite.SquaredDifferenceOptions.SquaredDifferenceOptions(
897         )
898         squared_difference_option.Init(selected_builtin_option.Bytes,
899                                        selected_builtin_option.Pos)
900
901         tflite.SquaredDifferenceOptions.SquaredDifferenceOptionsStart(new_builder)
902         return tflite.SquaredDifferenceOptions.SquaredDifferenceOptionsEnd(new_builder)
903
904     # MirrorPadOptions: not supported
905     # AbsOptions: not supported
906     # SplitVOptions: not supported
907
908     # IfOptions
909     import tflite.IfOptions
910     if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().IfOptions:
911
912         if_option = tflite.IfOptions.IfOptions()
913         if_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
914
915         tflite.IfOptions.IfOptionsStart(new_builder)
916         tflite.IfOptions.IfOptionsAddElseSubgraphIndex(
917             new_builder, used_subgraphs_dic[if_option.ElseSubgraphIndex()])
918         tflite.IfOptions.IfOptionsAddThenSubgraphIndex(
919             new_builder, used_subgraphs_dic[if_option.ThenSubgraphIndex()])
920         return tflite.IfOptions.IfOptionsEnd(new_builder)
921
922     # WhileOptions
923     import tflite.WhileOptions
924     if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().WhileOptions:
925
926         while_option = tflite.WhileOptions.WhileOptions()
927         while_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
928
929         tflite.WhileOptions.WhileOptionsStart(new_builder)
930         tflite.WhileOptions.WhileOptionsAddBodySubgraphIndex(
931             new_builder, used_subgraphs_dic[while_option.BodySubgraphIndex()])
932         tflite.WhileOptions.WhileOptionsAddCondSubgraphIndex(
933             new_builder, used_subgraphs_dic[while_option.CondSubgraphIndex()])
934         return tflite.WhileOptions.WhileOptionsEnd(new_builder)
935
936     # Cannot handle builtin option type yet
937     print("Cannot handle BuiltinOptions {} yet. See BuiltinOptions.py for op name".format(
938         builtin_option_type))
939     exit(1)
940
941
942 def GenerateOperator(new_builder, selected_operator, used_tensors_dic, used_opcodes_dic,
943                      used_subgraphs_dic):
944
945     # define opcode_index
946     opcode_index = selected_operator.OpcodeIndex()
947     new_opcode_index = used_opcodes_dic[opcode_index]
948
949     # create input vector
950     input_num = selected_operator.InputsLength()
951     if input_num != 0:
952         tflite.Operator.OperatorStartInputsVector(new_builder, input_num)
953         for input_idx in reversed(range(input_num)):
954             input_tensor_idx = selected_operator.Inputs(input_idx)
955             if input_tensor_idx == -1:
956                 new_input_tensor_idx = -1
957             else:
958                 new_input_tensor_idx = used_tensors_dic[input_tensor_idx]
959             new_builder.PrependInt32(new_input_tensor_idx)
960         new_input = new_builder.EndVector(input_num)
961
962     # create output_vector
963     output_num = selected_operator.OutputsLength()
964     if output_num != 0:
965         tflite.Operator.OperatorStartOutputsVector(new_builder, output_num)
966         for output_idx in reversed(range(output_num)):
967             output_tensor_idx = selected_operator.Outputs(output_idx)
968             new_output_tensor_idx = used_tensors_dic[output_tensor_idx]
969             new_builder.PrependInt32(new_output_tensor_idx)
970         new_output = new_builder.EndVector(output_num)
971
972     # Create builtin_option
973     builtin_option_type = selected_operator.BuiltinOptionsType()
974     if builtin_option_type != 0:
975         selected_builtin_option = selected_operator.BuiltinOptions()
976         new_builtin_option = GenerateBuiltinOption(
977             new_builder, selected_builtin_option, builtin_option_type, used_subgraphs_dic)
978
979     # Create custum option vector
980     custom_option_num = selected_operator.CustomOptionsLength()
981     if custom_option_num != 0:
982         tflite.Operator.OperatorStartCustomOptionsVector(new_builder, custom_option_num)
983         for custom_option_idx in reversed(range(custom_option_num)):
984             new_builder.PrependUint8(selected_operator.CustomOptions(custom_option_idx))
985         new_custom_option = new_builder.EndVector(custom_option_num)
986
987     # Create custum option type
988     custom_option_type = selected_operator.CustomOptionsFormat()
989
990     # Create operator
991     tflite.Operator.OperatorStart(new_builder)
992     tflite.Operator.OperatorAddOpcodeIndex(new_builder, new_opcode_index)
993     if input_num != 0:
994         tflite.Operator.OperatorAddInputs(new_builder, new_input)
995     if output_num != 0:
996         tflite.Operator.OperatorAddOutputs(new_builder, new_output)
997     tflite.Operator.OperatorAddBuiltinOptionsType(new_builder, builtin_option_type)
998     if builtin_option_type != 0:
999         tflite.Operator.OperatorAddBuiltinOptions(new_builder, new_builtin_option)
1000     if custom_option_num != 0:
1001         tflite.Operator.OperatorAddCustomOptions(new_builder, new_custom_option)
1002     tflite.Operator.OperatorAddCustomOptionsFormat(new_builder, custom_option_type)
1003     return tflite.Operator.OperatorEnd(new_builder)
1004
1005
1006 def GenerateOperators(new_builder, selected_subgraph, operator_list, used_tensors_dic,
1007                       used_opcodes_dic, used_subgraphs_dic):
1008     operator_num = selected_subgraph.OperatorsLength()
1009     new_operator_list = []
1010
1011     if operator_num == 0:
1012         return 0
1013
1014     for operator_idx in range(operator_num):
1015         if operator_idx in operator_list:
1016             selected_operator = selected_subgraph.Operators(operator_idx)
1017             new_operator = GenerateOperator(new_builder, selected_operator,
1018                                             used_tensors_dic, used_opcodes_dic,
1019                                             used_subgraphs_dic)
1020             new_operator_list.append(new_operator)
1021
1022     new_operator_num = len(new_operator_list)
1023     if new_operator_num == 0:
1024         return 0
1025
1026     tflite.SubGraph.SubGraphStartOperatorsVector(new_builder, new_operator_num)
1027     for new_operator in reversed(new_operator_list):
1028         new_builder.PrependUOffsetTRelative(new_operator)
1029
1030     return new_builder.EndVector(new_operator_num)
1031
1032
1033 def GenerateSubgraph(new_builder, selected_subgraph, operator_list, new_input_tensor,
1034                      new_output_tensor, used_tensors_dic, used_buffers_dic,
1035                      used_opcodes_dic, used_subgraphs_dic):
1036
1037     # Tensors
1038     tensors = GenerateTensors(new_builder, selected_subgraph, used_tensors_dic,
1039                               used_buffers_dic)
1040
1041     # Create input vector for subgraph table
1042     new_input_tensor_num = len(new_input_tensor)
1043     if new_input_tensor_num != 0:
1044         tflite.SubGraph.SubGraphStartInputsVector(new_builder, new_input_tensor_num)
1045         for input_tensor_idx in reversed(new_input_tensor):
1046             new_input_tensor_idx = used_tensors_dic[input_tensor_idx]
1047             new_builder.PrependInt32(new_input_tensor_idx)
1048         new_inputs = new_builder.EndVector(new_input_tensor_num)
1049
1050     # Create output vector for subgraph table
1051     new_output_tensor_num = len(new_output_tensor)
1052     if new_output_tensor_num != 0:
1053         tflite.SubGraph.SubGraphStartOutputsVector(new_builder, new_output_tensor_num)
1054         for output_tensor_idx in reversed(new_output_tensor):
1055             new_output_tensor_idx = used_tensors_dic[output_tensor_idx]
1056             new_builder.PrependInt32(new_output_tensor_idx)
1057         new_outputs = new_builder.EndVector(new_output_tensor_num)
1058
1059     # Operators
1060     operators = GenerateOperators(new_builder, selected_subgraph, operator_list,
1061                                   used_tensors_dic, used_opcodes_dic, used_subgraphs_dic)
1062
1063     # Name
1064     subgraph_name = selected_subgraph.Name()
1065     have_name = False
1066     if subgraph_name and subgraph_name != "":
1067         have_name = True
1068         new_subgraph_name = new_builder.CreateString(subgraph_name)
1069
1070     tflite.SubGraph.SubGraphStart(new_builder)
1071     tflite.SubGraph.SubGraphAddTensors(new_builder, tensors)
1072     if new_input_tensor_num != 0:
1073         tflite.SubGraph.SubGraphAddInputs(new_builder, new_inputs)
1074     if new_output_tensor_num != 0:
1075         tflite.SubGraph.SubGraphAddOutputs(new_builder, new_outputs)
1076     tflite.SubGraph.SubGraphAddOperators(new_builder, operators)
1077     if have_name:
1078         tflite.SubGraph.SubGraphAddName(new_builder, new_subgraph_name)
1079
1080     return tflite.SubGraph.SubGraphEnd(new_builder)
1081
1082
1083 def GenerateSubgraphs(args, new_builder, sample_model, operator_list, new_input_tensor,
1084                       new_output_tensor, used_tensors_dic, used_buffers_dic,
1085                       used_opcodes_dic, used_subgraphs_dic):
1086
1087     new_subgraph_list = []
1088
1089     # The selected subgraph will be primary subgraph of the model to be created newly
1090     selected_subgraph = sample_model.Subgraphs(args.subgraph)
1091
1092     # k: old subg index, v: new subg index
1093     # new subg index is sequential in used_subgraphs_dic
1094     for k, v in used_subgraphs_dic.items():
1095         print("Append subgraphs, old index : ", k, ", new index : ", v)
1096         if k == args.subgraph:
1097             assert v == 0
1098             new_subgraph = GenerateSubgraph(new_builder, selected_subgraph, operator_list,
1099                                             new_input_tensor, new_output_tensor,
1100                                             used_tensors_dic, used_buffers_dic,
1101                                             used_opcodes_dic, used_subgraphs_dic)
1102             new_subgraph_list.append(new_subgraph)
1103         else:
1104             subg = sample_model.Subgraphs(k)
1105             subg_opperator_idx_list = range(subg.OperatorsLength())
1106             subg_input_tensors = subg.InputsAsNumpy()
1107             subg_output_tensors = subg.OutputsAsNumpy()
1108             subg_tensors = range(subg.TensorsLength())
1109             subg_tensors_dic = {tensor_idx: tensor_idx for tensor_idx in subg_tensors}
1110             subg_buffers_dic = {(subg.Tensors(idx)).Buffer():
1111                                 (subg.Tensors(idx)).Buffer()
1112                                 for idx in subg_tensors}
1113             new_subgraph = GenerateSubgraph(new_builder, subg, subg_opperator_idx_list,
1114                                             subg_input_tensors, subg_output_tensors,
1115                                             subg_tensors_dic, subg_buffers_dic,
1116                                             used_opcodes_dic, used_subgraphs_dic)
1117             new_subgraph_list.append(new_subgraph)
1118
1119     new_subgraph_num = len(new_subgraph_list)
1120     tflite.Model.ModelStartSubgraphsVector(new_builder, new_subgraph_num)
1121     for subgraph_idx in reversed(range(new_subgraph_num)):
1122         new_builder.PrependUOffsetTRelative(new_subgraph_list[subgraph_idx])
1123
1124     return new_builder.EndVector(new_subgraph_num)
1125
1126
1127 def GenerateBuffers(new_builder, sample_model, used_buffers_dic):
1128     buffer_num = sample_model.BuffersLength()
1129     new_buffer_data_list = {}
1130     new_buffer_list = []
1131
1132     if buffer_num == 0:
1133         return 0
1134
1135     # Create data vector for buffer table
1136     for buffer_idx in range(buffer_num):
1137         buffer = sample_model.Buffers(buffer_idx)
1138         buffer_length = buffer.DataLength()
1139
1140         if (buffer_length != 0) and (buffer_idx in used_buffers_dic):
1141             tflite.Buffer.BufferStartDataVector(new_builder, buffer_length)
1142             for buffer_data_idx in reversed(range(buffer_length)):
1143                 new_builder.PrependUint8(buffer.Data(buffer_data_idx))
1144             new_buffer = new_builder.EndVector(buffer_length)
1145             new_buffer_data_list[buffer_idx] = new_buffer
1146
1147     # Create tables of buffer
1148     for buffer_idx in range(buffer_num):
1149         buffer = sample_model.Buffers(buffer_idx)
1150
1151         if buffer_idx in used_buffers_dic:
1152             # Create buffer table
1153             tflite.Buffer.BufferStart(new_builder)
1154             if buffer.DataLength() != 0:
1155                 tflite.Buffer.BufferAddData(new_builder, new_buffer_data_list[buffer_idx])
1156             new_buffer = tflite.Buffer.BufferEnd(new_builder)
1157             new_buffer_list.append(new_buffer)
1158
1159     # Create buffer vector
1160     new_buffer_num = len(new_buffer_list)
1161     if new_buffer_num == 0:
1162         return 0
1163
1164     tflite.Model.ModelStartBuffersVector(new_builder, new_buffer_num)
1165     for new_buffer_idx in reversed(range(new_buffer_num)):
1166         new_builder.PrependUOffsetTRelative(new_buffer_list[new_buffer_idx])
1167
1168     return new_builder.EndVector(new_buffer_num)
1169
1170
1171 def GenerateModel(args, new_builder, sample_model, operator_list, new_input_tensors,
1172                   new_output_tensors, used_tensors_dic, used_buffers_dic,
1173                   used_opcodes_dic, used_subgraphs_dic):
1174     # uint
1175     version = sample_model.Version()
1176
1177     # pointer of operator code 'table' vector
1178     operator_codes = GenerateOperatorCodes(new_builder, sample_model, used_opcodes_dic,
1179                                            used_subgraphs_dic)
1180
1181     # subgraphs
1182     subgraphs = GenerateSubgraphs(args, new_builder, sample_model, operator_list,
1183                                   new_input_tensors, new_output_tensors, used_tensors_dic,
1184                                   used_buffers_dic, used_opcodes_dic, used_subgraphs_dic)
1185
1186     # description
1187     description_string = new_builder.CreateString(sample_model.Description())
1188
1189     # buffers
1190     buffers = GenerateBuffers(new_builder, sample_model, used_buffers_dic)
1191
1192     # Generate model
1193     tflite.Model.ModelStart(new_builder)
1194     tflite.Model.ModelAddVersion(new_builder, version)
1195     tflite.Model.ModelAddOperatorCodes(new_builder, operator_codes)
1196     tflite.Model.ModelAddSubgraphs(new_builder, subgraphs)
1197     tflite.Model.ModelAddDescription(new_builder, description_string)
1198     tflite.Model.ModelAddBuffers(new_builder, buffers)
1199
1200     return tflite.Model.ModelEnd(new_builder)
1201
1202
1203 def main(args):
1204     input_model_file = args.input_model
1205     oplist_file = args.opcode_list
1206     output_model_file = args.output_model
1207     subgraph = args.subgraph
1208
1209     # Parse operator list file
1210     operator_list = GetOperatorList(oplist_file)
1211
1212     # Get sample model and subgraph
1213     # We use only 1st subgraph
1214     sample_buf = input_model_file.read()
1215     sample_buf = bytearray(sample_buf)
1216     sample_model = tflite.Model.Model.GetRootAsModel(sample_buf, 0)
1217     sample_subgraph = sample_model.Subgraphs(subgraph)
1218
1219     used_subgraphs_list = []
1220     used_subgraphs_list.append(args.subgraph)
1221     GetUsedSubgraphsList(sample_model, args.subgraph, operator_list, used_subgraphs_list)
1222
1223     used_subgraphs_dic = {}
1224     for new_subgraph_idx in range(len(used_subgraphs_list)):
1225         sample_subgraph_idx = used_subgraphs_list[new_subgraph_idx]
1226         used_subgraphs_dic[sample_subgraph_idx] = new_subgraph_idx
1227
1228     # Collect used tensor & used operator
1229     used_tensors = []
1230     used_opcodes = []
1231
1232     for operator_idx in operator_list:
1233         operator = sample_subgraph.Operators(operator_idx)
1234         for input_idx in range(operator.InputsLength()):
1235             input_tensor_idx = operator.Inputs(input_idx)
1236             if not input_tensor_idx == -1 and not input_tensor_idx in used_tensors:
1237                 # default: same as input sample
1238                 used_tensors.append(input_tensor_idx)
1239
1240         for output_idx in range(operator.OutputsLength()):
1241             output_tensor_idx = operator.Outputs(output_idx)
1242             if not output_tensor_idx in used_tensors:
1243                 # default: same as input sample
1244                 used_tensors.append(output_tensor_idx)
1245
1246         opcode_idx = operator.OpcodeIndex()
1247         if not opcode_idx in used_opcodes:
1248             used_opcodes.append(opcode_idx)
1249
1250     # Append opcodes of child subgraphs
1251     for subgraph_idx in used_subgraphs_list:
1252         if subgraph_idx == subgraph:
1253             continue
1254         for operator_idx in range(sample_model.Subgraphs(subgraph_idx).OperatorsLength()):
1255             operator = sample_model.Subgraphs(subgraph_idx).Operators(operator_idx)
1256             opcode_idx = operator.OpcodeIndex()
1257             if not opcode_idx in used_opcodes:
1258                 used_opcodes.append(opcode_idx)
1259
1260     used_tensors.sort()
1261     used_opcodes.sort()
1262
1263     # Collect used buffer
1264     # buffer[0] should be blank. So it should start from 1
1265     used_buffers = [0]
1266
1267     for used_tensor in used_tensors:
1268         # key and value is same in prepare phase
1269         buf_idx = (sample_subgraph.Tensors(used_tensor)).Buffer()
1270         used_buffers.append(buf_idx)
1271
1272     # Append buffers of tensors of child subgraphs
1273     for subgraph_idx in used_subgraphs_list:
1274         if subgraph_idx == subgraph:
1275             continue
1276         for tensor_idx in range(sample_model.Subgraphs(subgraph_idx).TensorsLength()):
1277             tensor = sample_model.Subgraphs(subgraph_idx).Tensors(tensor_idx)
1278             used_buffers.append(tensor.Buffer())
1279
1280     used_buffers.sort()
1281
1282     # Assign new index for operator
1283     used_opcodes_dic = {}
1284
1285     for new_operator_idx in range(len(used_opcodes)):
1286         sample_operator_idx = used_opcodes[new_operator_idx]
1287         used_opcodes_dic[sample_operator_idx] = new_operator_idx
1288
1289     # Assign new index for tensor
1290     used_tensors_dic = {}
1291
1292     for new_tensor_idx in range(len(used_tensors)):
1293         sample_tensor_idx = used_tensors[new_tensor_idx]
1294         used_tensors_dic[sample_tensor_idx] = new_tensor_idx
1295
1296     # Assign new index for buffer
1297     used_buffers_dic = {}
1298
1299     for new_buffer_idx in range(len(used_buffers)):
1300         sample_buffer_idx = used_buffers[new_buffer_idx]
1301         used_buffers_dic[sample_buffer_idx] = new_buffer_idx
1302
1303     # Find input & output tensor in new model
1304     new_input_tensors = used_tensors[:]
1305     new_output_tensors = used_tensors[:]
1306
1307     for operator_idx in operator_list:
1308         operator = sample_subgraph.Operators(operator_idx)
1309         for input_idx in range(operator.InputsLength()):
1310             input_tensor_idx = operator.Inputs(input_idx)
1311             if input_tensor_idx == -1:
1312                 continue
1313             if input_tensor_idx in new_output_tensors:
1314                 new_output_tensors.remove(input_tensor_idx)
1315             if input_tensor_idx in new_input_tensors:
1316                 matched_buffer_idx = sample_subgraph.Tensors(input_tensor_idx).Buffer()
1317                 matched_buffer = sample_model.Buffers(matched_buffer_idx)
1318                 if matched_buffer.DataLength() != 0:
1319                     new_input_tensors.remove(input_tensor_idx)
1320
1321         for output_idx in range(operator.OutputsLength()):
1322             output_tensor_idx = operator.Outputs(output_idx)
1323             if output_tensor_idx in new_input_tensors:
1324                 new_input_tensors.remove(output_tensor_idx)
1325             if output_tensor_idx in new_output_tensors:
1326                 matched_buffer_idx = sample_subgraph.Tensors(output_tensor_idx).Buffer()
1327                 matched_buffer = sample_model.Buffers(matched_buffer_idx)
1328                 if matched_buffer.DataLength() != 0:
1329                     new_output_tensors.remove(input_tensor_idx)
1330
1331     new_input_tensors_newidx = []
1332     new_output_tensors_newidx = []
1333
1334     for input_tensor_idx in new_input_tensors:
1335         new_input_tensors_newidx.append(used_tensors_dic[input_tensor_idx])
1336     for output_tensor_idx in new_output_tensors:
1337         new_output_tensors_newidx.append(used_tensors_dic[output_tensor_idx])
1338
1339     print("Input tensor(s): " + str(new_input_tensors_newidx))
1340     print("Output tensor(s): " + str(new_output_tensors_newidx))
1341
1342     # Create new model file
1343     new_builder = flatbuffers.Builder(1024)
1344
1345     new_model = GenerateModel(args, new_builder, sample_model, operator_list,
1346                               new_input_tensors, new_output_tensors, used_tensors_dic,
1347                               used_buffers_dic, used_opcodes_dic, used_subgraphs_dic)
1348
1349     new_builder.Finish(new_model, file_identifier=b'TFL3')
1350     new_buf = new_builder.Output()
1351
1352     output_model_file.write(new_buf)
1353
1354
1355 if __name__ == '__main__':
1356     # Define argument and read
1357     arg_parser = argparse.ArgumentParser()
1358     arg_parser.add_argument(
1359         "input_model",
1360         type=argparse.FileType('rb'),
1361         help="input tflite model file to read")
1362     arg_parser.add_argument(
1363         "opcode_list",
1364         type=argparse.FileType('r'),
1365         help="text file including selected operator list")
1366     arg_parser.add_argument(
1367         "output_model", type=argparse.FileType('wb'), help="output tflite model file")
1368     arg_parser.add_argument(
1369         '-g', '--subgraph', type=int, default=0, help="subgraph to use (default: 0)")
1370
1371     # TODO
1372     #   Select multiple subgraph
1373     #   Select subgraph by using opcode list file
1374     #   Select opcode list by using argument
1375
1376     args = arg_parser.parse_args()
1377
1378     # Call main function
1379     main(args)