const uint32_t *dims = getInputDims (idx);
uint32_t tensor_size = getInputElemSize (idx);
+ uint32_t depth = dims[3];
+ bool image;
for (uint32_t rank_idx = 0; rank_idx < MAX_RANK; rank_idx++)
tensor_size *= dims[rank_idx];
+ image = (depth == 1 || depth == 3);
+
/** special handling for TRIV2 */
- if (layout == DATA_LAYOUT_TRIV2 && dims[3] != 3 &&
- dims[3] % DATA_GRANULARITY != 0) {
- uint32_t depth;
+ if (layout == DATA_LAYOUT_TRIV2 && depth % DATA_GRANULARITY != 0 && !image) {
+ uint32_t new_depth;
- depth = (dims[3] + DATA_GRANULARITY - 1) >> DATA_GRANULARITY_SHIFT;
- depth = depth * DATA_GRANULARITY;
+ new_depth = (depth + DATA_GRANULARITY - 1) >> DATA_GRANULARITY_SHIFT;
+ new_depth = new_depth * DATA_GRANULARITY;
- tensor_size /= dims[3];
- tensor_size *= depth;
+ tensor_size /= depth;
+ tensor_size *= new_depth;
}
return tensor_size;
const uint32_t *dims = getOutputDims (idx);
uint32_t tensor_size = getOutputElemSize (idx);
+ uint32_t depth = dims[3];
for (uint32_t rank_idx = 0; rank_idx < MAX_RANK; rank_idx++)
tensor_size *= dims[rank_idx];
/** special handling for TRIV2 */
- if (layout == DATA_LAYOUT_TRIV2 && dims[3] != 3 &&
- dims[3] % DATA_GRANULARITY != 0) {
- uint32_t depth;
+ if (layout == DATA_LAYOUT_TRIV2 && depth % DATA_GRANULARITY != 0) {
+ uint32_t new_depth;
- depth = (dims[3] + DATA_GRANULARITY - 1) >> DATA_GRANULARITY_SHIFT;
- depth = depth * DATA_GRANULARITY;
+ new_depth = (depth + DATA_GRANULARITY - 1) >> DATA_GRANULARITY_SHIFT;
+ new_depth = new_depth * DATA_GRANULARITY;
- tensor_size /= dims[3];
- tensor_size *= depth;
+ tensor_size /= depth;
+ tensor_size *= new_depth;
}
return tensor_size;