333ca32f6fff3eb451f2500b4d26dbfe7e5634ed
[platform/core/ml/nnfw.git] / tools / tflitefile_tool / select_operator.py
1 #!/usr/bin/env python
2
3 # Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
4 #
5 # Licensed under the Apache License, Version 2.0 (the "License");
6 # you may not use this file except in compliance with the License.
7 # You may obtain a copy of the License at
8 #
9 #    http://www.apache.org/licenses/LICENSE-2.0
10 #
11 # Unless required by applicable law or agreed to in writing, software
12 # distributed under the License is distributed on an "AS IS" BASIS,
13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 # See the License for the specific language governing permissions and
15 # limitations under the License.
16
17 import os
18 import sys
19 import numpy
20
21 sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tflite'))
22 sys.path.append(
23     os.path.join(
24         os.path.dirname(os.path.abspath(__file__)), '../../externals/flatbuffers/python'))
25
26 import flatbuffers
27 import tflite.Model
28 import tflite.SubGraph
29 import tflite.BuiltinOptions
30 import argparse
31
32
33 # Assume we use only main model in model file
34 # Get selected operators from file, and return operator index list
35 def GetOperatorList(oplist_file):
36     lines = oplist_file.readlines()
37     opcode_list = []
38
39     for line in lines:
40         words = line.split()
41         for word in words:
42             if word.isdigit():
43                 opcode_list.append(int(word))
44             else:
45                 opcode_range = word.split('-')
46                 if ((len(opcode_range) == 2) and opcode_range[0].isdigit()
47                         and opcode_range[1].isdigit()):
48                     start = int(opcode_range[0])
49                     end = int(opcode_range[1])
50                     for num in range(start, end + 1):
51                         opcode_list.append(int(num))
52                 else:
53                     print("Error: Cannot get operator list")
54                     print(
55                         "Please pass operators as operator index or range list split by space and/or line"
56                     )
57                     exit(1)
58
59     if len(opcode_list) == 0:
60         print("No selected operator")
61         exit(1)
62
63     return opcode_list
64
65
66 def GetUsedSubgraphsList(sample_model, subg_num, operator_list, used_subgraphs_list):
67     import tflite.IfOptions
68     import tflite.WhileOptions
69
70     subg_list = []
71
72     selected_subgraph = sample_model.Subgraphs(subg_num)
73
74     for operator_idx in operator_list:
75         selected_operator = selected_subgraph.Operators(operator_idx)
76         if selected_operator.BuiltinOptionsType() == tflite.BuiltinOptions.BuiltinOptions(
77         ).IfOptions:
78             selected_builtin_option = selected_operator.BuiltinOptions()
79             if_option = tflite.IfOptions.IfOptions()
80             if_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
81
82             subg_list.append(if_option.ElseSubgraphIndex())
83             subg_list.append(if_option.ThenSubgraphIndex())
84
85         if selected_operator.BuiltinOptionsType() == tflite.BuiltinOptions.BuiltinOptions(
86         ).WhileOptions:
87             selected_builtin_option = selected_operator.BuiltinOptions()
88             while_option = tflite.WhileOptions.WhileOptions()
89             while_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
90
91             subg_list.append(while_option.BodySubgraphIndex())
92             subg_list.append(while_option.CondSubgraphIndex())
93
94     for idx in subg_list:
95         if idx not in used_subgraphs_list:
96             used_subgraphs_list.append(idx)
97             GetUsedSubgraphsList(sample_model, idx,
98                                  range(sample_model.Subgraphs(idx).OperatorsLength() - 1),
99                                  used_subgraphs_list)
100
101
102 def GenerateOperatorCodes(new_builder, sample_model, used_opcodes_dic,
103                           used_subgraphs_dic):
104     operator_code_num = sample_model.OperatorCodesLength()
105     new_operator_code_list = []
106     new_operator_code_string_list = {}
107
108     if operator_code_num == 0:
109         return 0
110
111     # Create operator_code string
112     for operator_code_idx in range(operator_code_num):
113         if operator_code_idx in used_opcodes_dic:
114             operator_code = sample_model.OperatorCodes(operator_code_idx)
115             operator_code_string = operator_code.CustomCode()
116             if operator_code_string and (operator_code_string != "") and (
117                     not operator_code_string in new_operator_code_string_list):
118                 new_operator_code_string_list[
119                     operator_code_string] = new_builder.CreateString(operator_code_string)
120
121     # Create tables of operator_code
122     for operator_code_idx in range(operator_code_num):
123         if operator_code_idx in used_opcodes_dic:
124             operator_code = sample_model.OperatorCodes(operator_code_idx)
125
126             # Create operator_code table
127             tflite.OperatorCode.OperatorCodeStart(new_builder)
128             tflite.OperatorCode.OperatorCodeAddBuiltinCode(new_builder,
129                                                            operator_code.BuiltinCode())
130
131             new_operator_code_string = operator_code.CustomCode()
132             if new_operator_code_string in new_operator_code_string_list:
133                 tflite.OperatorCode.OperatorCodeAddCustomCode(
134                     new_builder, new_operator_code_string_list[new_operator_code_string])
135             new_operator_code = tflite.OperatorCode.OperatorCodeEnd(new_builder)
136             new_operator_code_list.append(new_operator_code)
137
138     # Create operator_code vector
139     new_operator_code_num = len(new_operator_code_list)
140     tflite.Model.ModelStartOperatorCodesVector(new_builder, new_operator_code_num)
141     for operator_code_idx in reversed(range(new_operator_code_num)):
142         new_builder.PrependUOffsetTRelative(new_operator_code_list[operator_code_idx])
143
144     return new_builder.EndVector(new_operator_code_num)
145
146
147 def GenerateQuantization(new_builder, selected_quantization):
148     # Create min vector
149     min_num = selected_quantization.MinLength()
150     if min_num != 0:
151         tflite.QuantizationParameters.QuantizationParametersStartMinVector(
152             new_builder, min_num)
153         for min_idx in reversed(range(min_num)):
154             new_builder.PrependFloat32(selected_quantization.Min(min_idx))
155         new_min = new_builder.EndVector(min_num)
156
157     # Create max vector
158     max_num = selected_quantization.MaxLength()
159     if max_num != 0:
160         tflite.QuantizationParameters.QuantizationParametersStartMaxVector(
161             new_builder, max_num)
162         for max_idx in reversed(range(max_num)):
163             new_builder.PrependFloat32(selected_quantization.Max(max_idx))
164         new_max = new_builder.EndVector(max_num)
165
166     # Create scale vector
167     scale_num = selected_quantization.ScaleLength()
168     if scale_num != 0:
169         tflite.QuantizationParameters.QuantizationParametersStartScaleVector(
170             new_builder, scale_num)
171         for scale_idx in reversed(range(scale_num)):
172             new_builder.PrependFloat32(selected_quantization.Scale(scale_idx))
173         new_scale = new_builder.EndVector(scale_num)
174
175     # Create zero_point vector
176     zeropoint_num = selected_quantization.ZeroPointLength()
177     if zeropoint_num != 0:
178         tflite.QuantizationParameters.QuantizationParametersStartZeroPointVector(
179             new_builder, zeropoint_num)
180         for zeropoint_idx in reversed(range(zeropoint_num)):
181             new_builder.PrependInt64(selected_quantization.ZeroPoint(zeropoint_idx))
182         new_zeropoint = new_builder.EndVector(zeropoint_num)
183
184     # Create quantization
185     tflite.QuantizationParameters.QuantizationParametersStart(new_builder)
186     if min_num != 0:
187         tflite.QuantizationParameters.QuantizationParametersAddMin(new_builder, new_min)
188     if max_num != 0:
189         tflite.QuantizationParameters.QuantizationParametersAddMax(new_builder, new_max)
190     if scale_num != 0:
191         tflite.QuantizationParameters.QuantizationParametersAddScale(
192             new_builder, new_scale)
193     if zeropoint_num != 0:
194         tflite.QuantizationParameters.QuantizationParametersAddZeroPoint(
195             new_builder, new_zeropoint)
196
197     return tflite.QuantizationParameters.QuantizationParametersEnd(new_builder)
198
199
200 def GenerateTensor(new_builder, selected_tensor, used_buffers_dic):
201
202     # Create shape vector for tensor
203     shape_num = selected_tensor.ShapeLength()
204     tflite.Tensor.TensorStartShapeVector(new_builder, shape_num)
205     if shape_num != 0:
206         for shape_idx in reversed(range(shape_num)):
207             new_builder.PrependInt32(selected_tensor.Shape(shape_idx))
208     new_shape = new_builder.EndVector(shape_num)
209
210     # Create tensor_type
211     tensor_type = selected_tensor.Type()
212
213     # Create input vector for tensor
214     buffer_idx = selected_tensor.Buffer()
215     new_buffer_idx = used_buffers_dic[buffer_idx]
216
217     # Create name string
218     name_string = selected_tensor.Name()
219     if name_string != "":
220         new_name = new_builder.CreateString(name_string)
221
222     # Create quantization
223     quantization = selected_tensor.Quantization()
224     if quantization != None:
225         new_quantization = GenerateQuantization(new_builder, quantization)
226
227     # Create tensor
228     tflite.Tensor.TensorStart(new_builder)
229     tflite.Tensor.TensorAddShape(new_builder, new_shape)
230     tflite.Tensor.TensorAddType(new_builder, tensor_type)
231     tflite.Tensor.TensorAddBuffer(new_builder, new_buffer_idx)
232     if name_string != "":
233         tflite.Tensor.TensorAddName(new_builder, new_name)
234     if quantization != None:
235         tflite.Tensor.TensorAddQuantization(new_builder, new_quantization)
236
237     return tflite.Tensor.TensorEnd(new_builder)
238
239
240 def GenerateTensors(new_builder, selected_subgraph, used_tensors_dic, used_buffers_dic):
241     tensor_num = selected_subgraph.TensorsLength()
242     new_tensor_list = []
243
244     if tensor_num == 0:
245         return 0
246
247     for tensor_idx in range(tensor_num):
248         if tensor_idx in used_tensors_dic:
249             selected_tensor = selected_subgraph.Tensors(tensor_idx)
250             new_tensor = GenerateTensor(new_builder, selected_tensor, used_buffers_dic)
251             new_tensor_list.append(new_tensor)
252
253     new_tensor_num = len(new_tensor_list)
254     if new_tensor_num == 0:
255         return 0
256
257     tflite.SubGraph.SubGraphStartTensorsVector(new_builder, new_tensor_num)
258     for new_tensor in reversed(new_tensor_list):
259         new_builder.PrependUOffsetTRelative(new_tensor)
260
261     return new_builder.EndVector(new_tensor_num)
262
263
264 def GenerateBuiltinOption(new_builder, selected_builtin_option, builtin_option_type,
265                           used_subgraphs_dic):
266
267     # Conv2D option
268     import tflite.Conv2DOptions
269     if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().Conv2DOptions:
270
271         conv2d_options = tflite.Conv2DOptions.Conv2DOptions()
272         conv2d_options.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
273
274         tflite.Conv2DOptions.Conv2DOptionsStart(new_builder)
275         tflite.Conv2DOptions.Conv2DOptionsAddPadding(new_builder,
276                                                      conv2d_options.Padding())
277         tflite.Conv2DOptions.Conv2DOptionsAddStrideW(new_builder,
278                                                      conv2d_options.StrideW())
279         tflite.Conv2DOptions.Conv2DOptionsAddStrideH(new_builder,
280                                                      conv2d_options.StrideH())
281         tflite.Conv2DOptions.Conv2DOptionsAddFusedActivationFunction(
282             new_builder, conv2d_options.FusedActivationFunction())
283         return tflite.Conv2DOptions.Conv2DOptionsEnd(new_builder)
284
285     # DepthwiseConv2D option
286     import tflite.DepthwiseConv2DOptions
287     if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions(
288     ).DepthwiseConv2DOptions:
289
290         depthconv2d_option = tflite.DepthwiseConv2DOptions.DepthwiseConv2DOptions()
291         depthconv2d_option.Init(selected_builtin_option.Bytes,
292                                 selected_builtin_option.Pos)
293
294         tflite.DepthwiseConv2DOptions.DepthwiseConv2DOptionsStart(new_builder)
295         tflite.DepthwiseConv2DOptions.DepthwiseConv2DOptionsAddPadding(
296             new_builder, depthconv2d_option.Padding())
297         tflite.DepthwiseConv2DOptions.DepthwiseConv2DOptionsAddStrideW(
298             new_builder, depthconv2d_option.StrideW())
299         tflite.DepthwiseConv2DOptions.DepthwiseConv2DOptionsAddStrideH(
300             new_builder, depthconv2d_option.StrideH())
301         tflite.DepthwiseConv2DOptions.DepthwiseConv2DOptionsAddDepthMultiplier(
302             new_builder, depthconv2d_option.DepthMultiplier())
303         tflite.DepthwiseConv2DOptions.DepthwiseConv2DOptionsAddFusedActivationFunction(
304             new_builder, depthconv2d_option.FusedActivationFunction())
305         tflite.DepthwiseConv2DOptions.DepthwiseConv2DOptionsAddDilationWFactor(
306             new_builder, depthconv2d_option.DilationWFactor())
307         tflite.DepthwiseConv2DOptions.DepthwiseConv2DOptionsAddDilationHFactor(
308             new_builder, depthconv2d_option.DilationHFactor())
309         return tflite.DepthwiseConv2DOptions.DepthwiseConv2DOptionsEnd(new_builder)
310
311     # ConcatEmbeddingsOptions: not supported
312     # LSHProjectionOptions: not supported
313
314     # Pool2DPOption
315     import tflite.Pool2DOptions
316     if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().Pool2DOptions:
317
318         pool2d_option = tflite.Pool2DOptions.Pool2DOptions()
319         pool2d_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
320
321         tflite.Pool2DOptions.Pool2DOptionsStart(new_builder)
322         tflite.Pool2DOptions.Pool2DOptionsAddPadding(new_builder, pool2d_option.Padding())
323         tflite.Pool2DOptions.Pool2DOptionsAddStrideW(new_builder, pool2d_option.StrideW())
324         tflite.Pool2DOptions.Pool2DOptionsAddStrideH(new_builder, pool2d_option.StrideH())
325         tflite.Pool2DOptions.Pool2DOptionsAddFilterWidth(new_builder,
326                                                          pool2d_option.FilterWidth())
327         tflite.Pool2DOptions.Pool2DOptionsAddFilterHeight(new_builder,
328                                                           pool2d_option.FilterHeight())
329         tflite.Pool2DOptions.Pool2DOptionsAddFusedActivationFunction(
330             new_builder, pool2d_option.FusedActivationFunction())
331         return tflite.Pool2DOptions.Pool2DOptionsEnd(new_builder)
332
333     # SVDFOptions: not supported
334
335     # RNNOptions
336     import tflite.RNNOptions
337     if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().RNNOptions:
338
339         rnn_option = tflite.RNNOptions.RNNOptions()
340         rnn_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
341
342         tflite.RNNOptions.RNNOptionsStart(new_builder)
343         tflite.RNNOptions.RNNOptionsAddFusedActivationFunction(
344             new_builder, rnn_option.FusedActivationFunction())
345         return tflite.RNNOptions.RNNOptionsEnd(new_builder)
346
347     # FullyConnectedOptions
348     import tflite.FullyConnectedOptions
349     if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions(
350     ).FullyConnectedOptions:
351
352         fc_option = tflite.FullyConnectedOptions.FullyConnectedOptions()
353         fc_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
354
355         tflite.FullyConnectedOptions.FullyConnectedOptionsStart(new_builder)
356         tflite.FullyConnectedOptions.FullyConnectedOptionsAddFusedActivationFunction(
357             new_builder, fc_option.FusedActivationFunction())
358         return tflite.FullyConnectedOptions.FullyConnectedOptionsEnd(new_builder)
359
360     # SoftmaxOptions
361     import tflite.SoftmaxOptions
362     if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().SoftmaxOptions:
363
364         softmax_option = tflite.SoftmaxOptions.SoftmaxOptions()
365         softmax_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
366
367         tflite.SoftmaxOptions.SoftmaxOptionsStart(new_builder)
368         tflite.SoftmaxOptions.SoftmaxOptionsAddBeta(new_builder, softmax_option.Beta())
369         return tflite.SoftmaxOptions.SoftmaxOptionsEnd(new_builder)
370
371     # ConcatenationOptions
372     import tflite.ConcatenationOptions
373     if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().ConcatenationOptions:
374
375         concat_option = tflite.ConcatenationOptions.ConcatenationOptions()
376         concat_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
377
378         tflite.ConcatenationOptions.ConcatenationOptionsStart(new_builder)
379         tflite.ConcatenationOptions.ConcatenationOptionsAddAxis(
380             new_builder, concat_option.Axis())
381         tflite.ConcatenationOptions.ConcatenationOptionsAddFusedActivationFunction(
382             new_builder, concat_option.FusedActivationFunction())
383         return tflite.ConcatenationOptions.ConcatenationOptionsEnd(new_builder)
384
385     # AddOptions
386     import tflite.AddOptions
387     if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().AddOptions:
388
389         add_option = tflite.AddOptions.AddOptions()
390         add_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
391
392         tflite.AddOptions.AddOptionsStart(new_builder)
393         tflite.AddOptions.AddOptionsAddFusedActivationFunction(
394             new_builder, add_option.FusedActivationFunction())
395         return tflite.AddOptions.AddOptionsEnd(new_builder)
396
397     # L2NormOptions
398     import tflite.L2NormOptions
399     if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().L2NormOptions:
400
401         l2norm_option = tflite.L2NormOptions.L2NormOptions()
402         l2norm_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
403
404         tflite.L2NormOptions.L2NormOptionsStart(new_builder)
405         tflite.L2NormOptions.L2NormOptionsAddFusedActivationFunction(
406             new_builder, l2norm_option.FusedActivationFunction())
407         return tflite.L2NormOptions.L2NormOptionsEnd(new_builder)
408
409     # LocalResponseNormalizationOptions
410     import tflite.LocalResponseNormalizationOptions
411     if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions(
412     ).LocalResponseNormalizationOptions:
413
414         lrn_option = tflite.LocalResponseNormalizationOptions.LocalResponseNormalizationOptions(
415         )
416         lrn_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
417
418         tflite.LocalResponseNormalizationOptions.LocalResponseNormalizationOptionsStart(
419             new_builder)
420         tflite.LocalResponseNormalizationOptions.LocalResponseNormalizationOptionsAddRadius(
421             new_builder, lrn_option.Radius())
422         tflite.LocalResponseNormalizationOptions.LocalResponseNormalizationOptionsAddBias(
423             new_builder, lrn_option.Bias())
424         tflite.LocalResponseNormalizationOptions.LocalResponseNormalizationOptionsAddAlpha(
425             new_builder, lrn_option.Alpha())
426         tflite.LocalResponseNormalizationOptions.LocalResponseNormalizationOptionsAddBeta(
427             new_builder, lrn_option.Beta())
428         return tflite.LocalResponseNormalizationOptions.LocalResponseNormalizationOptionsEnd(
429             new_builder)
430
431     # LSTMOptions: not supported
432
433     # ResizeBilinearOptions
434     import tflite.ResizeBilinearOptions
435     if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions(
436     ).ResizeBilinearOptions:
437
438         resize_bilinear_option = tflite.ResizeBilinearOptions.ResizeBilinearOptions()
439         resize_bilinear_option.Init(selected_builtin_option.Bytes,
440                                     selected_builtin_option.Pos)
441
442         tflite.ResizeBilinearOptions.ResizeBilinearOptionsStart(new_builder)
443         tflite.ResizeBilinearOptions.ResizeBilinearOptionsAddAlignCorners(
444             new_builder, resize_bilinear_option.AlignCorners())
445         return tflite.ResizeBilinearOptions.ResizeBilinearOptionsEnd(new_builder)
446
447     # CallOptions: not supported
448
449     # ReshapeOptions
450     import tflite.ReshapeOptions
451     if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().ReshapeOptions:
452
453         reshape_option = tflite.ReshapeOptions.ReshapeOptions()
454         reshape_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
455
456         shape_num = reshape_option.NewShapeLength()
457         if shape_num != 0:
458             tflite.ReshapeOptions.ReshapeOptionsStartNewShapeVector(
459                 new_builder, shape_num)
460             for new_shape_idx in reversed(range(shape_num)):
461                 new_shape_val = reshape_option.NewShape(new_shape_idx)
462                 new_builder.PrependInt32(new_shape_val)
463             new_shape = new_builder.EndVector(shape_num)
464
465         tflite.ReshapeOptions.ReshapeOptionsStart(new_builder)
466         if shape_num != 0:
467             tflite.ReshapeOptions.ReshapeOptionsAddNewShape(new_builder, new_shape)
468         return tflite.ReshapeOptions.ReshapeOptionsEnd(new_builder)
469
470     # SkipGramOptions: not supported
471
472     # SpaceToDepthOptions
473     import tflite.SpaceToDepthOptions
474     if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().SpaceToDepthOptions:
475
476         space_to_depth_option = tflite.SpaceToDepthOptions.SpaceToDepthOptions()
477         space_to_depth_option.Init(selected_builtin_option.Bytes,
478                                    selected_builtin_option.Pos)
479
480         tflite.SpaceToDepthOptions.SpaceToDepthOptionsStart(new_builder)
481         tflite.SpaceToDepthOptions.SpaceToDepthOptionsAddBlockSize(
482             new_builder, space_to_depth_option.BlockSize())
483         return tflite.SpaceToDepthOptions.SpaceToDepthOptionsEnd(new_builder)
484
485     # EmbeddingLookupSparseOptions: not supported
486
487     # MulOptions
488     import tflite.MulOptions
489     if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().MulOptions:
490
491         mul_option = tflite.MulOptions.MulOptions()
492         mul_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
493
494         tflite.MulOptions.MulOptionsStart(new_builder)
495         tflite.MulOptions.MulOptionsAddFusedActivationFunction(
496             new_builder, mul_option.FusedActivationFunction())
497         return tflite.MulOptions.MulOptionsEnd(new_builder)
498
499     # PadOptions
500     import tflite.PadOptions
501     if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().PadOptions:
502
503         pad_option = tflite.PadOptions.PadOptions()
504         pad_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
505
506         tflite.PadOptions.PadOptionsStart(new_builder)
507         return tflite.PadOptions.PadOptionsEnd(new_builder)
508
509     # GatherOptions
510     import tflite.GatherOptions
511     if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().GatherOptions:
512
513         gather_option = tflite.GatherOptions.GatherOptions()
514         gather_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
515
516         tflite.GatherOptions.GatherOptionsStart(new_builder)
517         tflite.GatherOptions.GatherOptionsAddAxis(new_builder, gather_option.Axis())
518         return tflite.GatherOptions.GatherOptionsEnd(new_builder)
519
520     # BatchToSpaceNDOptions
521     import tflite.BatchToSpaceNDOptions
522     if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions(
523     ).BatchToSpaceNDOptions:
524
525         btsnd_option = tflite.BatchToSpaceNDOptions.BatchToSpaceNDOptions()
526         btsnd_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
527
528         tflite.BatchToSpaceNDOptions.BatchToSpaceNDOptionsStart(new_builder)
529         return tflite.BatchToSpaceNDOptions.BatchToSpaceNDOptionsEnd(new_builder)
530
531     # SpaceToBatchNDOptions
532     import tflite.SpaceToBatchNDOptions
533     if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions(
534     ).SpaceToBatchNDOptions:
535
536         stbnd_option = tflite.SpaceToBatchNDOptions.SpaceToBatchNDOptions()
537         stbnd_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
538
539         tflite.SpaceToBatchNDOptions.SpaceToBatchNDOptionsStart(new_builder)
540         return tflite.SpaceToBatchNDOptions.SpaceToBatchNDOptionsEnd(new_builder)
541
542     # TransposeOptions:
543     import tflite.TransposeOptions
544     if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().TransposeOptions:
545
546         transpose_option = tflite.TransposeOptions.TransposeOptions()
547         transpose_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
548
549         tflite.TransposeOptions.TransposeOptionsStart(new_builder)
550         return tflite.TransposeOptions.TransposeOptionsEnd(new_builder)
551
552     # ReducerOptions
553     import tflite.ReducerOptions
554     if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().ReducerOptions:
555
556         reducer_option = tflite.ReducerOptions.ReducerOptions()
557         reducer_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
558
559         tflite.ReducerOptions.ReducerOptionsStart(new_builder)
560         tflite.ReducerOptions.ReducerOptionsAddKeepDims(new_builder,
561                                                         reducer_option.KeepDims())
562         return tflite.ReducerOptions.ReducerOptionsEnd(new_builder)
563
564     # SubOptions
565     import tflite.SubOptions
566     if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().SubOptions:
567
568         sub_option = tflite.SubOptions.SubOptions()
569         sub_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
570
571         tflite.SubOptions.SubOptionsStart(new_builder)
572         tflite.SubOptions.SubOptionsAddFusedActivationFunction(
573             new_builder, sub_option.FusedActivationFunction())
574         return tflite.SubOptions.SubOptionsEnd(new_builder)
575
576     # DivOptions
577     import tflite.DivOptions
578     if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().DivOptions:
579
580         div_option = tflite.DivOptions.DivOptions()
581         div_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
582
583         tflite.DivOptions.DivOptionsStart(new_builder)
584         tflite.DivOptions.DivOptionsAddFusedActivationFunction(
585             new_builder, div_option.FusedActivationFunction())
586         return tflite.DivOptions.DivOptionsEnd(new_builder)
587
588     # SqueezeOptions
589     import tflite.SqueezeOptions
590     if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().SqueezeOptions:
591
592         squeeze_option = tflite.SqueezeOptions.SqueezeOptions()
593         squeeze_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
594
595         squeeze_dims_num = squeeze_option.SqueezeDimsLength()
596         if squeeze_dims_num != 0:
597             tflite.SqueezeOptions.SqueezeOptionsStartSqueezeDimsVector(
598                 new_builder, squeeze_dims_num)
599             for squeeze_dims_idx in reversed(range(squeeze_dims_num)):
600                 squeeze_dims_val = squeeze_option.SqueezeDims(squeeze_dims_idx)
601                 new_builder.PrependInt32(squeeze_dims_val)
602             new_squeeze_dims = new_builder.EndVector(squeeze_dims_num)
603
604         tflite.SqueezeOptions.SqueezeOptionsStart(new_builder)
605         if squeeze_dims_num != 0:
606             tflite.SqueezeOptions.SqueezeOptionsAddSqueezeDims(new_builder,
607                                                                new_squeeze_dims)
608         return tflite.SqueezeOptions.SqueezeOptionsEnd(new_builder)
609
610     # SequenceRNNOptions: not supported
611
612     # StridedSliceOptions
613     import tflite.StridedSliceOptions
614     if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().StridedSliceOptions:
615
616         stride_slice_option = tflite.StridedSliceOptions.StridedSliceOptions()
617         stride_slice_option.Init(selected_builtin_option.Bytes,
618                                  selected_builtin_option.Pos)
619
620         tflite.StridedSliceOptions.StridedSliceOptionsStart(new_builder)
621         tflite.StridedSliceOptions.StridedSliceOptionsAddBeginMask(
622             new_builder, stride_slice_option.BeginMask())
623         tflite.StridedSliceOptions.StridedSliceOptionsAddEndMask(
624             new_builder, stride_slice_option.EndMask())
625         tflite.StridedSliceOptions.StridedSliceOptionsAddEllipsisMask(
626             new_builder, stride_slice_option.EllipsisMask())
627         tflite.StridedSliceOptions.StridedSliceOptionsAddNewAxisMask(
628             new_builder, stride_slice_option.NewAxisMask())
629         tflite.StridedSliceOptions.StridedSliceOptionsAddShrinkAxisMask(
630             new_builder, stride_slice_option.ShrinkAxisMask())
631
632         return tflite.StridedSliceOptions.StridedSliceOptionsEnd(new_builder)
633
634     # ExpOptions
635     import tflite.ExpOptions
636     if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().ExpOptions:
637
638         exp_option = tflite.ExpOptions.ExpOptions()
639         exp_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
640
641         tflite.ExpOptions.ExpOptionsStart(new_builder)
642         return tflite.ExpOptions.ExpOptionsEnd(new_builder)
643
644     # TopKV2Options
645     import tflite.TopKV2Options
646     if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().TopKV2Options:
647
648         topkv2_option = tflite.TopKV2Options.TopKV2Options()
649         topkv2_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
650
651         tflite.TopKV2Options.TopKV2OptionsStart(new_builder)
652         return tflite.TopKV2Options.TopKV2OptionsEnd(new_builder)
653
654     # SplitOptions
655     import tflite.SplitOptions
656     if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().SplitOptions:
657
658         split_option = tflite.SplitOptions.SplitOptions()
659         split_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
660
661         tflite.SplitOptions.SplitOptionsStart(new_builder)
662         tflite.SplitOptions.SplitOptionsAddNumSplits(new_builder,
663                                                      split_option.NumSplits())
664         return tflite.SplitOptions.SplitOptionsEnd(new_builder)
665
666     # LogSoftmaxOptions: not supported
667
668     # CastOptions: not supported
669     import tflite.CastOptions
670     if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().CastOptions:
671
672         cast_option = tflite.CastOptions.CastOptions()
673         cast_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
674
675         tflite.CastOptions.CastOptionsStart(new_builder)
676         return tflite.CastOptions.CastOptionsEnd(new_builder)
677
678     # DequantizeOptions:
679     import tflite.DequantizeOptions
680     if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().DequantizeOptions:
681
682         dequantize_option = tflite.DequantizeOptions.DequantizeOptions()
683         dequantize_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
684
685         tflite.EqualOptions.DequantizeOptionsStart(new_builder)
686         return tflite.DequantizeOptions.DequantizeOptionsEnd(new_builder)
687
688     # MaximumMinimumOptions: not supported
689
690     # ArgMaxOptions
691     import tflite.ArgMaxOptions
692     if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().ArgMaxOptions:
693
694         arg_max_option = tflite.ArgMaxOptions.ArgMaxOptions()
695         arg_max_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
696
697         tflite.ArgMaxOptions.ArgMaxOptionsStart(new_builder)
698         tflite.ArgMaxOptions.ArgMaxOptionsAddOutputType(new_builder,
699                                                         arg_max_option.OutputType())
700         return tflite.ArgMaxOptions.ArgMaxOptionsEnd(new_builder)
701
702     # LessOptions: not supported
703
704     # NegOptions
705     import tflite.NegOptions
706     if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().NegOptions:
707
708         neg_option = tflite.NegOptions.NegOptions()
709         neg_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
710
711         tflite.NegOptions.NegOptionsStart(new_builder)
712         return tflite.NegOptions.NegOptionsEnd(new_builder)
713
714     # EqualOptions
715     import tflite.EqualOptions
716     if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().EqualOptions:
717
718         equal_option = tflite.EqualOptions.EqualOptions()
719         equal_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
720
721         tflite.EqualOptions.EqualOptionsStart(new_builder)
722         return tflite.EqualOptions.EqualOptionsEnd(new_builder)
723
724     # PadV2Options: not supported
725     # GreaterOptions: not supported
726     # GreaterEqualOptions: not supported
727     # LessEqualOptions: not supported
728     # SelectOptions: not supported
729     # SliceOptions: not supported
730
731     # TransposeConvOptions
732     import tflite.TransposeConvOptions
733     if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().TransposeConvOptions:
734
735         transposeconv_option = tflite.TransposeConvOptions.TransposeConvOptions()
736         transposeconv_option.Init(selected_builtin_option.Bytes,
737                                   selected_builtin_option.Pos)
738
739         tflite.TransposeConvOptions.TransposeConvOptionsStart(new_builder)
740         tflite.TransposeConvOptions.TransposeConvOptionsAddPadding(
741             new_builder, transposeconv_option.Padding())
742         tflite.TransposeConvOptions.TransposeConvOptionsAddStrideW(
743             new_builder, transposeconv_option.StrideW())
744         tflite.TransposeConvOptions.TransposeConvOptionsAddStrideH(
745             new_builder, transposeconv_option.StrideH())
746         return tflite.TransposeConvOptions.TransposeConvOptionsEnd(new_builder)
747
748     # SparseToDenseOptions: not supported
749     # TileOptions: not supported
750
751     # ExpandDimsOptions:
752     import tflite.ExpandDimsOptions
753     if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().ExpandDimsOptions:
754
755         expanddims_option = tflite.ExpandDimsOptions.ExpandDimsOptions()
756         expanddims_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
757
758         tflite.ExpandDimsOptions.ExpandDimsOptionsStart(new_builder)
759         return tflite.ExpandDimsOptions.ExpandDimsOptionsEnd(new_builder)
760
761     # NotEqualOptions:
762     import tflite.NotEqualOptions
763     if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().NotEqualOptions:
764
765         notequal_option = tflite.NotEqualOptions.NotEqualOptions()
766         notequal_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
767
768         tflite.NotEqualOptions.NotEqualOptionsStart(new_builder)
769         return tflite.NotEqualOptions.NotEqualOptionsEnd(new_builder)
770
771     # ShapeOptions:
772     import tflite.ShapeOptions
773     if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().ShapeOptions:
774
775         shape_option = tflite.ShapeOptions.ShapeOptions()
776         shape_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
777
778         tflite.ShapeOptions.ShapeOptionsStart(new_builder)
779         tflite.ShapeOptions.ShapeOptionsAddOutType(new_builder, shape_option.OutType())
780         return tflite.ShapeOptions.ShapeOptionsEnd(new_builder)
781
782     # PowOptions: not supported
783     # ArgMinOptions: not supported
784     # FakeQuantOptions: not supported
785
786     # PackOptions:
787     import tflite.PackOptions
788     if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().PackOptions:
789
790         pack_option = tflite.PackOptions.PackOptions()
791         pack_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
792
793         tflite.PackOptions.PackOptionsStart(new_builder)
794         tflite.PackOptions.PackOptionsAddValuesCount(new_builder,
795                                                      pack_option.ValuesCount())
796         tflite.PackOptions.PackOptionsAddAxis(new_builder, pack_option.Axis())
797         return tflite.PackOptions.PackOptionsEnd(new_builder)
798
799     # LogicalOrOptions:
800     import tflite.LogicalOrOptions
801     if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().LogicalOrOptions:
802
803         logical_or_option = tflite.LogicalAndOptions.LogicalOrOptions()
804         logical_or_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
805
806         tflite.LogicalOrOptions.LogicalOrOptionsStart(new_builder)
807         return tflite.LogicalOrOptions.LogicalOrOptionsEnd(new_builder)
808
809     # OneHotOptions: not supported
810     import tflite.OneHotOptions
811     if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().OneHotOptions:
812
813         one_hot_option = tflite.OneHotOptions.OneHotOptions()
814         one_hot_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
815
816         tflite.OneHotOptions.OneHotOptionsStart(new_builder)
817         tflite.OneHotOptions.OneHotOptionsAddAxis(new_builder, one_hot_option.Axis())
818         return tflite.OneHotOptions.OneHotOptionsEnd(new_builder)
819
820     # LogicalNotOptions
821     import tflite.LogicalNotOptions
822     if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().LogicalNotOptions:
823
824         equal_option = tflite.LogicalNotOptions.LogicalNotOptions()
825         equal_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
826
827         tflite.LogicalNotOptions.LogicalNotOptionsStart(new_builder)
828         return tflite.LogicalNotOptions.LogicalNotOptionsEnd(new_builder)
829
830     # UnpackOptions:
831     import tflite.UnpackOptions
832     if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().UnpackOptions:
833
834         unpack_option = tflite.UnpackOptions.UnpackOptions()
835         unpack_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
836
837         tflite.UnpackOptions.UnpackOptionsStart(new_builder)
838         tflite.UnpackOptions.UnpackOptionsAddNum(new_builder, unpack_option.Num())
839         tflite.UnpackOptions.UnpackOptionsAddAxis(new_builder, unpack_option.Axis())
840         return tflite.UnpackOptions.UnpackOptionsEnd(new_builder)
841
842     # FloorDivOptions: not supported
843     # SquareOptions: not supported
844     # ZerosLikeOptions: not supported
845     # FillOptions: not supported
846
847     # LogicalAndOptions
848     import tflite.LogicalAndOptions
849     if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().LogicalAndOptions:
850
851         logical_and_option = tflite.LogicalAndOptions.LogicalAndOptions()
852         logical_and_option.Init(selected_builtin_option.Bytes,
853                                 selected_builtin_option.Pos)
854
855         tflite.LogicalAndOptions.LogicalAndOptionsStart(new_builder)
856         return tflite.LogicalAndOptions.LogicalAndOptionsEnd(new_builder)
857
858     # LogicalNotOptions: not supported
859     # UnpackOptions: not supported
860     # FloorDivOptions: not supported
861     # SquareOptions: not supported
862     # ZerosLikeOptions: not supported
863     # FillOptions: not supported
864     # BidirectionalSequenceLSTMOptions: not supported
865     # BidirectionalSequenceRNNOptions: not supported
866     # UnidirectionalSequenceLSTMOptions: not supported
867     # FloorModOptions: not supported
868     # RangeOptions: not supported
869     # ResizeNearestNeighborOptions: not supported
870     # LeakyReluOptions: not supported
871
872     # SquaredDifferenceOptions
873     import tflite.SquaredDifferenceOptions
874     if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions(
875     ).SquaredDifferenceOptions:
876
877         squared_difference_option = tflite.SquaredDifferenceOptions.SquaredDifferenceOptions(
878         )
879         squared_difference_option.Init(selected_builtin_option.Bytes,
880                                        selected_builtin_option.Pos)
881
882         tflite.SquaredDifferenceOptions.SquaredDifferenceOptionsStart(new_builder)
883         return tflite.SquaredDifferenceOptions.SquaredDifferenceOptionsEnd(new_builder)
884
885     # MirrorPadOptions: not supported
886     # AbsOptions: not supported
887     # SplitVOptions: not supported
888
889     # IfOptions
890     import tflite.IfOptions
891     if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().IfOptions:
892
893         if_option = tflite.IfOptions.IfOptions()
894         if_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
895
896         tflite.IfOptions.IfOptionsStart(new_builder)
897         tflite.IfOptions.IfOptionsAddElseSubgraphIndex(
898             new_builder, used_subgraphs_dic[if_option.ElseSubgraphIndex()])
899         tflite.IfOptions.IfOptionsAddThenSubgraphIndex(
900             new_builder, used_subgraphs_dic[if_option.ThenSubgraphIndex()])
901         return tflite.IfOptions.IfOptionsEnd(new_builder)
902
903     # WhileOptions
904     import tflite.WhileOptions
905     if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().WhileOptions:
906
907         while_option = tflite.WhileOptions.WhileOptions()
908         while_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
909
910         tflite.WhileOptions.WhileOptionsStart(new_builder)
911         tflite.WhileOptions.WhileOptionsAddBodySubgraphIndex(
912             new_builder, used_subgraphs_dic[while_option.BodySubgraphIndex()])
913         tflite.WhileOptions.WhileOptionsAddCondSubgraphIndex(
914             new_builder, used_subgraphs_dic[while_option.CondSubgraphIndex()])
915         return tflite.WhileOptions.WhileOptionsEnd(new_builder)
916
917     # Cannot handle builtin option type yet
918     print("Cannot handle this option yet")
919     exit(1)
920
921
922 def GenerateOperator(new_builder, selected_operator, used_tensors_dic, used_opcodes_dic,
923                      used_subgraphs_dic):
924
925     # define opcode_index
926     opcode_index = selected_operator.OpcodeIndex()
927     new_opcode_index = used_opcodes_dic[opcode_index]
928
929     # create input vector
930     input_num = selected_operator.InputsLength()
931     if input_num != 0:
932         tflite.Operator.OperatorStartInputsVector(new_builder, input_num)
933         for input_idx in reversed(range(input_num)):
934             input_tensor_idx = selected_operator.Inputs(input_idx)
935             if input_tensor_idx == -1:
936                 new_input_tensor_idx = -1
937             else:
938                 new_input_tensor_idx = used_tensors_dic[input_tensor_idx]
939             new_builder.PrependInt32(new_input_tensor_idx)
940         new_input = new_builder.EndVector(input_num)
941
942     # create output_vector
943     output_num = selected_operator.OutputsLength()
944     if output_num != 0:
945         tflite.Operator.OperatorStartOutputsVector(new_builder, output_num)
946         for output_idx in reversed(range(output_num)):
947             output_tensor_idx = selected_operator.Outputs(output_idx)
948             new_output_tensor_idx = used_tensors_dic[output_tensor_idx]
949             new_builder.PrependInt32(new_output_tensor_idx)
950         new_output = new_builder.EndVector(output_num)
951
952     # Create builtin_option
953     builtin_option_type = selected_operator.BuiltinOptionsType()
954     if builtin_option_type != 0:
955         selected_builtin_option = selected_operator.BuiltinOptions()
956         new_builtin_option = GenerateBuiltinOption(
957             new_builder, selected_builtin_option, builtin_option_type, used_subgraphs_dic)
958
959     # Create custum option vector
960     custom_option_num = selected_operator.CustomOptionsLength()
961     if custom_option_num != 0:
962         tflite.Operator.OperatorStartCustomOptionsVector(new_builder, custom_option_num)
963         for custom_option_idx in reversed(range(custom_option_num)):
964             new_builder.PrependUint8(selected_operator.CustomOptions(custom_option_idx))
965         new_custom_option = new_builder.EndVector(custom_option_num)
966
967     # Create custum option type
968     custom_option_type = selected_operator.CustomOptionsFormat()
969
970     # Create operator
971     tflite.Operator.OperatorStart(new_builder)
972     tflite.Operator.OperatorAddOpcodeIndex(new_builder, new_opcode_index)
973     if input_num != 0:
974         tflite.Operator.OperatorAddInputs(new_builder, new_input)
975     if output_num != 0:
976         tflite.Operator.OperatorAddOutputs(new_builder, new_output)
977     tflite.Operator.OperatorAddBuiltinOptionsType(new_builder, builtin_option_type)
978     if builtin_option_type != 0:
979         tflite.Operator.OperatorAddBuiltinOptions(new_builder, new_builtin_option)
980     if custom_option_num != 0:
981         tflite.Operator.OperatorAddCustomOptions(new_builder, new_custom_option)
982     tflite.Operator.OperatorAddCustomOptionsFormat(new_builder, custom_option_type)
983     return tflite.Operator.OperatorEnd(new_builder)
984
985
986 def GenerateOperators(new_builder, selected_subgraph, operator_list, used_tensors_dic,
987                       used_opcodes_dic, used_subgraphs_dic):
988     operator_num = selected_subgraph.OperatorsLength()
989     new_operator_list = []
990
991     if operator_num == 0:
992         return 0
993
994     for operator_idx in range(operator_num):
995         if operator_idx in operator_list:
996             selected_operator = selected_subgraph.Operators(operator_idx)
997             new_operator = GenerateOperator(new_builder, selected_operator,
998                                             used_tensors_dic, used_opcodes_dic,
999                                             used_subgraphs_dic)
1000             new_operator_list.append(new_operator)
1001
1002     new_operator_num = len(new_operator_list)
1003     if new_operator_num == 0:
1004         return 0
1005
1006     tflite.SubGraph.SubGraphStartOperatorsVector(new_builder, new_operator_num)
1007     for new_operator in reversed(new_operator_list):
1008         new_builder.PrependUOffsetTRelative(new_operator)
1009
1010     return new_builder.EndVector(new_operator_num)
1011
1012
1013 def GenerateSubgraph(new_builder, selected_subgraph, operator_list, new_input_tensor,
1014                      new_output_tensor, used_tensors_dic, used_buffers_dic,
1015                      used_opcodes_dic, used_subgraphs_dic):
1016
1017     # Tensors
1018     tensors = GenerateTensors(new_builder, selected_subgraph, used_tensors_dic,
1019                               used_buffers_dic)
1020
1021     # Create input vector for subgraph table
1022     new_input_tensor_num = len(new_input_tensor)
1023     if new_input_tensor_num != 0:
1024         tflite.SubGraph.SubGraphStartInputsVector(new_builder, new_input_tensor_num)
1025         for input_tensor_idx in reversed(new_input_tensor):
1026             new_input_tensor_idx = used_tensors_dic[input_tensor_idx]
1027             new_builder.PrependInt32(new_input_tensor_idx)
1028         new_inputs = new_builder.EndVector(new_input_tensor_num)
1029
1030     # Create output vector for subgraph table
1031     new_output_tensor_num = len(new_output_tensor)
1032     if new_output_tensor_num != 0:
1033         tflite.SubGraph.SubGraphStartOutputsVector(new_builder, new_output_tensor_num)
1034         for output_tensor_idx in reversed(new_output_tensor):
1035             new_output_tensor_idx = used_tensors_dic[output_tensor_idx]
1036             new_builder.PrependInt32(new_output_tensor_idx)
1037         new_outputs = new_builder.EndVector(new_output_tensor_num)
1038
1039     # Operators
1040     operators = GenerateOperators(new_builder, selected_subgraph, operator_list,
1041                                   used_tensors_dic, used_opcodes_dic, used_subgraphs_dic)
1042
1043     # Name
1044     subgraph_name = selected_subgraph.Name()
1045     have_name = False
1046     if subgraph_name and subgraph_name != "":
1047         have_name = True
1048         new_subgraph_name = new_builder.CreateString(subgraph_name)
1049
1050     tflite.SubGraph.SubGraphStart(new_builder)
1051     tflite.SubGraph.SubGraphAddTensors(new_builder, tensors)
1052     if new_input_tensor_num != 0:
1053         tflite.SubGraph.SubGraphAddInputs(new_builder, new_inputs)
1054     if new_output_tensor_num != 0:
1055         tflite.SubGraph.SubGraphAddOutputs(new_builder, new_outputs)
1056     tflite.SubGraph.SubGraphAddOperators(new_builder, operators)
1057     if have_name:
1058         tflite.SubGraph.SubGraphAddName(new_builder, new_subgraph_name)
1059
1060     return tflite.SubGraph.SubGraphEnd(new_builder)
1061
1062
1063 def GenerateSubgraphs(args, new_builder, sample_model, operator_list, new_input_tensor,
1064                       new_output_tensor, used_tensors_dic, used_buffers_dic,
1065                       used_opcodes_dic, used_subgraphs_dic):
1066
1067     new_subgraph_list = []
1068
1069     # The selected subgraph will be primary subgraph of the model to be created newly
1070     selected_subgraph = sample_model.Subgraphs(args.subgraph)
1071
1072     # k: old subg index, v: new subg index
1073     # new subg index is sequential in used_subgraphs_dic
1074     for k, v in used_subgraphs_dic.items():
1075         print("Append subgraphs, old index : ", k, ", new index : ", v)
1076         if k == args.subgraph:
1077             assert v == 0
1078             new_subgraph = GenerateSubgraph(new_builder, selected_subgraph, operator_list,
1079                                             new_input_tensor, new_output_tensor,
1080                                             used_tensors_dic, used_buffers_dic,
1081                                             used_opcodes_dic, used_subgraphs_dic)
1082             new_subgraph_list.append(new_subgraph)
1083         else:
1084             subg = sample_model.Subgraphs(k)
1085             subg_opperator_idx_list = range(subg.OperatorsLength())
1086             subg_input_tensors = subg.InputsAsNumpy()
1087             subg_output_tensors = subg.OutputsAsNumpy()
1088             subg_tensors = range(subg.TensorsLength())
1089             subg_tensors_dic = {tensor_idx: tensor_idx for tensor_idx in subg_tensors}
1090             subg_buffers_dic = {(subg.Tensors(idx)).Buffer():
1091                                 (subg.Tensors(idx)).Buffer()
1092                                 for idx in subg_tensors}
1093             new_subgraph = GenerateSubgraph(new_builder, subg, subg_opperator_idx_list,
1094                                             subg_input_tensors, subg_output_tensors,
1095                                             subg_tensors_dic, subg_buffers_dic,
1096                                             used_opcodes_dic, used_subgraphs_dic)
1097             new_subgraph_list.append(new_subgraph)
1098
1099     new_subgraph_num = len(new_subgraph_list)
1100     tflite.Model.ModelStartSubgraphsVector(new_builder, new_subgraph_num)
1101     for subgraph_idx in reversed(range(new_subgraph_num)):
1102         new_builder.PrependUOffsetTRelative(new_subgraph_list[subgraph_idx])
1103
1104     return new_builder.EndVector(new_subgraph_num)
1105
1106
1107 def GenerateBuffers(new_builder, sample_model, used_buffers_dic):
1108     buffer_num = sample_model.BuffersLength()
1109     new_buffer_data_list = {}
1110     new_buffer_list = []
1111
1112     if buffer_num == 0:
1113         return 0
1114
1115     # Create data vector for buffer table
1116     for buffer_idx in range(buffer_num):
1117         buffer = sample_model.Buffers(buffer_idx)
1118         buffer_length = buffer.DataLength()
1119
1120         if (buffer_length != 0) and (buffer_idx in used_buffers_dic):
1121             tflite.Buffer.BufferStartDataVector(new_builder, buffer_length)
1122             for buffer_data_idx in reversed(range(buffer_length)):
1123                 new_builder.PrependUint8(buffer.Data(buffer_data_idx))
1124             new_buffer = new_builder.EndVector(buffer_length)
1125             new_buffer_data_list[buffer_idx] = new_buffer
1126
1127     # Create tables of buffer
1128     for buffer_idx in range(buffer_num):
1129         buffer = sample_model.Buffers(buffer_idx)
1130
1131         if buffer_idx in used_buffers_dic:
1132             # Create buffer table
1133             tflite.Buffer.BufferStart(new_builder)
1134             if buffer.DataLength() != 0:
1135                 tflite.Buffer.BufferAddData(new_builder, new_buffer_data_list[buffer_idx])
1136             new_buffer = tflite.Buffer.BufferEnd(new_builder)
1137             new_buffer_list.append(new_buffer)
1138
1139     # Create buffer vector
1140     new_buffer_num = len(new_buffer_list)
1141     if new_buffer_num == 0:
1142         return 0
1143
1144     tflite.Model.ModelStartBuffersVector(new_builder, new_buffer_num)
1145     for new_buffer_idx in reversed(range(new_buffer_num)):
1146         new_builder.PrependUOffsetTRelative(new_buffer_list[new_buffer_idx])
1147
1148     return new_builder.EndVector(new_buffer_num)
1149
1150
1151 def GenerateModel(args, new_builder, sample_model, operator_list, new_input_tensors,
1152                   new_output_tensors, used_tensors_dic, used_buffers_dic,
1153                   used_opcodes_dic, used_subgraphs_dic):
1154     # uint
1155     version = sample_model.Version()
1156
1157     # pointer of operator code 'table' vector
1158     operator_codes = GenerateOperatorCodes(new_builder, sample_model, used_opcodes_dic,
1159                                            used_subgraphs_dic)
1160
1161     # subgraphs
1162     subgraphs = GenerateSubgraphs(args, new_builder, sample_model, operator_list,
1163                                   new_input_tensors, new_output_tensors, used_tensors_dic,
1164                                   used_buffers_dic, used_opcodes_dic, used_subgraphs_dic)
1165
1166     # description
1167     description_string = new_builder.CreateString(sample_model.Description())
1168
1169     # buffers
1170     buffers = GenerateBuffers(new_builder, sample_model, used_buffers_dic)
1171
1172     # Generate model
1173     tflite.Model.ModelStart(new_builder)
1174     tflite.Model.ModelAddVersion(new_builder, version)
1175     tflite.Model.ModelAddOperatorCodes(new_builder, operator_codes)
1176     tflite.Model.ModelAddSubgraphs(new_builder, subgraphs)
1177     tflite.Model.ModelAddDescription(new_builder, description_string)
1178     tflite.Model.ModelAddBuffers(new_builder, buffers)
1179
1180     return tflite.Model.ModelEnd(new_builder)
1181
1182
1183 def main(args):
1184     input_model_file = args.input_model
1185     oplist_file = args.opcode_list
1186     output_model_file = args.output_model
1187     subgraph = args.subgraph
1188
1189     # Parse operator list file
1190     operator_list = GetOperatorList(oplist_file)
1191
1192     # Get sample model and subgraph
1193     # We use only 1st subgraph
1194     sample_buf = input_model_file.read()
1195     sample_buf = bytearray(sample_buf)
1196     sample_model = tflite.Model.Model.GetRootAsModel(sample_buf, 0)
1197     sample_subgraph = sample_model.Subgraphs(subgraph)
1198
1199     used_subgraphs_list = []
1200     used_subgraphs_list.append(args.subgraph)
1201     GetUsedSubgraphsList(sample_model, args.subgraph, operator_list, used_subgraphs_list)
1202
1203     used_subgraphs_dic = {}
1204     for new_subgraph_idx in range(len(used_subgraphs_list)):
1205         sample_subgraph_idx = used_subgraphs_list[new_subgraph_idx]
1206         used_subgraphs_dic[sample_subgraph_idx] = new_subgraph_idx
1207
1208     # Collect used tensor & used operator
1209     used_tensors = []
1210     used_opcodes = []
1211
1212     for operator_idx in operator_list:
1213         operator = sample_subgraph.Operators(operator_idx)
1214         for input_idx in range(operator.InputsLength()):
1215             input_tensor_idx = operator.Inputs(input_idx)
1216             if not input_tensor_idx == -1 and not input_tensor_idx in used_tensors:
1217                 # default: same as input sample
1218                 used_tensors.append(input_tensor_idx)
1219
1220         for output_idx in range(operator.OutputsLength()):
1221             output_tensor_idx = operator.Outputs(output_idx)
1222             if not output_tensor_idx in used_tensors:
1223                 # default: same as input sample
1224                 used_tensors.append(output_tensor_idx)
1225
1226         opcode_idx = operator.OpcodeIndex()
1227         if not opcode_idx in used_opcodes:
1228             used_opcodes.append(opcode_idx)
1229
1230     # Append opcodes of child subgraphs
1231     for subgraph_idx in used_subgraphs_list:
1232         if subgraph_idx == subgraph:
1233             continue
1234         for operator_idx in range(sample_model.Subgraphs(subgraph_idx).OperatorsLength()):
1235             operator = sample_model.Subgraphs(subgraph_idx).Operators(operator_idx)
1236             opcode_idx = operator.OpcodeIndex()
1237             if not opcode_idx in used_opcodes:
1238                 used_opcodes.append(opcode_idx)
1239
1240     used_tensors.sort()
1241     used_opcodes.sort()
1242
1243     # Collect used buffer
1244     # buffer[0] should be blank. So it should start from 1
1245     used_buffers = [0]
1246
1247     for used_tensor in used_tensors:
1248         # key and value is same in prepare phase
1249         buf_idx = (sample_subgraph.Tensors(used_tensor)).Buffer()
1250         used_buffers.append(buf_idx)
1251
1252     # Append buffers of tensors of child subgraphs
1253     for subgraph_idx in used_subgraphs_list:
1254         if subgraph_idx == subgraph:
1255             continue
1256         for tensor_idx in range(sample_model.Subgraphs(subgraph_idx).TensorsLength()):
1257             tensor = sample_model.Subgraphs(subgraph_idx).Tensors(tensor_idx)
1258             used_buffers.append(tensor.Buffer())
1259
1260     used_buffers.sort()
1261
1262     # Assign new index for operator
1263     used_opcodes_dic = {}
1264
1265     for new_operator_idx in range(len(used_opcodes)):
1266         sample_operator_idx = used_opcodes[new_operator_idx]
1267         used_opcodes_dic[sample_operator_idx] = new_operator_idx
1268
1269     # Assign new index for tensor
1270     used_tensors_dic = {}
1271
1272     for new_tensor_idx in range(len(used_tensors)):
1273         sample_tensor_idx = used_tensors[new_tensor_idx]
1274         used_tensors_dic[sample_tensor_idx] = new_tensor_idx
1275
1276     # Assign new index for buffer
1277     used_buffers_dic = {}
1278
1279     for new_buffer_idx in range(len(used_buffers)):
1280         sample_buffer_idx = used_buffers[new_buffer_idx]
1281         used_buffers_dic[sample_buffer_idx] = new_buffer_idx
1282
1283     # Find input & output tensor in new model
1284     new_input_tensors = used_tensors[:]
1285     new_output_tensors = used_tensors[:]
1286
1287     for operator_idx in operator_list:
1288         operator = sample_subgraph.Operators(operator_idx)
1289         for input_idx in range(operator.InputsLength()):
1290             input_tensor_idx = operator.Inputs(input_idx)
1291             if input_tensor_idx == -1:
1292                 continue
1293             if input_tensor_idx in new_output_tensors:
1294                 new_output_tensors.remove(input_tensor_idx)
1295             if input_tensor_idx in new_input_tensors:
1296                 matched_buffer_idx = sample_subgraph.Tensors(input_tensor_idx).Buffer()
1297                 matched_buffer = sample_model.Buffers(matched_buffer_idx)
1298                 if matched_buffer.DataLength() != 0:
1299                     new_input_tensors.remove(input_tensor_idx)
1300
1301         for output_idx in range(operator.OutputsLength()):
1302             output_tensor_idx = operator.Outputs(output_idx)
1303             if output_tensor_idx in new_input_tensors:
1304                 new_input_tensors.remove(output_tensor_idx)
1305             if output_tensor_idx in new_output_tensors:
1306                 matched_buffer_idx = sample_subgraph.Tensors(output_tensor_idx).Buffer()
1307                 matched_buffer = sample_model.Buffers(matched_buffer_idx)
1308                 if matched_buffer.DataLength() != 0:
1309                     new_output_tensors.remove(input_tensor_idx)
1310
1311     new_input_tensors_newidx = []
1312     new_output_tensors_newidx = []
1313
1314     for input_tensor_idx in new_input_tensors:
1315         new_input_tensors_newidx.append(used_tensors_dic[input_tensor_idx])
1316     for output_tensor_idx in new_output_tensors:
1317         new_output_tensors_newidx.append(used_tensors_dic[output_tensor_idx])
1318
1319     print("Input tensor(s): " + str(new_input_tensors_newidx))
1320     print("Output tensor(s): " + str(new_output_tensors_newidx))
1321
1322     # Create new model file
1323     new_builder = flatbuffers.Builder(1024)
1324
1325     new_model = GenerateModel(args, new_builder, sample_model, operator_list,
1326                               new_input_tensors, new_output_tensors, used_tensors_dic,
1327                               used_buffers_dic, used_opcodes_dic, used_subgraphs_dic)
1328
1329     new_builder.Finish(new_model, file_identifier=b'TFL3')
1330     new_buf = new_builder.Output()
1331
1332     output_model_file.write(new_buf)
1333
1334
1335 if __name__ == '__main__':
1336     # Define argument and read
1337     arg_parser = argparse.ArgumentParser()
1338     arg_parser.add_argument(
1339         "input_model",
1340         type=argparse.FileType('rb'),
1341         help="input tflite model file to read")
1342     arg_parser.add_argument(
1343         "opcode_list",
1344         type=argparse.FileType('r'),
1345         help="text file including selected operator list")
1346     arg_parser.add_argument(
1347         "output_model", type=argparse.FileType('wb'), help="output tflite model file")
1348     arg_parser.add_argument(
1349         '-g', '--subgraph', type=int, default=0, help="subgraph to use (default: 0)")
1350
1351     # TODO
1352     #   Select multiple subgraph
1353     #   Select subgraph by using opcode list file
1354     #   Select opcode list by using argument
1355
1356     args = arg_parser.parse_args()
1357
1358     # Call main function
1359     main(args)