8ad444d95791c2abfdc626d9d5e4990b665c6268
[platform/core/ml/nnfw.git] / nnpackage / schema / circle_schema.fbs
1 // Copyright (c) 2019~2022 Samsung Electronics Co., Ltd. All Rights Reserved
2 // Copyright 2017 The TensorFlow Authors. All Rights Reserved.
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //     http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15
16 // Revision History
17 //
18 // Version Major.Minor
19 //
20 // Major version is schema version.
21 // We keep schema version if it is compatible
22 // Minor version is for human communication
23 // It will not be stored in circle model.
24 //
25 // Version 0.0: Initial version. Based on TensorFlow Lite v1.13.1 schema.
26 // Version 0.1: Based on TF v2.2-rc2 + more (from TensorFlow `56d281c`)
27 //              `BATCH_MATMUL` operator, `FLOAT64` tensor type,
28 //              `asymmetric_quantize_inputs` for several operator options
29 // Version 0.2: BCQ_GATHER and BCQ_FULLY_CONNECTED are added.
30 // Version 0.3: SHUFFLED16x1FLOAT32 is added.
31 // Version 0.4: Base up to TensorFlow Lite v2.7.0 schema.
32
33 namespace circle;
34
35 // This corresponds to the version.
36 file_identifier "CIR0";
37 // File extension of any written files.
38 file_extension "circle";
39
40 // IMPORTANT: All new members of tables, enums and unions must be added at the
41 // end to ensure backwards compatibility.
42
43 // The type of data stored in a tensor.
44 enum TensorType : byte {
45   FLOAT32 = 0,
46   FLOAT16 = 1,
47   INT32 = 2,
48   UINT8 = 3,
49   INT64 = 4,
50   STRING = 5,
51   BOOL = 6,
52   INT16 = 7,
53   COMPLEX64 = 8,
54   INT8 = 9,
55   FLOAT64 = 10,
56   COMPLEX128 = 11,
57   UINT64 = 12,
58   // Experimental: Resource and variant types are experimental, that are subject
59   // to change. Do not implement custom kernels using resource & variant types
60   // now.
61   RESOURCE = 13,
62   VARIANT = 14,
63   UINT32 = 15,
64 }
65
66 // Custom quantization parameters for experimenting with new quantization
67 // techniques.
68 table CustomQuantization {
69   custom:[ubyte] (force_align: 16);
70 }
71
72 // Represents a specific quantization technique's parameters.
73 union QuantizationDetails {
74   CustomQuantization,
75 }
76
77 // Parameters for converting a quantized tensor back to float.
78 table QuantizationParameters {
79   // These four parameters are the asymmetric linear quantization parameters.
80   // Given a quantized value q, the corresponding float value f should be:
81   //   f = scale * (q - zero_point)
82   // For other quantization types, the QuantizationDetails below is used.
83   min:[float];  // For importing back into tensorflow.
84   max:[float];  // For importing back into tensorflow.
85   scale:[float];  // For dequantizing the tensor's values.
86   zero_point:[long];
87
88   // If this is not none, the other quantization parameters (i.e. min, max,
89   // scale, zero_point fields above) are ignored and the value of the
90   // QuantizationDetails union should be used.
91   details:QuantizationDetails;
92
93   // Specifies the dimension of the Tensor's shape that the scales and
94   // zero_points correspond to. For example, a tensor t, with dims=[4, 3, 2, 1]
95   // with quantization params:
96   //   scale=[1.0, 2.0, 3.0], zero_point=[1, 2, 3], quantization_dimension=1
97   // will be quantized across the second dimension of t.
98   //   t[:, 0, :, :] will have scale[0]=1.0, zero_point[0]=1
99   //   t[:, 1, :, :] will have scale[1]=2.0, zero_point[0]=2
100   //   t[:, 2, :, :] will have scale[2]=3.0, zero_point[0]=3
101   quantized_dimension:int;
102 }
103
104 // Sparse tensors.
105 // We use a modification of the TACO format.
106 // Reference: http://tensor-compiler.org/kjolstad-oopsla17-tensor-compiler.pdf
107 //
108 // To encode a conceptual n-dimensional dense tensor with dims (d0, ..., dn-1),
109 // potentially with a k-dimensional block (0 <= k <= n) with dims
110 // (dn, ..., dn+k-1), the format needs to specify:
111 //   1. In what order to traverse these dimensions. For example, to store a 2-D
112 //      matrix in row major order, the traversal order would be (d0, d1),
113 //      whereas to store it in column major order, the traversal order would be
114 //      (d1, d0). If the 2-D matrix has a 2-D inner block, the traversal order
115 //      could be (d0, d1, d2, d3).
116 //   2. How each block dimension in (dn, ..., dn+k-1) maps to the original
117 //      tensor dimension in (d0, ..., dn-1).
118 //   3. In the traversal order defined above, the format (dense vs. sparse) and
119 //      index metadata for each dimension. For a dense dimension, this is just
120 //      the size of that dimension. For a sparse dimension, it's the same as
121 //      the compressed index defined in the Compressed Sparse Row (CSR) format.
122 //      (http://scipy-lectures.org/advanced/scipy_sparse/csr_matrix.html)
123
124 // The storage type for a dimension. Currently we support:
125 //   1. DENSE: each coordinate in this dimension is stored implicitly.
126 //   2. SPARSE_CSR: only the coordinates with non-zero elements are stored. The
127 //      compression technique is the same what CSR uses.
128 // More types like a sparse dimension with a different compression technique
129 // could be added to the list in the future.
130 enum DimensionType : byte {
131   DENSE = 0,
132   SPARSE_CSR = 1,
133 }
134
135 table Int32Vector {
136   values:[int];
137 }
138
139 table Uint16Vector {
140   values:[ushort] (force_align: 4);
141 }
142
143 table Uint8Vector {
144   values:[ubyte] (force_align: 4);
145 }
146
147 // Variable-typed buffer to store the index metadata for a sparse dimension.
148 // The widest type is Int32 instead of UInt32 because tensor's shape is a int32
149 // vector. We don't want the per-dimensional index to overflow that range.
150 union SparseIndexVector {
151   Int32Vector,
152   Uint16Vector,
153   Uint8Vector
154 }
155
156 table DimensionMetadata {
157   // Whether a dimension is dense or sparse.
158   format:DimensionType;
159   // Index metadata used for a dimension.
160   //   - If format is DimensionType.DENSE then we use the dense_size field to
161   //     store the size of that dimension. Each index in that dimension is
162   //     stored implicitly.
163   //   - If format is DimensionType.SPARSE_CSR then we use array_segments and
164   //     array_indices to encode that dimension. array_segments represents how
165   //     to segment the indices array, each segment corresponds to one element
166   //     in the previous dimension. array_indices represents the index of the
167   //     non-zero elements within this dimension (as those in the CSR matrix
168   //     format, where the first array is row pointers and the second array is
169   //     column indices).
170   dense_size:int;
171   array_segments:SparseIndexVector;
172   array_indices:SparseIndexVector;
173 }
174
175 // Parameters to encode a sparse TfLite tensor.
176 table SparsityParameters {
177   // The traversal order of the dimensions defined in the `shape` field of the
178   // conceptual dense tensor. For a n-dimensional tensors with dims (d0, d1,
179   // ..., dn-1),
180   //   - if not block sparse, the traversal_order is just a permutation of (d0,
181   //     ..., dn-1). For example, a 2-D matrix stored in row-major order would
182   //     have traversal_order = (d0, d1).
183   //   - if block sparse with a k-dimensional block (0 <= k <= n), the
184   //     traversal_order has n + k elements. The first n elements are still a
185   //     permutation of (d0, ..., dn-1). The lask k elements are a permutation
186   //     of (dn, ..., dn+k-1), defining how to traverse a block internally. For
187   //     example, a 2-D matrix with 2-D blocks, both stored in row-major order
188   //     would have traversal_order = (d0, d1, d2, d3).
189   traversal_order:[int];
190   // For an n-dimensional tensor with a k-dimensional block (0 <= k <= n),
191   // stores how a block dimension in (dn, ..., dn+k-1) maps to the original
192   // tensor dimension in (d0, ..., dn).
193   // It's stored in the order of (dn, ..., dn+k-1).
194   // If not block-sparse, this field is NULL.
195   block_map:[int];
196   // In the traversal order defined above, the metadata needed for
197   // each dimension to locate the non-zero values in the original dense tensor.
198   // The size of the dim_metadata array = the size of the traversal_order array
199   // = n + k.
200   dim_metadata:[DimensionMetadata];
201 }
202
203 table Tensor {
204   // The tensor shape. The meaning of each entry is operator-specific but
205   // builtin ops use: [batch size, height, width, number of channels] (That's
206   // Tensorflow's NHWC).
207   shape:[int];
208   type:TensorType;
209   // An index that refers to the buffers table at the root of the model. Or,
210   // if there is no data buffer associated (i.e. intermediate results), then
211   // this is 0 (which refers to an always existent empty buffer).
212   //
213   // The data_buffer itself is an opaque container, with the assumption that the
214   // target device is little-endian. In addition, all builtin operators assume
215   // the memory is ordered such that if `shape` is [4, 3, 2], then index
216   // [i, j, k] maps to data_buffer[i*3*2 + j*2 + k].
217   buffer:uint;
218   name:string;  // For debugging and importing back into tensorflow.
219   quantization:QuantizationParameters;  // Optional.
220
221   is_variable:bool = false;
222
223   // Parameters to encode a sparse tensor. See the example in
224   // tensorflow/lite/testdata/sparse_tensor.json.
225   sparsity:SparsityParameters;  // Optional.
226
227   // Encodes `shape` with unknown dimensions. Unknown dimensions are
228   // represented with -1.
229   shape_signature:[int]; // Optional.
230 }
231
232 // A list of builtin operators. Builtin operators are slightly faster than custom
233 // ones, but not by much. Moreover, while custom operators accept an opaque
234 // object containing configuration parameters, builtins have a predetermined
235 // set of acceptable options.
236 // LINT.IfChange
237 enum BuiltinOperator : int32 {
238   BCQ_GATHER = -4,
239   BCQ_FULLY_CONNECTED = -3,
240   INSTANCE_NORM = -2,
241   ADD = 0,
242   AVERAGE_POOL_2D = 1,
243   CONCATENATION = 2,
244   CONV_2D = 3,
245   DEPTHWISE_CONV_2D = 4,
246   DEPTH_TO_SPACE = 5,
247   DEQUANTIZE = 6,
248   EMBEDDING_LOOKUP = 7,
249   FLOOR = 8,
250   FULLY_CONNECTED = 9,
251   HASHTABLE_LOOKUP = 10,
252   L2_NORMALIZATION = 11,
253   L2_POOL_2D = 12,
254   LOCAL_RESPONSE_NORMALIZATION = 13,
255   LOGISTIC = 14,
256   LSH_PROJECTION = 15,
257   LSTM = 16,
258   MAX_POOL_2D = 17,
259   MUL = 18,
260   RELU = 19,
261   // NOTE(aselle): RELU_N1_TO_1 used to be called RELU1, but it was renamed
262   // since different model developers use RELU1 in different ways. Never
263   // create another op called RELU1.
264   RELU_N1_TO_1 = 20,
265   RELU6 = 21,
266   RESHAPE = 22,
267   RESIZE_BILINEAR = 23,
268   RNN = 24,
269   SOFTMAX = 25,
270   SPACE_TO_DEPTH = 26,
271   SVDF = 27,
272   TANH = 28,
273   CONCAT_EMBEDDINGS = 29,
274   SKIP_GRAM = 30,
275   CALL = 31,
276   CUSTOM = 32,
277   EMBEDDING_LOOKUP_SPARSE = 33,
278   PAD = 34,
279   UNIDIRECTIONAL_SEQUENCE_RNN = 35,
280   GATHER = 36,
281   BATCH_TO_SPACE_ND = 37,
282   SPACE_TO_BATCH_ND = 38,
283   TRANSPOSE = 39,
284   MEAN = 40,
285   SUB = 41,
286   DIV = 42,
287   SQUEEZE = 43,
288   UNIDIRECTIONAL_SEQUENCE_LSTM = 44,
289   STRIDED_SLICE = 45,
290   BIDIRECTIONAL_SEQUENCE_RNN = 46,
291   EXP = 47,
292   TOPK_V2 = 48,
293   SPLIT = 49,
294   LOG_SOFTMAX = 50,
295   // DELEGATE is a special op type for the operations which are delegated to
296   // other backends.
297   // WARNING: Experimental interface, subject to change
298   DELEGATE = 51,
299   BIDIRECTIONAL_SEQUENCE_LSTM = 52,
300   CAST = 53,
301   PRELU = 54,
302   MAXIMUM = 55,
303   ARG_MAX = 56,
304   MINIMUM = 57,
305   LESS = 58,
306   NEG = 59,
307   PADV2 = 60,
308   GREATER = 61,
309   GREATER_EQUAL = 62,
310   LESS_EQUAL = 63,
311   SELECT = 64,
312   SLICE = 65,
313   SIN = 66,
314   TRANSPOSE_CONV = 67,
315   SPARSE_TO_DENSE = 68,
316   TILE = 69,
317   EXPAND_DIMS = 70,
318   EQUAL = 71,
319   NOT_EQUAL = 72,
320   LOG = 73,
321   SUM = 74,
322   SQRT = 75,
323   RSQRT = 76,
324   SHAPE = 77,
325   POW = 78,
326   ARG_MIN = 79,
327   FAKE_QUANT = 80,
328   REDUCE_PROD = 81,
329   REDUCE_MAX = 82,
330   PACK = 83,
331   LOGICAL_OR = 84,
332   ONE_HOT = 85,
333   LOGICAL_AND = 86,
334   LOGICAL_NOT = 87,
335   UNPACK = 88,
336   REDUCE_MIN = 89,
337   FLOOR_DIV = 90,
338   REDUCE_ANY = 91,
339   SQUARE = 92,
340   ZEROS_LIKE = 93,
341   FILL = 94,
342   FLOOR_MOD = 95,
343   RANGE = 96,
344   RESIZE_NEAREST_NEIGHBOR = 97,
345   LEAKY_RELU = 98,
346   SQUARED_DIFFERENCE = 99,
347   MIRROR_PAD = 100,
348   ABS = 101,
349   SPLIT_V = 102,
350   UNIQUE = 103,
351   CEIL = 104,
352   REVERSE_V2 = 105,
353   ADD_N = 106,
354   GATHER_ND = 107,
355   COS = 108,
356   WHERE = 109,
357   RANK = 110,
358   ELU = 111,
359   REVERSE_SEQUENCE = 112,
360   MATRIX_DIAG = 113,
361   QUANTIZE = 114,
362   MATRIX_SET_DIAG = 115,
363   ROUND = 116,
364   HARD_SWISH = 117,
365   IF = 118,
366   WHILE = 119,
367   NON_MAX_SUPPRESSION_V4 = 120,
368   NON_MAX_SUPPRESSION_V5 = 121,
369   SCATTER_ND = 122,
370   SELECT_V2 = 123,
371   DENSIFY = 124,
372   SEGMENT_SUM = 125,
373   BATCH_MATMUL = 126,
374   PLACEHOLDER_FOR_GREATER_OP_CODES = 127,
375   CUMSUM = 128,
376   CALL_ONCE = 129,
377   BROADCAST_TO = 130,
378   RFFT2D = 131,
379   CONV_3D = 132,
380   IMAG=133,
381   REAL=134,
382   COMPLEX_ABS=135,
383   HASHTABLE = 136,
384   HASHTABLE_FIND = 137,
385   HASHTABLE_IMPORT = 138,
386   HASHTABLE_SIZE = 139,
387   REDUCE_ALL = 140,
388   CONV_3D_TRANSPOSE = 141,
389   VAR_HANDLE = 142,
390   READ_VARIABLE = 143,
391   ASSIGN_VARIABLE = 144,
392   BROADCAST_ARGS = 145,
393   RANDOM_STANDARD_NORMAL = 146,
394 }
395 // LINT.ThenChange(nnapi_linter/linter.proto)
396
397 // Options for the builtin operators.
398 union BuiltinOptions {
399   Conv2DOptions,
400   DepthwiseConv2DOptions,
401   ConcatEmbeddingsOptions,
402   LSHProjectionOptions,
403   Pool2DOptions,
404   SVDFOptions,
405   RNNOptions,
406   FullyConnectedOptions,
407   SoftmaxOptions,
408   ConcatenationOptions,
409   AddOptions,
410   L2NormOptions,
411   LocalResponseNormalizationOptions,
412   LSTMOptions,
413   ResizeBilinearOptions,
414   CallOptions,
415   ReshapeOptions,
416   SkipGramOptions,
417   SpaceToDepthOptions,
418   EmbeddingLookupSparseOptions,
419   MulOptions,
420   PadOptions,
421   GatherOptions,
422   BatchToSpaceNDOptions,
423   SpaceToBatchNDOptions,
424   TransposeOptions,
425   ReducerOptions,
426   SubOptions,
427   DivOptions,
428   SqueezeOptions,
429   SequenceRNNOptions,
430   StridedSliceOptions,
431   ExpOptions,
432   TopKV2Options,
433   SplitOptions,
434   LogSoftmaxOptions,
435   CastOptions,
436   DequantizeOptions,
437   MaximumMinimumOptions,
438   ArgMaxOptions,
439   LessOptions,
440   NegOptions,
441   PadV2Options,
442   GreaterOptions,
443   GreaterEqualOptions,
444   LessEqualOptions,
445   SelectOptions,
446   SliceOptions,
447   TransposeConvOptions,
448   SparseToDenseOptions,
449   TileOptions,
450   ExpandDimsOptions,
451   EqualOptions,
452   NotEqualOptions,
453   ShapeOptions,
454   PowOptions,
455   ArgMinOptions,
456   FakeQuantOptions,
457   PackOptions,
458   LogicalOrOptions,
459   OneHotOptions,
460   LogicalAndOptions,
461   LogicalNotOptions,
462   UnpackOptions,
463   FloorDivOptions,
464   SquareOptions,
465   ZerosLikeOptions,
466   FillOptions,
467   BidirectionalSequenceLSTMOptions,
468   BidirectionalSequenceRNNOptions,
469   UnidirectionalSequenceLSTMOptions,
470   FloorModOptions,
471   RangeOptions,
472   ResizeNearestNeighborOptions,
473   LeakyReluOptions,
474   SquaredDifferenceOptions,
475   MirrorPadOptions,
476   AbsOptions,
477   SplitVOptions,
478   UniqueOptions,
479   ReverseV2Options,
480   AddNOptions,
481   GatherNdOptions,
482   CosOptions,
483   WhereOptions,
484   RankOptions,
485   ReverseSequenceOptions,
486   MatrixDiagOptions,
487   QuantizeOptions,
488   MatrixSetDiagOptions,
489   HardSwishOptions,
490   IfOptions,
491   WhileOptions,
492   DepthToSpaceOptions,
493   NonMaxSuppressionV4Options,
494   NonMaxSuppressionV5Options,
495   ScatterNdOptions,
496   SelectV2Options,
497   DensifyOptions,
498   SegmentSumOptions,
499   BatchMatMulOptions,
500   CumsumOptions,
501   CallOnceOptions,
502   BroadcastToOptions,
503   Rfft2dOptions,
504   Conv3DOptions,
505   HashtableOptions,
506   HashtableFindOptions,
507   HashtableImportOptions,
508   HashtableSizeOptions,
509   VarHandleOptions,
510   ReadVariableOptions,
511   AssignVariableOptions,
512   RandomOptions,
513   BCQGatherOptions = 252,
514   BCQFullyConnectedOptions = 253,
515   InstanceNormOptions = 254,
516 }
517
518 enum Padding : byte { SAME, VALID }
519
520 enum ActivationFunctionType : byte {
521   NONE = 0,
522   RELU = 1,
523   RELU_N1_TO_1 = 2,
524   RELU6 = 3,
525   TANH = 4,
526   SIGN_BIT = 5,
527 }
528
529 table Conv2DOptions {
530   padding:Padding;
531   stride_w:int;
532   stride_h:int;
533   fused_activation_function:ActivationFunctionType;
534   dilation_w_factor:int = 1;
535   dilation_h_factor:int = 1;
536 }
537
538 // Options for both Conv3D and Conv3DTranspose.
539 table Conv3DOptions {
540   padding:Padding;
541   stride_d:int;
542   stride_w:int;
543   stride_h:int;
544   fused_activation_function:ActivationFunctionType;
545   dilation_d_factor:int = 1;
546   dilation_w_factor:int = 1;
547   dilation_h_factor:int = 1;
548 }
549
550 table Pool2DOptions {
551   padding:Padding;
552   stride_w:int;
553   stride_h:int;
554   filter_width:int;
555   filter_height:int;
556   fused_activation_function:ActivationFunctionType;
557 }
558
559 table DepthwiseConv2DOptions {
560   // Parameters for DepthwiseConv version 1 or above.
561   padding:Padding;
562   stride_w:int;
563   stride_h:int;
564   // `depth_multiplier` is redundant. It's used by CPU kernels in
565   // TensorFlow 2.0 or below, but ignored in versions above.
566   // See comments in lite/c/builtin_op_data.h for more details.
567   depth_multiplier:int;
568   fused_activation_function:ActivationFunctionType;
569   // Parameters for DepthwiseConv version 2 or above.
570   dilation_w_factor:int = 1;
571   dilation_h_factor:int = 1;
572 }
573
574 table ConcatEmbeddingsOptions {
575   num_channels:int;
576   num_columns_per_channel:[int];
577   embedding_dim_per_channel:[int]; // This could be inferred from parameters.
578 }
579
580 enum LSHProjectionType: byte {
581   UNKNOWN = 0,
582   SPARSE = 1,
583   DENSE = 2,
584 }
585
586 table LSHProjectionOptions {
587   type: LSHProjectionType;
588 }
589
590 table SVDFOptions {
591   rank:int;
592   fused_activation_function:ActivationFunctionType;
593   // For weights-only quantization, use asymmetric quantization for non
594   // constant inputs at evaluation time.
595   asymmetric_quantize_inputs:bool;
596 }
597
598 // An implementation of TensorFlow RNNCell.
599 table RNNOptions {
600   fused_activation_function:ActivationFunctionType;
601   asymmetric_quantize_inputs:bool;
602 }
603
604 // An implementation of TensorFlow dynamic_rnn with RNNCell.
605 table SequenceRNNOptions {
606   time_major:bool;
607   fused_activation_function:ActivationFunctionType;
608   asymmetric_quantize_inputs:bool;
609 }
610
611 // An implementation of TensorFlow bidrectional_dynamic_rnn with RNNCell.
612 table BidirectionalSequenceRNNOptions {
613   time_major:bool;
614   fused_activation_function:ActivationFunctionType;
615   merge_outputs: bool;
616   asymmetric_quantize_inputs:bool;
617 }
618
619 enum FullyConnectedOptionsWeightsFormat: byte {
620   DEFAULT = 0,
621   SHUFFLED4x16INT8 = 1,
622   SHUFFLED16x1FLOAT32 = 127
623 }
624
625 // An implementation of TensorFlow fully_connected (a.k.a Dense) layer.
626 table FullyConnectedOptions {
627   // Parameters for FullyConnected version 1 or above.
628   fused_activation_function:ActivationFunctionType;
629
630   // Parameters for FullyConnected version 2 or above.
631   weights_format:FullyConnectedOptionsWeightsFormat = DEFAULT;
632
633   // Parameters for FullyConnected version 5 or above.
634   // If set to true, then the number of dimension is preserved. Furthermore,
635   // all but the last dimension of the input and output shapes will be equal.
636   keep_num_dims: bool;
637
638   // Parameters for FullyConnected version 7 or above.
639   // If set to true, then weights-only op will use asymmetric quantization for
640   // inputs.
641   asymmetric_quantize_inputs: bool;
642 }
643
644 table SoftmaxOptions {
645   beta: float;
646 }
647
648 // An implementation of TensorFlow concat.
649 table ConcatenationOptions {
650   axis:int;
651   fused_activation_function:ActivationFunctionType;
652 }
653
654 table AddOptions {
655   fused_activation_function:ActivationFunctionType;
656   // Parameters supported by version 3.
657   pot_scale_int16:bool = true;
658 }
659
660 table MulOptions {
661   fused_activation_function:ActivationFunctionType;
662 }
663
664 table L2NormOptions {
665   // This field is currently ignored in the L2 Norm Op.
666   fused_activation_function:ActivationFunctionType;
667 }
668
669 table LocalResponseNormalizationOptions {
670   radius:int;
671   bias:float;
672   alpha:float;
673   beta:float;
674 }
675
676 enum LSTMKernelType : byte {
677   // Full LSTM kernel which supports peephole and projection.
678   FULL = 0,
679   // Basic LSTM kernels. Equivalent to TensorFlow BasicLSTMCell.
680   BASIC = 1,
681 }
682
683 // An implementation of TensorFlow LSTMCell and CoupledInputForgetGateLSTMCell
684 table LSTMOptions {
685   // Parameters for LSTM version 1 or above.
686   fused_activation_function:ActivationFunctionType;
687   cell_clip: float; // Optional, 0.0 means no clipping
688   proj_clip: float; // Optional, 0.0 means no clipping
689
690   // Parameters for LSTM version 2 or above.
691   // Basic kernel is only supported in version 2 or above.
692   kernel_type: LSTMKernelType = FULL;
693
694   // Parameters for LSTM version 4 or above.
695   asymmetric_quantize_inputs: bool;
696 }
697
698 // An implementation of TensorFlow dynamic_rnn with LSTMCell.
699 table UnidirectionalSequenceLSTMOptions {
700   fused_activation_function:ActivationFunctionType;
701   cell_clip: float; // Optional, 0.0 means no clipping
702   proj_clip: float; // Optional, 0.0 means no clipping
703
704   // If true then first dimension is sequence, otherwise batch.
705   time_major:bool;
706
707   // Parameter for Unidirectional Sequence LSTM version 4.
708   asymmetric_quantize_inputs:bool;
709 }
710
711 table BidirectionalSequenceLSTMOptions {
712   // Parameters supported by version 1:
713   fused_activation_function:ActivationFunctionType;
714   cell_clip: float; // Optional, 0.0 means no clipping
715   proj_clip: float; // Optional, 0.0 means no clipping
716
717   // If true, store the outputs of both directions into the first output.
718   merge_outputs: bool;
719
720   // Parameters supported by version 2:
721   // If true then first dimension is sequence, otherwise batch.
722   // Version 1 implementations assumed time_major to be true, so this default
723   // value should never change.
724   time_major: bool = true;
725
726   // Parameters for version 3 or above.
727   asymmetric_quantize_inputs:bool;
728 }
729
730 table ResizeBilinearOptions {
731   new_height: int (deprecated);
732   new_width: int (deprecated);
733   align_corners: bool;
734   half_pixel_centers: bool;
735 }
736
737 table ResizeNearestNeighborOptions {
738   align_corners: bool;
739   half_pixel_centers: bool;
740 }
741
742 // A call operation options
743 table CallOptions {
744   // The subgraph index that needs to be called.
745   subgraph:uint;
746 }
747
748 table PadOptions {
749 }
750
751 table PadV2Options {
752 }
753
754 table ReshapeOptions {
755   new_shape:[int];
756 }
757
758 table SpaceToBatchNDOptions {
759 }
760
761 table BatchToSpaceNDOptions {
762 }
763
764 table SkipGramOptions {
765   ngram_size: int;
766   max_skip_size: int;
767   include_all_ngrams: bool;
768 }
769
770 table SpaceToDepthOptions {
771   block_size: int;
772 }
773
774 table DepthToSpaceOptions {
775   block_size: int;
776 }
777
778 table SubOptions {
779   fused_activation_function:ActivationFunctionType;
780   // Parameters supported by version 5
781   pot_scale_int16:bool = true;
782 }
783
784 table DivOptions {
785   fused_activation_function:ActivationFunctionType;
786 }
787
788 table TopKV2Options {
789 }
790
791 enum CombinerType : byte {
792   SUM = 0,
793   MEAN = 1,
794   SQRTN = 2,
795 }
796
797 table EmbeddingLookupSparseOptions {
798   combiner:CombinerType;
799 }
800
801 table GatherOptions {
802   axis: int;
803   // Parameters for Gather version 5 or above.
804   batch_dims: int = 0;
805 }
806
807 table TransposeOptions {
808 }
809
810 table ExpOptions {
811 }
812
813 table CosOptions {
814 }
815
816 table ReducerOptions {
817   keep_dims: bool;
818 }
819
820 table SqueezeOptions {
821   squeeze_dims:[int];
822 }
823
824 table SplitOptions {
825   num_splits: int;
826 }
827
828 table SplitVOptions {
829   num_splits: int;
830 }
831
832 table StridedSliceOptions {
833   begin_mask: int;
834   end_mask: int;
835   ellipsis_mask: int;
836   new_axis_mask: int;
837   shrink_axis_mask: int;
838 }
839
840 table LogSoftmaxOptions {
841 }
842
843 table CastOptions {
844   in_data_type: TensorType;
845   out_data_type: TensorType;
846 }
847
848 table DequantizeOptions {
849 }
850
851 table MaximumMinimumOptions {
852 }
853
854 table TileOptions {
855 }
856
857 table ArgMaxOptions {
858   output_type : TensorType;
859 }
860
861 table ArgMinOptions {
862   output_type : TensorType;
863 }
864
865 table GreaterOptions {
866 }
867
868 table GreaterEqualOptions {
869 }
870
871 table LessOptions {
872 }
873
874 table LessEqualOptions {
875 }
876
877 table NegOptions {
878 }
879
880 table SelectOptions {
881 }
882
883 table SliceOptions {
884 }
885
886 table TransposeConvOptions {
887   padding:Padding;
888   stride_w:int;
889   stride_h:int;
890 }
891
892 table ExpandDimsOptions {
893 }
894
895 table SparseToDenseOptions {
896   validate_indices:bool;
897 }
898
899 table EqualOptions {
900 }
901
902 table NotEqualOptions {
903 }
904
905 table ShapeOptions {
906   // Optional output type of the operation (int32 or int64). Defaults to int32.
907   out_type : TensorType;
908 }
909
910 table RankOptions {
911 }
912
913 table PowOptions {
914 }
915
916 table FakeQuantOptions {
917   // Parameters supported by version 1:
918   min:float;
919   max:float;
920   num_bits:int;
921
922   // Parameters supported by version 2:
923   narrow_range:bool;
924 }
925
926 table PackOptions {
927   values_count:int;
928   axis:int;
929 }
930
931 table LogicalOrOptions {
932 }
933
934 table OneHotOptions {
935   axis:int;
936 }
937
938 table AbsOptions {
939 }
940
941
942 table HardSwishOptions {
943 }
944
945 table LogicalAndOptions {
946 }
947
948 table LogicalNotOptions {
949 }
950
951 table UnpackOptions {
952   num:int;
953   axis:int;
954 }
955
956 table FloorDivOptions {
957 }
958
959 table SquareOptions {
960 }
961
962 table ZerosLikeOptions {
963 }
964
965 table FillOptions {
966 }
967
968 table FloorModOptions {
969 }
970
971 table RangeOptions {
972 }
973
974 table LeakyReluOptions {
975   alpha:float;
976 }
977
978 table SquaredDifferenceOptions {
979 }
980
981 enum MirrorPadMode : byte {
982   // Doesn't include borders.
983   REFLECT = 0,
984   // Includes borders.
985   SYMMETRIC = 1,
986 }
987
988 table MirrorPadOptions {
989   mode:MirrorPadMode;
990 }
991
992 table UniqueOptions {
993   idx_out_type:TensorType = INT32;
994 }
995
996 table ReverseV2Options {
997 }
998
999 table AddNOptions {
1000 }
1001
1002 table GatherNdOptions {
1003 }
1004
1005 table WhereOptions {
1006 }
1007
1008 table ReverseSequenceOptions {
1009   seq_dim:int;
1010   batch_dim:int = 0;
1011 }
1012
1013 table MatrixDiagOptions {
1014 }
1015
1016 table QuantizeOptions {
1017 }
1018
1019 table MatrixSetDiagOptions {
1020 }
1021
1022 table IfOptions {
1023   then_subgraph_index:int;
1024   else_subgraph_index:int;
1025 }
1026
1027 table CallOnceOptions {
1028   init_subgraph_index:int;
1029 }
1030
1031 table WhileOptions {
1032   cond_subgraph_index:int;
1033   body_subgraph_index:int;
1034 }
1035
1036 table NonMaxSuppressionV4Options {
1037 }
1038
1039 table NonMaxSuppressionV5Options {
1040 }
1041
1042 table ScatterNdOptions {
1043 }
1044
1045 table SelectV2Options {
1046 }
1047
1048 table DensifyOptions {
1049 }
1050
1051 table SegmentSumOptions {
1052 }
1053
1054 table BatchMatMulOptions {
1055   adjoint_lhs:bool;
1056   adjoint_rhs:bool;
1057   // Parameters for BatchMatMul version 4 or above.
1058   // If set to true, then weights-only op will use asymmetric quantization for
1059   // inputs.
1060   asymmetric_quantize_inputs: bool;
1061 }
1062
1063 table CumsumOptions {
1064   exclusive:bool;
1065   reverse:bool;
1066 }
1067
1068 table BroadcastToOptions {
1069 }
1070
1071 table Rfft2dOptions {
1072 }
1073
1074 table HashtableOptions {
1075   // The identity of hash tables. This identity will be used across different
1076   // subgraphs in the same interpreter instance.
1077   table_id:int;
1078   key_dtype:TensorType;
1079   value_dtype:TensorType;
1080 }
1081
1082 table HashtableFindOptions {
1083 }
1084
1085 table HashtableImportOptions {
1086 }
1087
1088 table HashtableSizeOptions {
1089 }
1090
1091 table VarHandleOptions {
1092   container:string;
1093   shared_name:string;
1094 }
1095
1096 table ReadVariableOptions {
1097 }
1098
1099 table AssignVariableOptions {
1100 }
1101
1102 table RandomOptions {
1103   seed: int;
1104   seed2: int;
1105 }
1106
1107 table BCQGatherOptions {
1108   input_hidden_size: int;
1109   axis: int;
1110 }
1111
1112 table BCQFullyConnectedOptions {
1113   weights_hidden_size: int;
1114   fused_activation_function:ActivationFunctionType;
1115 }
1116
1117 table InstanceNormOptions {
1118   epsilon:float;
1119   fused_activation_function:ActivationFunctionType;
1120 }
1121
1122 // An OperatorCode can be an enum value (BuiltinOperator) if the operator is a
1123 // builtin, or a string if the operator is custom.
1124 table OperatorCode {
1125   // This field is for backward compatibility. This field will be used when
1126   // the value of the extended builtin_code field has less than
1127   // BulitinOperator_PLACEHOLDER_FOR_GREATER_OP_CODES.
1128   deprecated_builtin_code:byte;
1129   custom_code:string;
1130
1131   // The version of the operator. The version need to be bumped whenever new
1132   // parameters are introduced into an op.
1133   version:int = 1;
1134
1135   // This field is introduced for resolving op builtin code shortage problem
1136   // (the original BuiltinOperator enum field was represented as a byte).
1137   // This field will be used when the value of the extended builtin_code field
1138   // has greater than BulitinOperator_PLACEHOLDER_FOR_GREATER_OP_CODES.
1139   builtin_code:BuiltinOperator;
1140 }
1141
1142 enum CustomOptionsFormat : byte {
1143   FLEXBUFFERS = 0,
1144 }
1145
1146 enum DataFormat : byte {
1147   // For 2D data, NHWC(batch, height, width, channels)
1148   // For 3D data, NDHWC(batch, depth, height, width, channels)
1149   CHANNELS_LAST = 0,
1150   // For 2D data, NCHW(batch, channels, height, width)
1151   // For 3D data, NCDHW(batch, channels, depth, height, width)
1152   CHANNELS_FIRST = 1,
1153 }
1154
1155 // An operator takes tensors as inputs and outputs. The type of operation being
1156 // performed is determined by an index into the list of valid OperatorCodes,
1157 // while the specifics of each operations is configured using builtin_options
1158 // or custom_options.
1159 table Operator {
1160   // Index into the operator_codes array. Using an integer here avoids
1161   // complicate map lookups.
1162   opcode_index:uint;
1163
1164   // Optional input are indicated by -1.
1165   inputs:[int];
1166   outputs:[int];
1167
1168   builtin_options:BuiltinOptions;
1169   custom_options:[ubyte];
1170   custom_options_format:CustomOptionsFormat;
1171
1172   // A list of booleans indicating the input tensors which are being mutated by
1173   // this operator.(e.g. used by RNN and LSTM).
1174   // For example, if the "inputs" array refers to 5 tensors and the second and
1175   // fifth are mutable variables, then this list will contain
1176   // [false, true, false, false, true].
1177   //
1178   // If the list is empty, no variable is mutated in this operator.
1179   // The list either has the same length as `inputs`, or is empty.
1180   mutating_variable_inputs:[bool];
1181
1182   // A list of indices to the subgraph's "tensors" that are internal to an Op.
1183   // Internal tensors are those that do not flow in or out of the operation,
1184   // but instead are part of internal computation. As such, the operation's
1185   // implementation may manage its memory more efficiently. They are needed
1186   // however (i.e. not just an implementation detail) since they are part of the
1187   // computation, which may require relevant metadata such as quantization
1188   // parameters.
1189   intermediates:[int];
1190 }
1191
1192 // The root type, defining a subgraph, which typically represents an entire
1193 // model.
1194 table SubGraph {
1195   // A list of all tensors used in this subgraph.
1196   tensors:[Tensor];
1197
1198   // Indices of the tensors that are inputs into this subgraph. Note this is
1199   // the list of non-static tensors that feed into the subgraph for inference.
1200   inputs:[int];
1201
1202   // Indices of the tensors that are outputs out of this subgraph. Note this is
1203   // the list of output tensors that are considered the product of the
1204   // subgraph's inference.
1205   outputs:[int];
1206
1207   // All operators, in execution order.
1208   operators:[Operator];
1209
1210   // Name of this subgraph (used for debugging).
1211   name:string;
1212
1213   // Data format for input/output of SubGraph
1214   data_format: DataFormat;
1215 }
1216
1217 // Table of raw data buffers (used for constant tensors). Referenced by tensors
1218 // by index. The generous alignment accommodates mmap-friendly data structures.
1219 table Buffer {
1220   data:[ubyte] (force_align: 16);
1221 }
1222
1223 table Metadata {
1224   // A human readable string to uniquely identify a Metadata.
1225   name:string;
1226   // An index to the buffers table.
1227   buffer:uint;
1228 }
1229
1230 // Map from an alias name of tensor to tensor index in the graph.
1231 // This is used in Signature def.
1232 table TensorMap {
1233   // Represents the alias to use for this tensor.
1234   name:string;
1235
1236   // The actual tensor index in the primary graph, that 'name' corresponds to.
1237   tensor_index:uint;
1238 }
1239
1240 // This corresponds to SignatureDef in Tensorflow SavedModel.
1241 // The SignatureDef will be part of the SavedModel provided for conversion.
1242 table SignatureDef {
1243   // Named inputs for this signature.
1244   inputs:[TensorMap];
1245
1246   // Named outputs for this signature.
1247   outputs:[TensorMap];
1248
1249   // Key value which was in the Tensorflow SavedModel SignatureDef map.
1250   signature_key:string;
1251
1252   // Model tag, deprecated.
1253   deprecated_tag:string (deprecated);
1254
1255   // Index of subgraphs that corresponds to the exported method.
1256   subgraph_index:uint;
1257 }
1258
1259 table Model {
1260   // Version of the schema.
1261   version:uint;
1262
1263   // A list of all operator codes used in this model. This is
1264   // kept in order because operators carry an index into this
1265   // vector.
1266   operator_codes:[OperatorCode];
1267
1268   // All the subgraphs of the model. The 0th is assumed to be the main
1269   // model.
1270   subgraphs:[SubGraph];
1271
1272   // A description of the model.
1273   description:string;
1274
1275   // Buffers of the model.
1276   // Note the 0th entry of this array must be an empty buffer (sentinel).
1277   // This is a convention so that tensors without a buffer can provide 0 as
1278   // their buffer.
1279   buffers:[Buffer];
1280
1281   // Metadata about the model. Indirects into the existings buffers list.
1282   // Deprecated, prefer to use metadata field.
1283   metadata_buffer:[int];
1284
1285   // Metadata about the model.
1286   metadata:[Metadata];
1287
1288   // Optional SignatureDefs for the model.
1289   signature_defs:[SignatureDef];
1290 }
1291
1292 root_type Model;