92e1614d194d94fad2f6a3ce23738c582798046e
[platform/core/ml/nnfw.git] / compute / cker / include / cker / eigen / eigen_spatial_convolutions-inl.h
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *      http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 #ifndef __NNFW_CKER_EIGEN_EIGEN_SPATIAL_CONVOLUTIONS_INL_H__
19 #define __NNFW_CKER_EIGEN_EIGEN_SPATIAL_CONVOLUTIONS_INL_H__
20
21 #include "cker/eigen/eigen_convolution_helpers.h"
22
23 // Note this header is used in both TF and TFLite.
24 namespace Eigen
25 {
26
27 namespace internal
28 {
29
30 // WARNING: Most of the code here implicitly assumes that the matrix is in
31 // ColMajor layout. This is guaranteed by the tensor contraction (see
32 // TensorContraction.h).
33 //
34 // Inside Eigen a tensor contraction is represented by a matrix multiplication.
35 // We don't want to actually extract image patches and reshape the result into
36 // a matrix (this involves allocating huge extra memory), so the patch
37 // extraction and reshape operations are implicit.
38 //
39 // TensorContractionInputMapper takes a matrix index and returns the coefficient
40 // (or the packet) of the "virtual tensor", that would be at that index if we
41 // were to actually reshape the result of patch extraction.
42 //
43 // TensorContractionSubMapper provides a similar view into the "virtual matrix"
44 // at the given vertical and horizontal offsets.
45 //
46 // "Virtual matrix" dimensions:
47 //   *0: kernelChannels * kernelRows * kernelCols;
48 //    1: out_height * out_width; * OTHERS (e.g batches, etc...)
49 //
50 // *) extracted patches are continuous in memory (innermost dimension assuming
51 //    col major layout)
52 //
53 // With this dimensions:
54 //   row - offset within a single patch (in code: patchId)
55 //   col - index of the extracted patch (in code: patchIndex)
56 //         patchIndex ∈ [0..num_patches * OTHERS] (batch and other dimensions)
57 //
58 // TODO(ezhulenev): Consolidate this part of the code with the image patch
59 // extraction code since they are both very similar.
60
61 template <typename NewDimension, Index Rows, Index Cols, typename ArgType, typename Device,
62           typename Scalar_, typename Index, typename nocontract_t, typename contract_t, int Side,
63           int packet_size, bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment>
64 class TensorContractionInputMapper<
65     Scalar_, Index, Side,
66     TensorEvaluator<
67         const TensorReshapingOp<NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType>>,
68         Device>,
69     nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment>
70 {
71 public:
72   typedef Scalar_ Scalar;
73
74   typedef TensorContractionInputMapper<
75       Scalar, Index, Side,
76       TensorEvaluator<
77           const TensorReshapingOp<NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType>>,
78           Device>,
79       nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment>
80       Self;
81
82   typedef TensorContractionSubMapper<
83       Scalar, Index, Side,
84       TensorEvaluator<
85           const TensorReshapingOp<NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType>>,
86           Device>,
87       nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment>
88       SubMapper;
89
90   typedef SubMapper VectorMapper;
91   typedef SubMapper LinearMapper;
92   typedef typename packet_traits<Scalar>::type Packet;
93
94   typedef TensorEvaluator<ArgType, Device> TensorEvaluatorT;
95
96   EIGEN_DEVICE_FUNC
97   TensorContractionInputMapper(
98       const TensorEvaluator<
99           const TensorReshapingOp<NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType>>,
100           Device> &tensor,
101       const nocontract_t &, const nocontract_t &, const contract_t &, const contract_t &)
102       : m_impl(tensor.impl().impl())
103   {
104     Index patch_rows;
105     Index patch_depth;
106     if (internal::traits<ArgType>::Layout == ColMajor)
107     {
108       patch_depth = tensor.impl().dimensions()[0];
109       patch_rows = tensor.impl().dimensions()[1];
110       m_patch_cols = tensor.impl().dimensions()[2];
111       m_num_patches = tensor.impl().dimensions()[3];
112     }
113     else
114     {
115       const size_t NumDims = tensor.impl().dimensions().size();
116       patch_depth = tensor.impl().dimensions()[NumDims - 1];
117       patch_rows = tensor.impl().dimensions()[NumDims - 2];
118       m_patch_cols = tensor.impl().dimensions()[NumDims - 3];
119       m_num_patches = tensor.impl().dimensions()[NumDims - 4];
120     }
121
122     // Strides for navigating through the single patch.
123     m_patch_row_stride = patch_depth;
124     m_patch_col_stride = patch_rows * m_patch_row_stride;
125
126     m_patch_row_inflate_strides = tensor.impl().rowInflateStride();
127     m_patch_col_inflate_strides = tensor.impl().colInflateStride();
128
129     m_colStride = patch_rows;
130
131     m_outputRows = tensor.impl().outputRows();
132     m_outputCols = tensor.impl().outputCols();
133     m_row_strides = tensor.impl().userRowStride();
134     m_col_strides = tensor.impl().userColStride();
135
136     m_in_row_strides = tensor.impl().userInRowStride();
137     m_in_col_strides = tensor.impl().userInColStride();
138
139     if (internal::traits<ArgType>::Layout == ColMajor)
140     {
141       m_inputRows = tensor.impl().impl().dimensions()[1];
142       m_inputCols = tensor.impl().impl().dimensions()[2];
143     }
144     else
145     {
146       const int NumDims = tensor.impl().impl().dimensions().size();
147       m_inputRows = tensor.impl().impl().dimensions()[NumDims - 2];
148       m_inputCols = tensor.impl().impl().dimensions()[NumDims - 3];
149     }
150
151     m_rowInputStride = patch_depth;
152     m_colInputStride = patch_depth * m_inputRows;
153     m_patchInputStride = patch_depth * m_inputRows * m_inputCols;
154
155     m_rowPaddingTop = tensor.impl().rowPaddingTop();
156     m_colPaddingLeft = tensor.impl().colPaddingLeft();
157
158     m_fastPatchRowStride = internal::TensorIntDivisor<Index>(m_patch_row_stride);
159     m_fastPatchColStride = internal::TensorIntDivisor<Index>(m_patch_col_stride);
160     m_fastInputRowStride = internal::TensorIntDivisor<Index>(m_patch_row_inflate_strides);
161     m_fastInputColStride = internal::TensorIntDivisor<Index>(m_patch_col_inflate_strides);
162     m_fastNumPatches = internal::TensorIntDivisor<Index>(m_num_patches);
163     m_fastColStride = internal::TensorIntDivisor<Index>(m_colStride);
164     m_fastOutputRows = internal::TensorIntDivisor<Index>(m_outputRows);
165     m_fastDimZero = internal::TensorIntDivisor<Index>(patch_depth);
166   }
167
168   EIGEN_DEVICE_FUNC
169   TensorContractionInputMapper(const TensorContractionInputMapper &base_mapper)
170       : m_impl(base_mapper.m_impl)
171   {
172     m_patch_cols = base_mapper.m_patch_cols;
173     m_num_patches = base_mapper.m_num_patches;
174
175     m_patch_row_stride = base_mapper.m_patch_row_stride;
176     m_patch_col_stride = base_mapper.m_patch_col_stride;
177
178     m_patch_row_inflate_strides = base_mapper.m_patch_row_inflate_strides;
179     m_patch_col_inflate_strides = base_mapper.m_patch_col_inflate_strides;
180
181     m_colStride = base_mapper.m_colStride;
182
183     m_rowInputStride = base_mapper.m_rowInputStride;
184     m_colInputStride = base_mapper.m_colInputStride;
185     m_patchInputStride = base_mapper.m_patchInputStride;
186
187     m_inputRows = base_mapper.m_inputRows;
188     m_inputCols = base_mapper.m_inputCols;
189
190     m_outputRows = base_mapper.m_outputRows;
191     m_outputCols = base_mapper.m_outputCols;
192     m_row_strides = base_mapper.m_row_strides;
193     m_col_strides = base_mapper.m_col_strides;
194
195     m_in_row_strides = base_mapper.m_in_row_strides;
196     m_in_col_strides = base_mapper.m_in_col_strides;
197
198     m_rowPaddingTop = base_mapper.m_rowPaddingTop;
199     m_colPaddingLeft = base_mapper.m_colPaddingLeft;
200
201     m_fastPatchRowStride = base_mapper.m_fastPatchRowStride;
202     m_fastPatchColStride = base_mapper.m_fastPatchColStride;
203     m_fastInputRowStride = base_mapper.m_fastInputRowStride;
204     m_fastInputColStride = base_mapper.m_fastInputColStride;
205     m_fastNumPatches = base_mapper.m_fastNumPatches;
206     m_fastColStride = base_mapper.m_fastColStride;
207     m_fastOutputRows = base_mapper.m_fastOutputRows;
208     m_fastDimZero = base_mapper.m_fastDimZero;
209   }
210
211   // If true, turns off some optimizations for loading packets since the image
212   // patches are "non-standard" such as there are non-trivial strides or
213   // inflations in the input.
214   EIGEN_DEVICE_FUNC
215   EIGEN_ALWAYS_INLINE bool nonStandardPatches() const
216   {
217     return m_in_row_strides != 1 || m_in_col_strides != 1 || m_patch_row_inflate_strides != 1 ||
218            m_patch_col_inflate_strides != 1;
219   }
220
221   EIGEN_DEVICE_FUNC
222   EIGEN_STRONG_INLINE SubMapper getSubMapper(Index i, Index j) const
223   {
224     return SubMapper(*this, i, j);
225   }
226
227   EIGEN_DEVICE_FUNC
228   EIGEN_STRONG_INLINE LinearMapper getLinearMapper(Index i, Index j) const
229   {
230     return LinearMapper(*this, i, j);
231   }
232
233   EIGEN_DEVICE_FUNC
234   EIGEN_ALWAYS_INLINE Scalar operator()(Index row) const
235   {
236     Index rowIndex, colIndex, otherIndex;
237     computeBaseIndices(0, rowIndex, colIndex, otherIndex);
238     return loadCoeff(row, rowIndex, colIndex, otherIndex);
239   }
240
241   // Load the coefficient at the patchIndex location instead of the usual
242   // m_rowIndex,
243   // m_colIndex, m_otherIndex. This is currently only used by the gpu code.
244   // EIGEN_DEVICE_FUNC
245   EIGEN_DEVICE_FUNC
246   EIGEN_STRONG_INLINE Scalar operator()(Index row, Index patchIndex) const
247   {
248     Index rowIndex, colIndex, otherIndex;
249     computeBaseIndices(patchIndex, rowIndex, colIndex, otherIndex);
250     return loadCoeff(row, rowIndex, colIndex, otherIndex);
251   }
252
253   EIGEN_DEVICE_FUNC
254   EIGEN_ALWAYS_INLINE Packet loadPacket(Index row) const
255   {
256     Index rowIndex, colIndex, otherIndex;
257     computeBaseIndices(0, rowIndex, colIndex, otherIndex);
258     return loadPacket(row, rowIndex, colIndex, otherIndex);
259   }
260
261   // Load the packet at the patchIndex location instead of the usual m_rowIndex,
262   // m_colIndex, m_otherIndex. This is currently only used by the gpu code.
263   EIGEN_DEVICE_FUNC
264   EIGEN_ALWAYS_INLINE Packet loadPacket(Index row, Index patchIndex) const
265   {
266     Index rowIndex, colIndex, otherIndex;
267     computeBaseIndices(patchIndex, rowIndex, colIndex, otherIndex);
268     return loadPacket(row, rowIndex, colIndex, otherIndex);
269   }
270
271   EIGEN_DEVICE_FUNC
272   EIGEN_ALWAYS_INLINE const TensorEvaluator<ArgType, Device> &impl() const { return m_impl; }
273
274   EIGEN_DEVICE_FUNC
275   EIGEN_ALWAYS_INLINE Index patchDepth() const { return m_rowInputStride; }
276   EIGEN_DEVICE_FUNC
277   EIGEN_ALWAYS_INLINE Index patchRows() const { return m_colStride; }
278   EIGEN_DEVICE_FUNC
279   EIGEN_ALWAYS_INLINE Index patchCols() const { return m_patch_cols; }
280
281 private:
282   friend class TensorContractionSubMapper<
283       Scalar, Index, Side,
284       TensorEvaluator<
285           const TensorReshapingOp<NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType>>,
286           Device>,
287       nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment>;
288
289   // Load coefficient from a patch specified by the "within patch offset"
290   // (patchId) and the precomputed indices of the first element of the patch.
291   EIGEN_DEVICE_FUNC
292   EIGEN_STRONG_INLINE Scalar loadCoeff(Index patchId, Index rowIndex, Index colIndex,
293                                        Index otherIndex) const
294   {
295     // Find the offset of the element wrt the location of the first element.
296     const Index patchOffset = patchId / m_fastDimZero;
297
298     const Index colOffset = patchOffset / m_fastColStride;
299     const Index inputCol = colIndex + colOffset * m_in_col_strides;
300     const Index origInputCol = (m_patch_col_inflate_strides == 1)
301                                    ? inputCol
302                                    : ((inputCol >= 0) ? (inputCol / m_fastInputColStride) : 0);
303
304     const Index rowOffset = patchOffset - colOffset * m_colStride;
305     const Index inputRow = rowIndex + rowOffset * m_in_row_strides;
306     const Index origInputRow = (m_patch_row_inflate_strides == 1)
307                                    ? inputRow
308                                    : ((inputRow >= 0) ? (inputRow / m_fastInputRowStride) : 0);
309     if (origInputCol < 0 || origInputRow < 0 || origInputCol >= m_inputCols ||
310         origInputRow >= m_inputRows || (inputCol != origInputCol * m_patch_col_inflate_strides) ||
311         (inputRow != origInputRow * m_patch_row_inflate_strides))
312     {
313       return Scalar(0);
314     }
315     const Index depth = patchId - patchOffset * patchDepth();
316     const Index inputIndex =
317         depth + origInputRow * m_rowInputStride + origInputCol * m_colInputStride + otherIndex;
318     return m_impl.coeff(inputIndex);
319   }
320
321   // This is the same as loadCoeff(...), but optimized for all `inflate_strides`
322   // and `in_strides` equal to 1 (template specialization without templates).
323   EIGEN_DEVICE_FUNC
324   EIGEN_STRONG_INLINE Scalar loadCoeffStandard(Index patchId, Index rowIndex, Index colIndex,
325                                                Index otherIndex) const
326   {
327     eigen_assert(!nonStandardPatches());
328
329     // Find the offset of the element wrt the location of the first element.
330     const Index patchOffset = patchId / m_fastDimZero;
331     const Index colOffset = patchOffset / m_fastColStride;
332     const Index rowOffset = patchOffset - colOffset * m_colStride;
333     const Index inputCol = colIndex + colOffset;
334     const Index inputRow = rowIndex + rowOffset;
335     if (inputCol < 0 || inputCol >= m_inputCols || inputRow < 0 || inputRow >= m_inputRows)
336     {
337       return Scalar(0);
338     }
339     const Index depth = patchId - patchOffset * patchDepth();
340     const Index inputIndex =
341         depth + inputRow * m_rowInputStride + inputCol * m_colInputStride + otherIndex;
342     return m_impl.coeff(inputIndex);
343   }
344
345   // Load packet from a patch specified by the "within patch offset"
346   // (patchId) and the precomputed indices of the first element of the patch.
347   EIGEN_DEVICE_FUNC
348   EIGEN_ALWAYS_INLINE Packet loadPacket(Index patchId, Index rowIndex, Index colIndex,
349                                         Index otherIndex) const
350   {
351     const Index packetSize = internal::unpacket_traits<Packet>::size;
352     EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE)
353     eigen_assert(patchId < patchDepth() * patchRows() * m_patch_cols);
354
355     if (nonStandardPatches())
356     {
357       return packetWithPossibleZero(patchId, rowIndex, colIndex, otherIndex);
358     }
359     typedef decltype(m_impl) TensorEvaluatorT;
360     return loadPacketStandard<Packet, TensorEvaluatorT>(patchId, rowIndex, colIndex, otherIndex);
361   }
362
363   // Helper function to load a 'partial' packet - this is the single column
364   // part of a packet that is split across two columns. In the 'partial' packet,
365   // the elements corresponding to the column (specified through colOffset) are
366   // loaded and the rest of the elements are zero-filled into the 'partial'
367   // packet. This function is called from loadPacketStandardFromTwoColumns().
368   // This code path is exercised only when the packet type supports masked load
369   // and when the partial packet load is available in the TensorEvaluator.
370   EIGEN_DEVICE_FUNC
371   EIGEN_ALWAYS_INLINE Packet loadPartialPacketStandard(Index rowIndex, Index colIndex,
372                                                        Index otherIndex, Index patchId,
373                                                        const Index span[],
374                                                        const Index patchOffsets[],
375                                                        Index colOffset) const
376   {
377     const Index inputCol = colIndex + colOffset;
378     const Index rowOffsets[2] = {patchOffsets[0] - colOffset * m_colStride,
379                                  patchOffsets[1] - colOffset * m_colStride};
380     const Index inputRows[2] = {rowIndex + rowOffsets[0], rowIndex + rowOffsets[1]};
381
382     if (inputRows[0] >= m_inputRows || inputRows[1] < 0 || inputCol >= m_inputCols || inputCol < 0)
383     {
384       // Partial packet is all zeros
385       return internal::pset1<Packet>(Scalar(0));
386     }
387     else if (inputRows[0] >= 0 && inputRows[1] < m_inputRows)
388     {
389       // From inputIndex-span[0], we need to load elements starting from index
390       // span[0] all the way upto (and including) span[1].
391       const Index depth = patchId - patchOffsets[0] * patchDepth();
392       const Index inputIndex =
393           depth + inputRows[0] * m_rowInputStride + inputCol * m_colInputStride + otherIndex;
394       return m_impl.template partialPacket<Packet>(inputIndex - span[0],
395                                                    mask<Packet>(span[0], span[1] + 1));
396     }
397     else
398     {
399       // Using slow path for this partial packet.
400       // We need to load elements starting from index span[0] all the way upto
401       // (and including) span[1]. We split this load into 3 parts:
402       // 0 : span[0]-1 - Zeros will be loaded for these indices
403       // span[0] : span[1] - Elements will be loaded here for these indices
404       // span[1]+1 : packetSize-1 - Zeross will be loaded for these indices
405       const Index packetSize = internal::unpacket_traits<Packet>::size;
406       EIGEN_ALIGN_MAX
407       typename internal::remove_const<Scalar>::type values[packetSize];
408       for (int i = 0; i < span[0]; ++i)
409         values[i] = Scalar(0);
410       for (int i = span[0]; i < span[1] + 1; ++i)
411         values[i] = loadCoeff(patchId - span[0] + i, rowIndex, colIndex, otherIndex);
412       for (int i = span[1] + 1; i < packetSize; ++i)
413         values[i] = Scalar(0);
414       return internal::pload<Packet>(values);
415     }
416   }
417
418   // Helper function to load a packet that is split across two columns.
419   // If required, this function is called from loadPacketStandard() when the
420   // packet type supports masked load and when the partial packet load is
421   // available in the TensorEvaluator.
422   EIGEN_DEVICE_FUNC
423   EIGEN_ALWAYS_INLINE Packet loadPacketStandardFromTwoColumns(Index patchId, Index rowIndex,
424                                                               Index colIndex, Index otherIndex,
425                                                               const Index patchOffsets[],
426                                                               const Index colOffsets[]) const
427   {
428     eigen_assert(colOffsets[1] == colOffsets[0] + 1);
429     const Index packetSize = internal::unpacket_traits<Packet>::size;
430
431     // Packet to load will be split into 2 parts where each part spans a single
432     // column. First determine where to split.
433     const Index patchIdSplit = ((colOffsets[1] * m_colStride) * m_rowInputStride) - 1;
434     const Index patchOffsetSplit = patchIdSplit / m_fastDimZero;
435
436     // patchIds[i]:          patchId corresponding to partial packet i
437     // spans[i]:             Start and end indices corresponding to the elements
438     //                       to be loaded for partial packet i
439     // patchOffsets2Cols[i]: patchOffsets corresponding to partial packet i
440     const Index patchIds[2] = {patchId, patchIdSplit + 1};
441     const Index spans[2][2] = {{0, patchIdSplit - patchId},
442                                {patchIdSplit - patchId + 1, packetSize - 1}};
443     const Index patchOffsets2Cols[2][2] = {{patchOffsets[0], patchOffsetSplit},
444                                            {patchOffsetSplit + 1, patchOffsets[1]}};
445
446     // Load partial packets and do bit-wise OR to generate required packet
447     return internal::por<Packet>(
448         loadPartialPacketStandard(rowIndex, colIndex, otherIndex, patchIds[0], spans[0],
449                                   patchOffsets2Cols[0], colOffsets[0]),
450         loadPartialPacketStandard(rowIndex, colIndex, otherIndex, patchIds[1], spans[1],
451                                   patchOffsets2Cols[1], colOffsets[1]));
452   }
453
454   // Helper function to load a packet that is present in a single columns.
455   // If required, this function is called from loadPacketStandard().
456   EIGEN_DEVICE_FUNC
457   EIGEN_ALWAYS_INLINE Packet loadPacketStandardFromSingleColumn(Index patchId, Index rowIndex,
458                                                                 Index colIndex, Index otherIndex,
459                                                                 const Index patchOffsets[],
460                                                                 const Index colOffsets[],
461                                                                 const Index inputCols[]) const
462   {
463     eigen_assert(colOffsets[0] == colOffsets[1]);
464     const Index rowOffsets[2] = {patchOffsets[0] - colOffsets[0] * m_colStride,
465                                  patchOffsets[1] - colOffsets[1] * m_colStride};
466     eigen_assert(rowOffsets[0] <= rowOffsets[1]);
467     const Index inputRows[2] = {rowIndex + rowOffsets[0], rowIndex + rowOffsets[1]};
468
469     if (inputRows[0] >= m_inputRows || inputRows[1] < 0)
470     {
471       // all zeros
472       return internal::pset1<Packet>(Scalar(0)); // all zeros
473     }
474
475     if (inputRows[0] >= 0 && inputRows[1] < m_inputRows)
476     {
477       // no padding
478       const Index depth = patchId - patchOffsets[0] * patchDepth();
479       const Index inputIndex =
480           depth + inputRows[0] * m_rowInputStride + inputCols[0] * m_colInputStride + otherIndex;
481       return m_impl.template packet<Unaligned>(inputIndex);
482     }
483     return packetWithPossibleZero(patchId, rowIndex, colIndex, otherIndex);
484   }
485
486   // Load standard packet from a patch specified by the "within patch offset"
487   // (patchId) and the precomputed indices of the first element of the patch.
488   // This function will be called if partial packet loading is not available
489   // for the TensorEvaluator or if the packet type does not support masked
490   // load.
491   template <typename PacketT, typename TensorEvaluatorT>
492   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE typename std::enable_if<
493       !TensorEvaluatorHasPartialPacket<TensorEvaluatorT, PacketT, Index>::value, PacketT>::type
494   loadPacketStandard(Index patchId, Index rowIndex, Index colIndex, Index otherIndex) const
495   {
496     const Index packetSize = internal::unpacket_traits<Packet>::size;
497     EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE)
498     eigen_assert(patchId < patchDepth() * patchRows() * m_patch_cols);
499
500     eigen_assert(!nonStandardPatches());
501
502     if ((patchDepth() % packetSize) == 0)
503     {
504       return loadPacketFast(patchId, rowIndex, colIndex, otherIndex);
505     }
506
507     // Offsets and input calculation here are identical to
508     // loadCoeffStandard(...), but repeated twice.
509     const Index patchOffsets[2] = {patchId / m_fastDimZero,
510                                    (patchId + packetSize - 1) / m_fastDimZero};
511     const Index colOffsets[2] = {patchOffsets[0] / m_fastColStride,
512                                  patchOffsets[1] / m_fastColStride};
513     const Index inputCols[2] = {colIndex + colOffsets[0], colIndex + colOffsets[1]};
514
515     if (inputCols[0] >= m_inputCols || inputCols[1] < 0)
516     {
517       // all zeros
518       return internal::pset1<Packet>(Scalar(0));
519     }
520     if (inputCols[0] == inputCols[1])
521     {
522       return loadPacketStandardFromSingleColumn(patchId, rowIndex, colIndex, otherIndex,
523                                                 patchOffsets, colOffsets, inputCols);
524     }
525     return packetWithPossibleZero(patchId, rowIndex, colIndex, otherIndex);
526   }
527
528   // Load standard packet from a patch specified by the "within patch offset"
529   // (patchId) and the precomputed indices of the first element of the patch.
530   // This function will be called if partial packet loading is available for
531   // the TensorEvaluator and if the packet type supports masked load.
532   // The only difference between this and the other case is that if the packet
533   // to load is split across two columns, then in this case instead of going to
534   // the slow (element-by-element) load, we load two packets - each containing
535   // elements from one of the columns (rest of the elements of the packets are
536   // zeroes), and then combine these two packets to generate the required
537   // packet. The idea is to enable fast load (if possible) of these 'partial'
538   // packets.
539   template <typename PacketT, typename TensorEvaluatorT>
540   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE typename std::enable_if<
541       TensorEvaluatorHasPartialPacket<TensorEvaluatorT, PacketT, Index>::value, PacketT>::type
542   loadPacketStandard(Index patchId, Index rowIndex, Index colIndex, Index otherIndex) const
543   {
544     const Index packetSize = internal::unpacket_traits<PacketT>::size;
545     EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE)
546     eigen_assert(patchId < patchDepth() * patchRows() * m_patch_cols);
547
548     eigen_assert(!nonStandardPatches());
549
550     if ((patchDepth() % packetSize) == 0)
551     {
552       return loadPacketFast(patchId, rowIndex, colIndex, otherIndex);
553     }
554
555     // Offsets and input calculation here are identical to
556     // loadCoeffStandard(...), but repeated twice.
557     const Index patchOffsets[2] = {patchId / m_fastDimZero,
558                                    (patchId + packetSize - 1) / m_fastDimZero};
559     const Index colOffsets[2] = {patchOffsets[0] / m_fastColStride,
560                                  patchOffsets[1] / m_fastColStride};
561     const Index inputCols[2] = {colIndex + colOffsets[0], colIndex + colOffsets[1]};
562
563     if (inputCols[0] >= m_inputCols || inputCols[1] < 0)
564     {
565       // all zeros
566       return internal::pset1<PacketT>(Scalar(0));
567     }
568     if (inputCols[0] == inputCols[1])
569     {
570       return loadPacketStandardFromSingleColumn(patchId, rowIndex, colIndex, otherIndex,
571                                                 patchOffsets, colOffsets, inputCols);
572     }
573     if (inputCols[1] == inputCols[0] + 1)
574     {
575       return loadPacketStandardFromTwoColumns(patchId, rowIndex, colIndex, otherIndex, patchOffsets,
576                                               colOffsets);
577     }
578     return packetWithPossibleZero(patchId, rowIndex, colIndex, otherIndex);
579   }
580
581   EIGEN_DEVICE_FUNC
582   EIGEN_ALWAYS_INLINE Packet loadPacketFast(Index patchId, Index rowIndex, Index colIndex,
583                                             Index otherIndex) const
584   {
585     const Index packetSize = internal::unpacket_traits<Packet>::size;
586     EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE)
587     eigen_assert(patchId < patchDepth() * patchRows() * m_patch_cols);
588
589     eigen_assert(!nonStandardPatches());
590     eigen_assert((patchDepth() % packetSize) == 0);
591     // Find the offset of the element wrt the location of the first element.
592     const Index patchOffset = patchId / m_fastDimZero;
593     eigen_assert((patchId + packetSize - 1) / m_fastDimZero == patchOffset);
594
595     const Index colOffset = patchOffset / m_fastColStride;
596     const Index rowOffset = patchOffset - colOffset * m_colStride;
597     const Index inputCol = colIndex + colOffset;
598     const Index inputRow = rowIndex + rowOffset;
599     if (inputCol < 0 || inputRow < 0 || inputCol >= m_inputCols || inputRow >= m_inputRows)
600     {
601       // all zeros
602       return internal::pset1<Packet>(Scalar(0));
603     }
604     // no padding
605     const Index depth = patchId - patchOffset * patchDepth();
606     const Index inputIndex =
607         depth + inputRow * m_rowInputStride + inputCol * m_colInputStride + otherIndex;
608     return m_impl.template packet<Unaligned>(inputIndex);
609   }
610
611   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet packetWithPossibleZero(Index patchId, Index rowIndex,
612                                                                       Index colIndex,
613                                                                       Index otherIndex) const
614   {
615     const int packetSize = internal::unpacket_traits<Packet>::size;
616     EIGEN_ALIGN_MAX
617     typename internal::remove_const<Scalar>::type values[packetSize];
618     for (int i = 0; i < packetSize; ++i)
619     {
620       values[i] = loadCoeff(patchId + i, rowIndex, colIndex, otherIndex);
621     }
622     Packet rslt = internal::pload<Packet>(values);
623     return rslt;
624   }
625
626   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void
627   computeBaseIndices(Index patchIndex, Index &rowIndex, Index &colIndex, Index &otherIndex) const
628   {
629     const size_t NumInputDims =
630         array_size<typename TensorEvaluator<ArgType, Device>::Dimensions>::value;
631     otherIndex = (NumInputDims == 3) ? 0 : patchIndex / m_fastNumPatches;
632     const Index patch2DIndex =
633         (NumInputDims == 3) ? patchIndex : (patchIndex - otherIndex * m_num_patches);
634     otherIndex *= m_patchInputStride;
635     colIndex = patch2DIndex / m_fastOutputRows;
636     rowIndex = patch2DIndex - colIndex * m_outputRows;
637     colIndex = colIndex * m_col_strides - m_colPaddingLeft;
638     rowIndex = rowIndex * m_row_strides - m_rowPaddingTop;
639   }
640
641   Index m_patch_cols;  // number of columns in the patch
642   Index m_num_patches; // number of patches to extract.
643
644   // Strides for navigating through the single patch.
645   Index m_patch_row_stride;
646   Index m_patch_col_stride;
647   internal::TensorIntDivisor<Index> m_fastPatchRowStride;
648   internal::TensorIntDivisor<Index> m_fastPatchColStride;
649
650   Index m_patch_row_inflate_strides; // the strides for row inflation in the
651                                      // image patch
652   Index m_patch_col_inflate_strides; // the strides for col inflation in the
653                                      // image patch
654   // Fast representation of inflation strides.
655   internal::TensorIntDivisor<Index> m_fastInputRowStride;
656   internal::TensorIntDivisor<Index> m_fastInputColStride;
657
658   Index m_otherStride;
659   Index m_colStride;
660   internal::TensorIntDivisor<Index> m_fastNumPatches;
661   internal::TensorIntDivisor<Index> m_fastColStride;
662
663   Index m_rowInputStride;   // row stride in the input tensor
664   Index m_colInputStride;   // col stride in the input tensor
665   Index m_patchInputStride; // patch stride in the input tensor
666
667   Index m_inputRows; // Number of rows in the input tensor
668   Index m_inputCols; // Number of cols in the input tensor
669
670   Index m_outputRows; // Number of convolution output rows
671   Index m_outputCols; // Number of convolution output column
672
673   Index m_row_strides; // User specified row stride
674   Index m_col_strides; // User specified col stride
675
676   Index m_in_row_strides; // User specified input row stride
677   Index m_in_col_strides; // User specified input col stride
678
679   Index m_rowPaddingTop;  // Row padding
680   Index m_colPaddingLeft; // Column padding
681
682   internal::TensorIntDivisor<Index> m_fastOutputRows;
683   internal::TensorIntDivisor<Index> m_fastDimZero;
684
685   const TensorEvaluator<ArgType, Device> m_impl;
686 };
687
688 template <typename NewDimension, Index Rows, Index Cols, typename ArgType, typename Device,
689           typename Scalar, typename Index, typename nocontract_t, typename contract_t, int Side,
690           int packet_size, bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment>
691 class TensorContractionSubMapper<
692     Scalar, Index, Side,
693     TensorEvaluator<
694         const TensorReshapingOp<NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType>>,
695         Device>,
696     nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment>
697 {
698 public:
699   typedef typename packet_traits<Scalar>::type Packet;
700   typedef typename packet_traits<Scalar>::half HalfPacket;
701
702   typedef TensorContractionInputMapper<
703       Scalar, Index, Side,
704       TensorEvaluator<
705           const TensorReshapingOp<NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType>>,
706           Device>,
707       nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment>
708       ParentMapper;
709
710   typedef TensorContractionSubMapper<
711       Scalar, Index, Side,
712       TensorEvaluator<
713           const TensorReshapingOp<NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType>>,
714           Device>,
715       nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment>
716       Self;
717
718   typedef Self LinearMapper;
719
720   typedef typename ParentMapper::TensorEvaluatorT TensorEvaluatorT;
721
722   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionSubMapper(const ParentMapper &base_mapper,
723                                                                    Index vert_offset,
724                                                                    Index horiz_offset)
725       : m_depth_offset(vert_offset), m_col_offset(horiz_offset), m_base_mapper(base_mapper)
726   {
727     m_base_mapper.computeBaseIndices(m_col_offset, m_rowIndex, m_colIndex, m_otherIndex);
728   }
729   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionSubMapper(const Self &base_mapper,
730                                                                    Index vert_offset,
731                                                                    Index horiz_offset)
732       : m_depth_offset(vert_offset + base_mapper.m_depth_offset),
733         m_col_offset(horiz_offset + base_mapper.m_col_offset),
734         m_base_mapper(base_mapper.m_base_mapper)
735   {
736     m_base_mapper.computeBaseIndices(m_col_offset, m_rowIndex, m_colIndex, m_otherIndex);
737   }
738   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i) const
739   {
740     return m_base_mapper.loadCoeff(i + m_depth_offset, m_rowIndex, m_colIndex, m_otherIndex);
741   }
742   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i, Index j) const
743   {
744     return m_base_mapper(i + m_depth_offset, j + m_col_offset);
745   }
746
747   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i) const
748   {
749     return m_base_mapper.loadPacket(i + m_depth_offset, m_rowIndex, m_colIndex, m_otherIndex);
750   }
751   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i, Index j) const
752   {
753     return m_base_mapper.template loadPacket<Alignment>(i + m_depth_offset, j + m_col_offset);
754   }
755   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar loadCoeffStandard(Index i) const
756   {
757     return m_base_mapper.loadCoeffStandard(i + m_depth_offset, m_rowIndex, m_colIndex,
758                                            m_otherIndex);
759   }
760
761   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacketFast(Index i) const
762   {
763     return m_base_mapper.loadPacketFast(i + m_depth_offset, m_rowIndex, m_colIndex, m_otherIndex);
764   }
765   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacketStandard(Index i) const
766   {
767     typedef decltype(m_base_mapper.m_impl) TensorEvaluatorT;
768     return m_base_mapper.template loadPacketStandard<Packet, TensorEvaluatorT>(
769         i + m_depth_offset, m_rowIndex, m_colIndex, m_otherIndex);
770   }
771   template <typename Packet> EIGEN_DEVICE_FUNC bool aligned(Index) const { return false; }
772
773   EIGEN_DEVICE_FUNC
774   EIGEN_ALWAYS_INLINE bool nonStandardPatches() const { return m_base_mapper.nonStandardPatches(); }
775
776   // Max(Col|Row|Depth): compute the upper limit for the column, row and depth
777   // index respectively that fits into the peeled_k elements starting at
778   // m_depth_offset.
779
780   EIGEN_DEVICE_FUNC
781   EIGEN_ALWAYS_INLINE Index maxCol(const Index peeled_k) const
782   {
783     const Index max_col =
784         (m_depth_offset + (peeled_k == 0 ? 0 : peeled_k - 1)) / fastPatchColStride();
785     return std::min<Index>(1 + max_col, patchCols());
786   }
787
788   EIGEN_DEVICE_FUNC
789   EIGEN_ALWAYS_INLINE Index maxRow(const Index peeled_k, const Index col) const
790   {
791     const Index max_row =
792         (m_depth_offset + (peeled_k == 0 ? 0 : peeled_k - 1) - col * patchColStride()) /
793         fastPatchRowStride();
794     return std::min<Index>(1 + max_row, patchRows());
795   }
796
797   EIGEN_DEVICE_FUNC
798   EIGEN_ALWAYS_INLINE Index maxDepth(const Index peeled_k, const Index col, Index row) const
799   {
800     const Index max_depth = m_depth_offset + peeled_k - //
801                             col * patchColStride() -    //
802                             row * patchRowStride();
803     return std::min<Index>(max_depth, patchDepth());
804   }
805
806   // MaxDepth uses only the remaining number of elements in the peeled_k.
807   EIGEN_DEVICE_FUNC
808   EIGEN_ALWAYS_INLINE Index maxDepth(const Index num_elements, const Index start_depth) const
809   {
810     return std::min<Index>(start_depth + num_elements, patchDepth());
811   }
812
813   // Every register matters in this code, so sometimes to prevent register
814   // spilling, instead of the variable that you would expect to see, we use
815   // another one, that is guaranteed to have the same value. E.g. patch depth is
816   // always the same as input depth, and it's also the same as input row stride.
817   // Bunch of other parameters have similar relations.
818
819   typedef internal::TensorIntDivisor<Index> IndexDivisor;
820
821   EIGEN_DEVICE_FUNC
822   EIGEN_ALWAYS_INLINE Index patchDepth() const { return m_base_mapper.m_rowInputStride; }
823   EIGEN_DEVICE_FUNC
824   EIGEN_ALWAYS_INLINE Index patchRows() const { return m_base_mapper.m_colStride; }
825   EIGEN_DEVICE_FUNC
826   EIGEN_ALWAYS_INLINE Index patchCols() const { return m_base_mapper.m_patch_cols; }
827
828   EIGEN_DEVICE_FUNC
829   EIGEN_ALWAYS_INLINE Index patchRowStride() const
830   {
831     eigen_assert(patchDepth() == m_base_mapper.m_patch_row_stride &&
832                  "Patch depth must be equal to patch row stride.");
833     return patchDepth();
834   }
835   EIGEN_DEVICE_FUNC
836   EIGEN_ALWAYS_INLINE Index patchColStride() const { return m_base_mapper.m_patch_col_stride; }
837
838   EIGEN_DEVICE_FUNC
839   EIGEN_ALWAYS_INLINE IndexDivisor fastPatchRowStride() const
840   {
841     eigen_assert(patchDepth() == m_base_mapper.m_patch_row_stride &&
842                  "Patch depth must be equal to patch row stride.");
843     return m_base_mapper.m_fastDimZero; // patch_depth
844   }
845   EIGEN_DEVICE_FUNC
846   EIGEN_ALWAYS_INLINE IndexDivisor fastPatchColStride() const
847   {
848     return m_base_mapper.m_fastPatchColStride;
849   }
850
851   EIGEN_DEVICE_FUNC
852   EIGEN_ALWAYS_INLINE Packet packetNoPadding(const Index depth, const Index baseIndex) const
853   {
854     const Index inputIndex = depth + baseIndex;
855     return m_base_mapper.m_impl.template packet<Unaligned>(inputIndex);
856   }
857   EIGEN_DEVICE_FUNC
858   EIGEN_ALWAYS_INLINE Scalar coeffNoPadding(const Index depth, const Index baseIndex) const
859   {
860     const Index inputIndex = depth + baseIndex;
861     return m_base_mapper.m_impl.coeff(inputIndex);
862   }
863   template <typename PacketT = Packet>
864   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE typename std::enable_if<
865       TensorEvaluatorHasPartialPacket<TensorEvaluatorT, PacketT, Index>::value, PacketT>::type
866   partialPacketNoPadding(const Index depth, const Index baseIndex, Index num_coeffs) const
867   {
868     const Index inputIndex = depth + baseIndex;
869     return m_base_mapper.m_impl.template partialPacket<PacketT>(inputIndex,
870                                                                 mask<PacketT>(0, num_coeffs));
871   }
872   EIGEN_DEVICE_FUNC
873   EIGEN_ALWAYS_INLINE bool hasPadding() const
874   {
875     // TODO(ezhulenev): It does seems that for inflated filter it's still
876     // possible to guarantee "no padding or skipping" for non-standard packing.
877     if (nonStandardPatches())
878       return true;
879
880     // Non zero padding before.
881     if (m_base_mapper.m_rowPaddingTop > 0)
882       return true;
883     if (m_base_mapper.m_colPaddingLeft > 0)
884       return true;
885
886     // Non zero padding after in rows.
887     const Index last_row = (m_base_mapper.m_outputRows - 1) * m_base_mapper.m_row_strides;
888     if (last_row + (patchRows() - 1) >= m_base_mapper.m_inputRows)
889       return true;
890
891     // Non zero padding after in cols.
892     const Index last_col = (m_base_mapper.m_outputCols - 1) * m_base_mapper.m_col_strides;
893     if (last_col + (patchCols() - 1) >= m_base_mapper.m_inputCols)
894       return true;
895
896     return false;
897   }
898   EIGEN_DEVICE_FUNC
899   EIGEN_ALWAYS_INLINE bool padRow(const Index row) const
900   {
901     const Index r = m_rowIndex + row;
902     return r < 0 || r >= m_base_mapper.m_inputRows;
903   }
904   EIGEN_DEVICE_FUNC
905   EIGEN_ALWAYS_INLINE bool padAnyRow(const Index first_row, const Index last_row) const
906   {
907     return m_rowIndex + first_row < 0 || m_rowIndex + last_row >= m_base_mapper.m_inputRows;
908   }
909   EIGEN_DEVICE_FUNC
910   EIGEN_ALWAYS_INLINE bool padOrSkipRow(const Index row, Index *orig_row) const
911   {
912     eigen_assert(nonStandardPatches());
913
914     const Index input_row = m_rowIndex + row * m_base_mapper.m_in_row_strides;
915     *orig_row = (m_base_mapper.m_patch_row_inflate_strides == 1)
916                     ? input_row
917                     : ((input_row >= 0) ? (input_row / m_base_mapper.m_fastInputRowStride) : 0);
918
919     return (*orig_row < 0 || *orig_row >= m_base_mapper.m_inputRows) ||
920            (input_row != *orig_row * m_base_mapper.m_patch_row_inflate_strides);
921   }
922   EIGEN_DEVICE_FUNC
923   EIGEN_ALWAYS_INLINE bool padCol(const Index col) const
924   {
925     const Index c = m_colIndex + col;
926     return c < 0 || c >= m_base_mapper.m_inputCols;
927   }
928   EIGEN_DEVICE_FUNC
929   EIGEN_ALWAYS_INLINE bool padOrSkipCol(const Index col, Index *orig_col) const
930   {
931     eigen_assert(nonStandardPatches());
932
933     const Index input_col = m_colIndex + col * m_base_mapper.m_in_col_strides;
934     *orig_col = (m_base_mapper.m_patch_col_inflate_strides == 1)
935                     ? input_col
936                     : ((input_col >= 0) ? (input_col / m_base_mapper.m_fastInputColStride) : 0);
937
938     return (*orig_col < 0 || *orig_col >= m_base_mapper.m_inputCols) ||
939            (input_col != *orig_col * m_base_mapper.m_patch_col_inflate_strides);
940   }
941   EIGEN_DEVICE_FUNC
942   EIGEN_ALWAYS_INLINE Index baseIndex(const Index row, const Index col) const
943   {
944     const Index r = m_rowIndex + row;
945     const Index c = m_colIndex + col;
946     return r * m_base_mapper.m_rowInputStride + c * m_base_mapper.m_colInputStride + m_otherIndex;
947   }
948   // Compute a base index when original input row and column were precomputed
949   // using padOrSkipRow and padOrSkipCol. Used only for non standard patches.
950   EIGEN_DEVICE_FUNC
951   EIGEN_ALWAYS_INLINE Index origBaseIndex(const Index orig_row, const Index orig_col) const
952   {
953     return orig_row * m_base_mapper.m_rowInputStride + orig_col * m_base_mapper.m_colInputStride +
954            m_otherIndex;
955   }
956
957   EIGEN_DEVICE_FUNC
958   EIGEN_ALWAYS_INLINE Index rowStride() const { return m_base_mapper.m_row_strides; }
959   EIGEN_DEVICE_FUNC
960   EIGEN_ALWAYS_INLINE Index colStride() const { return m_base_mapper.m_col_strides; }
961
962   EIGEN_DEVICE_FUNC
963   EIGEN_ALWAYS_INLINE Index rowOffset() const
964   {
965     const Index patchOffset = m_depth_offset / m_base_mapper.m_fastDimZero;
966     const Index colOffset = patchOffset / m_base_mapper.m_fastColStride;
967     return patchOffset - colOffset * m_base_mapper.m_colStride;
968   }
969
970   EIGEN_DEVICE_FUNC
971   EIGEN_ALWAYS_INLINE Index colOffset() const
972   {
973     const Index patchOffset = m_depth_offset / m_base_mapper.m_fastDimZero;
974     const Index colOffset = patchOffset / m_base_mapper.m_fastColStride;
975     return colOffset;
976   }
977
978   EIGEN_DEVICE_FUNC
979   EIGEN_ALWAYS_INLINE Index depthOffset() const { return m_depth_offset % patchDepth(); }
980
981   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE LinearMapper getLinearMapper(Index i, Index j) const
982   {
983     return LinearMapper(m_base_mapper, i + m_depth_offset, j + m_col_offset);
984   }
985
986 private:
987   Index m_depth_offset; // First row in the input matrix
988   Index m_col_offset;   // First col in the input matrix
989
990   // Knowing that: col_offset == patchIndex * OTHERS, we keep precomputed base
991   // indices for the first element in a patch specified by col_offset
992   // (see computeBaseIndices(...) for details).
993   Index m_rowIndex;
994   Index m_colIndex;
995   Index m_otherIndex;
996
997   const ParentMapper m_base_mapper; // Keeping a copy instead of a reference
998                                     // performs better in benchmarks.
999 };
1000
1001 // Arrange a block of the right input matrix (in our case it's always a "virtual
1002 // matrix" constructed from extracted image patches) in contiguous memory.
1003 //
1004 // Given column major input (A0 beside A1 in memory):
1005 // A0 B0 C0 D0  E0 F0 G0 H0 ... Z0
1006 // A1 B1 C1 D1  E1 F1 G1 H1 ... Z1
1007 // A2 B2 C2 D2  E2 F2 G2 H2 ... Z2
1008 // A3 B3 C3 D3  E3 F3 G3 H3 ... Z3
1009 // A4 B4 C4 D4  E4 F4 G4 H4 ... Z4
1010 // A5 B5 C5 D5  E5 F5 G5 H5 ... Z5
1011 // A6 B6 C6 D6  E6 F6 G6 H6 ... Z6
1012 // A7 B7 C7 D7  E7 F7 G7 H7 ... Z7
1013 // A8 ...
1014 // ...
1015 //
1016 // *) A, B, C, ... - patches extracted from the original input.
1017 // *) A0, A1, A2 ... - values from the same patch at different offsets.
1018 //
1019 // The traversal (packed rhs memory) order (B0 besides A0 in memory):
1020 // A0 B0 C0 D0 A1 B1 C1 D1 ...
1021 // E0 F0 G0 H0 E1 F1 G1 H1 ...
1022 // ...
1023 // Z0 Z1 Z2 Z3 Z4 Z5 Z6 Z7 ... <- doesn't belong to any block (nr = 4)
1024 //
1025 // This traversal order must be the same as in default gemm_pack_rhs defined in
1026 // GeneralBlockPanelKernel.h.
1027 //
1028 // *) nr - number of registers along the 'n' dimension.
1029 //    See GeneralBlockPanelKernel.h and "Anatomy of High-Performance Matrix
1030 //    Multiplication" paper.
1031 template <typename NewDimension, Index Rows, Index Cols, typename ArgType, typename Device,
1032           typename Scalar, typename Index, typename nocontract_t, typename contract_t,
1033           int packet_size, bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment,
1034           int nr>
1035 struct gemm_pack_rhs<
1036     Scalar, Index,
1037     TensorContractionSubMapper<
1038         Scalar, Index, Rhs,
1039         TensorEvaluator<
1040             const TensorReshapingOp<NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType>>,
1041             Device>,
1042         nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered,
1043         Alignment>,
1044     nr, ColMajor, false, false>
1045 {
1046   typedef TensorContractionSubMapper<
1047       Scalar, Index, Rhs,
1048       TensorEvaluator<
1049           const TensorReshapingOp<NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType>>,
1050           Device>,
1051       nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment>
1052       SubMapper;
1053   typedef SubMapper DataMapper;
1054   typedef typename packet_traits<Scalar>::type Packet;
1055
1056   EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE)
1057
1058   EIGEN_DEVICE_FUNC
1059   EIGEN_DONT_INLINE void operator()(Scalar *block, const DataMapper &rhs, Index depth, Index cols,
1060                                     Index stride = 0, Index offset = 0) const
1061   {
1062     eigen_assert(stride == 0);
1063     eigen_assert(offset == 0);
1064     (void)stride;
1065     (void)offset;
1066
1067     const Index packet_cols4 = (cols / 4) * 4;
1068     const Index peeled_k = (depth / packet_size) * packet_size;
1069     const bool non_standard_patches = rhs.nonStandardPatches();
1070
1071     for (Index j2 = 0; j2 < packet_cols4; j2 += 4)
1072     {
1073       const SubMapper dm0 = rhs.getLinearMapper(0, j2 + 0);
1074       const SubMapper dm1 = rhs.getLinearMapper(0, j2 + 1);
1075       const SubMapper dm2 = rhs.getLinearMapper(0, j2 + 2);
1076       const SubMapper dm3 = rhs.getLinearMapper(0, j2 + 3);
1077
1078       Index k = 0;
1079       if ((packet_size % 4) == 0 && !non_standard_patches)
1080       {
1081         // FAST PATH:
1082         // Iterate over patch columns and rows, if we know that a single
1083         // packet do not span across multiple rows or columns.
1084         if ((rhs.patchDepth() % packet_size) == 0)
1085         {
1086           const Index start_col = rhs.colOffset();
1087           const Index max_col = rhs.maxCol(peeled_k);
1088
1089           for (Index c = start_col; c < max_col; ++c)
1090           {
1091             eigen_assert(k <= peeled_k);
1092
1093             const Index start_row = (c == start_col) ? rhs.rowOffset() : 0;
1094             const Index max_row = rhs.maxRow(peeled_k, c);
1095
1096             const bool pad_col0 = dm0.padCol(c);
1097             const bool pad_col1 = dm1.padCol(c);
1098             const bool pad_col2 = dm2.padCol(c);
1099             const bool pad_col3 = dm3.padCol(c);
1100
1101             // Check if we can squeeze reads along the `row` and `depth`
1102             // dimensions (two innermost dimensions).
1103             if (!pad_col0 && !pad_col1 && !pad_col2 && !pad_col3 &&   //
1104                 !dm0.padRow(start_row) && !dm0.padRow(max_row - 1) && //
1105                 !dm1.padRow(start_row) && !dm1.padRow(max_row - 1) && //
1106                 !dm2.padRow(start_row) && !dm2.padRow(max_row - 1) && //
1107                 !dm3.padRow(start_row) && !dm3.padRow(max_row - 1))
1108             {
1109               // Compute how many elements we can squeeze read.
1110               const Index start_depth = (c == start_col) ? rhs.depthOffset() : 0;
1111
1112               // Upper bound for the number of elements in the depth dimension
1113               // that we can squeeze read.
1114               const Index squeeze_length = (max_row - start_row) * rhs.patchDepth() - start_depth;
1115
1116               // Do not overshoot beyond the block size.
1117               const Index max_depth = start_depth + std::min<Index>(peeled_k - k, squeeze_length);
1118               eigen_assert((max_depth - start_depth) % packet_size == 0);
1119
1120               const Index idx0 = dm0.baseIndex(start_row, c);
1121               const Index idx1 = dm1.baseIndex(start_row, c);
1122               const Index idx2 = dm2.baseIndex(start_row, c);
1123               const Index idx3 = dm3.baseIndex(start_row, c);
1124
1125               for (Index d = start_depth; d < max_depth; d += packet_size)
1126               {
1127                 eigen_assert(k < peeled_k);
1128                 PacketBlock<Packet, 4> kernel;
1129                 kernel.packet[0] = rhs.packetNoPadding(d, idx0);
1130                 kernel.packet[1] = rhs.packetNoPadding(d, idx1);
1131                 kernel.packet[2] = rhs.packetNoPadding(d, idx2);
1132                 kernel.packet[3] = rhs.packetNoPadding(d, idx3);
1133                 ptranspose(kernel);
1134                 pstoreu(block + 0 * packet_size, kernel.packet[0]);
1135                 pstoreu(block + 1 * packet_size, kernel.packet[1]);
1136                 pstoreu(block + 2 * packet_size, kernel.packet[2]);
1137                 pstoreu(block + 3 * packet_size, kernel.packet[3]);
1138                 block += 4 * packet_size;
1139                 k += packet_size;
1140               }
1141
1142               // Go to the next column.
1143               continue;
1144             }
1145
1146             // If we can't squeeze reads, process rows one by one.
1147             for (Index r = start_row; r < max_row; ++r)
1148             {
1149               eigen_assert(k <= peeled_k);
1150
1151               const bool pad0 = pad_col0 || dm0.padRow(r);
1152               const bool pad1 = pad_col1 || dm1.padRow(r);
1153               const bool pad2 = pad_col2 || dm2.padRow(r);
1154               const bool pad3 = pad_col3 || dm3.padRow(r);
1155
1156               const Index idx0 = dm0.baseIndex(r, c);
1157               const Index idx1 = dm1.baseIndex(r, c);
1158               const Index idx2 = dm2.baseIndex(r, c);
1159               const Index idx3 = dm3.baseIndex(r, c);
1160
1161               const Index start_depth =
1162                   ((c == start_col) && (r == start_row)) ? rhs.depthOffset() : 0;
1163               const Index max_depth = rhs.maxDepth(peeled_k - k, start_depth);
1164               eigen_assert((max_depth - start_depth) % packet_size == 0);
1165
1166               for (Index d = start_depth; d < max_depth; d += packet_size)
1167               {
1168                 eigen_assert(k < peeled_k);
1169                 PacketBlock<Packet, 4> kernel;
1170                 kernel.packet[0] = pad0 ? pset1<Packet>(Scalar(0)) : rhs.packetNoPadding(d, idx0);
1171                 kernel.packet[1] = pad1 ? pset1<Packet>(Scalar(0)) : rhs.packetNoPadding(d, idx1);
1172                 kernel.packet[2] = pad2 ? pset1<Packet>(Scalar(0)) : rhs.packetNoPadding(d, idx2);
1173                 kernel.packet[3] = pad3 ? pset1<Packet>(Scalar(0)) : rhs.packetNoPadding(d, idx3);
1174                 ptranspose(kernel);
1175                 pstoreu(block + 0 * packet_size, kernel.packet[0]);
1176                 pstoreu(block + 1 * packet_size, kernel.packet[1]);
1177                 pstoreu(block + 2 * packet_size, kernel.packet[2]);
1178                 pstoreu(block + 3 * packet_size, kernel.packet[3]);
1179                 block += 4 * packet_size;
1180                 k += packet_size;
1181               }
1182             }
1183           }
1184
1185           // The loop above should fill peeled_k elements.
1186           eigen_assert(peeled_k == k);
1187         }
1188         else
1189         {
1190           for (; k < peeled_k; k += packet_size)
1191           {
1192             PacketBlock<Packet, 4> kernel;
1193             kernel.packet[0] = dm0.loadPacketStandard(k);
1194             kernel.packet[1] = dm1.loadPacketStandard(k);
1195             kernel.packet[2] = dm2.loadPacketStandard(k);
1196             kernel.packet[3] = dm3.loadPacketStandard(k);
1197             ptranspose(kernel);
1198             pstoreu(block + 0 * packet_size, kernel.packet[0]);
1199             pstoreu(block + 1 * packet_size, kernel.packet[1]);
1200             pstoreu(block + 2 * packet_size, kernel.packet[2]);
1201             pstoreu(block + 3 * packet_size, kernel.packet[3]);
1202             block += 4 * packet_size;
1203           }
1204         }
1205       }
1206
1207       // Copy the remaining coefficients of the column block after the peeled_k.
1208       if (!rhs.nonStandardPatches())
1209       {
1210         for (; k < depth; k++)
1211         {
1212           block[0] = dm0.loadCoeffStandard(k);
1213           block[1] = dm1.loadCoeffStandard(k);
1214           block[2] = dm2.loadCoeffStandard(k);
1215           block[3] = dm3.loadCoeffStandard(k);
1216           block += 4;
1217         }
1218       }
1219       else
1220       {
1221         for (; k < depth; k++)
1222         {
1223           block[0] = dm0(k);
1224           block[1] = dm1(k);
1225           block[2] = dm2(k);
1226           block[3] = dm3(k);
1227           block += 4;
1228         }
1229       }
1230     }
1231
1232     // copy the remaining columns one at a time (nr==1)
1233     for (Index j2 = packet_cols4; j2 < cols; ++j2)
1234     {
1235       const SubMapper dm0 = rhs.getLinearMapper(0, j2);
1236       for (Index k = 0; k < depth; k++)
1237       {
1238         *block = dm0(k);
1239         block += 1;
1240       }
1241     }
1242   }
1243 };
1244
1245 // Template specialization for packet_size = 2. We must special-case packet
1246 // blocks with nr > packet_size, e.g. PacketBlock<Packet2d, 4>.
1247 template <typename NewDimension, Index Rows, Index Cols, typename ArgType, typename Device,
1248           typename Scalar, typename Index, typename nocontract_t, typename contract_t,
1249           bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment, int nr>
1250 struct gemm_pack_rhs<
1251     Scalar, Index,
1252     TensorContractionSubMapper<
1253         Scalar, Index, Rhs,
1254         TensorEvaluator<
1255             const TensorReshapingOp<NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType>>,
1256             Device>,
1257         nocontract_t, contract_t, 2, inner_dim_contiguous, inner_dim_reordered, Alignment>,
1258     nr, ColMajor, false, false>
1259 {
1260   typedef TensorContractionSubMapper<
1261       Scalar, Index, Rhs,
1262       TensorEvaluator<
1263           const TensorReshapingOp<NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType>>,
1264           Device>,
1265       nocontract_t, contract_t, 2, inner_dim_contiguous, inner_dim_reordered, Alignment>
1266       SubMapper;
1267   typedef SubMapper DataMapper;
1268   typedef typename packet_traits<Scalar>::type Packet;
1269
1270   EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE)
1271
1272   EIGEN_DEVICE_FUNC
1273   EIGEN_DONT_INLINE void operator()(Scalar *block, const DataMapper &rhs, Index depth, Index cols,
1274                                     Index stride = 0, Index offset = 0) const
1275   {
1276     eigen_assert(stride == 0);
1277     eigen_assert(offset == 0);
1278
1279     (void)stride;
1280     (void)offset;
1281
1282     const int packet_size = 2;
1283     const Index packet_cols4 = (cols / 4) * 4;
1284     const Index peeled_k = (depth / packet_size) * packet_size;
1285     const bool non_standard_patches = rhs.nonStandardPatches();
1286
1287     for (Index j2 = 0; j2 < packet_cols4; j2 += 4)
1288     {
1289       const SubMapper dm0 = rhs.getLinearMapper(0, j2 + 0);
1290       const SubMapper dm1 = rhs.getLinearMapper(0, j2 + 1);
1291       const SubMapper dm2 = rhs.getLinearMapper(0, j2 + 2);
1292       const SubMapper dm3 = rhs.getLinearMapper(0, j2 + 3);
1293
1294       Index k = 0;
1295       if (!non_standard_patches)
1296       {
1297         // FAST PATH:
1298         // Iterate over patch columns and rows if we know that a single
1299         // packet do not span across multiple rows or columns.
1300         if ((rhs.patchDepth() % packet_size) == 0)
1301         {
1302           const Index start_col = rhs.colOffset();
1303           const Index max_col = rhs.maxCol(peeled_k);
1304
1305           for (Index c = start_col; c < max_col; ++c)
1306           {
1307             eigen_assert(k <= peeled_k);
1308
1309             const Index start_row = (c == start_col) ? rhs.rowOffset() : 0;
1310             const Index max_row = rhs.maxRow(peeled_k, c);
1311
1312             const bool pad_col0 = dm0.padCol(c);
1313             const bool pad_col1 = dm1.padCol(c);
1314             const bool pad_col2 = dm2.padCol(c);
1315             const bool pad_col3 = dm3.padCol(c);
1316
1317             // We can squeeze reads along the `row` and `depth` dimensions if
1318             // the row stride is `1`, which means that `row` and `depth`
1319             // dimensions are contiguous (two innermost dimensions).
1320             if (rhs.rowStride() == 1 &&                               //
1321                 !pad_col0 && !pad_col1 && !pad_col2 && !pad_col3 &&   //
1322                 !dm0.padRow(start_row) && !dm0.padRow(max_row - 1) && //
1323                 !dm1.padRow(start_row) && !dm1.padRow(max_row - 1) && //
1324                 !dm2.padRow(start_row) && !dm2.padRow(max_row - 1) && //
1325                 !dm3.padRow(start_row) && !dm3.padRow(max_row - 1))
1326             {
1327               // Compute how many elements we can squeeze read.
1328               const Index start_depth = (c == start_col) ? rhs.depthOffset() : 0;
1329
1330               // Upper bound for the number of elements in the depth dimension
1331               // that we can squeeze read.
1332               const Index squeeze_length = (max_row - start_row) * rhs.patchDepth() - start_depth;
1333
1334               // Do not overshoot beyond the block size.
1335               const Index max_depth = start_depth + std::min<Index>(peeled_k - k, squeeze_length);
1336               eigen_assert((max_depth - start_depth) % packet_size == 0);
1337
1338               const Index idx0 = dm0.baseIndex(start_row, c);
1339               const Index idx1 = dm1.baseIndex(start_row, c);
1340               const Index idx2 = dm2.baseIndex(start_row, c);
1341               const Index idx3 = dm3.baseIndex(start_row, c);
1342
1343               for (Index d = start_depth; d < max_depth; d += packet_size)
1344               {
1345                 PacketBlock<Packet, 2> kernel0;
1346                 PacketBlock<Packet, 2> kernel1;
1347                 kernel0.packet[0] = rhs.packetNoPadding(d, idx0);
1348                 kernel0.packet[1] = rhs.packetNoPadding(d, idx1);
1349                 kernel1.packet[0] = rhs.packetNoPadding(d, idx2);
1350                 kernel1.packet[1] = rhs.packetNoPadding(d, idx3);
1351                 ptranspose(kernel0);
1352                 ptranspose(kernel1);
1353                 pstoreu(block + 0 * packet_size, kernel0.packet[0]);
1354                 pstoreu(block + 1 * packet_size, kernel1.packet[0]);
1355                 pstoreu(block + 2 * packet_size, kernel0.packet[1]);
1356                 pstoreu(block + 3 * packet_size, kernel1.packet[1]);
1357                 block += 4 * packet_size;
1358                 k += packet_size;
1359               }
1360
1361               // Go to the next column.
1362               continue;
1363             }
1364
1365             // If we can't squeeze reads, process rows one by one.
1366             for (Index r = start_row; r < max_row; ++r)
1367             {
1368               eigen_assert(k <= peeled_k);
1369
1370               const bool pad0 = pad_col0 || dm0.padRow(r);
1371               const bool pad1 = pad_col1 || dm1.padRow(r);
1372               const bool pad2 = pad_col2 || dm2.padRow(r);
1373               const bool pad3 = pad_col3 || dm3.padRow(r);
1374
1375               const Index idx0 = dm0.baseIndex(r, c);
1376               const Index idx1 = dm1.baseIndex(r, c);
1377               const Index idx2 = dm2.baseIndex(r, c);
1378               const Index idx3 = dm3.baseIndex(r, c);
1379
1380               const Index start_depth =
1381                   ((c == start_col) && (r == start_row)) ? rhs.depthOffset() : 0;
1382               const Index max_depth = rhs.maxDepth(peeled_k - k, start_depth);
1383               eigen_assert((max_depth - start_depth) % packet_size == 0);
1384
1385               for (Index d = start_depth; d < max_depth; d += packet_size)
1386               {
1387                 eigen_assert(k < peeled_k);
1388                 PacketBlock<Packet, 2> kernel0;
1389                 PacketBlock<Packet, 2> kernel1;
1390                 kernel0.packet[0] = pad0 ? pset1<Packet>(Scalar(0)) : rhs.packetNoPadding(d, idx0);
1391                 kernel0.packet[1] = pad1 ? pset1<Packet>(Scalar(0)) : rhs.packetNoPadding(d, idx1);
1392                 kernel1.packet[0] = pad2 ? pset1<Packet>(Scalar(0)) : rhs.packetNoPadding(d, idx2);
1393                 kernel1.packet[1] = pad3 ? pset1<Packet>(Scalar(0)) : rhs.packetNoPadding(d, idx3);
1394                 ptranspose(kernel0);
1395                 ptranspose(kernel1);
1396                 pstoreu(block + 0 * packet_size, kernel0.packet[0]);
1397                 pstoreu(block + 1 * packet_size, kernel1.packet[0]);
1398                 pstoreu(block + 2 * packet_size, kernel0.packet[1]);
1399                 pstoreu(block + 3 * packet_size, kernel1.packet[1]);
1400                 block += 4 * packet_size;
1401                 k += packet_size;
1402               }
1403             }
1404           }
1405
1406           // The loop above should fill peeled_k elements.
1407           eigen_assert(peeled_k == k);
1408         }
1409         else
1410         {
1411           // Packet can span multiple rows or columns, so we have to go
1412           // though the slower "standard" path.
1413           for (; k < peeled_k; k += packet_size)
1414           {
1415             PacketBlock<Packet, 2> kernel0;
1416             PacketBlock<Packet, 2> kernel1;
1417             kernel0.packet[0] = dm0.loadPacketStandard(k);
1418             kernel0.packet[1] = dm1.loadPacketStandard(k);
1419             kernel1.packet[0] = dm2.loadPacketStandard(k);
1420             kernel1.packet[1] = dm3.loadPacketStandard(k);
1421             ptranspose(kernel0);
1422             ptranspose(kernel1);
1423             pstoreu(block + 0 * packet_size, kernel0.packet[0]);
1424             pstoreu(block + 1 * packet_size, kernel1.packet[0]);
1425             pstoreu(block + 2 * packet_size, kernel0.packet[1]);
1426             pstoreu(block + 3 * packet_size, kernel1.packet[1]);
1427             block += 4 * packet_size;
1428           }
1429         }
1430       }
1431
1432       // Copy the remaining coefficients of the column block after the peeled_k.
1433       if (!non_standard_patches)
1434       {
1435         for (; k < depth; k++)
1436         {
1437           block[0] = dm0.loadCoeffStandard(k);
1438           block[1] = dm1.loadCoeffStandard(k);
1439           block[2] = dm2.loadCoeffStandard(k);
1440           block[3] = dm3.loadCoeffStandard(k);
1441           block += 4;
1442         }
1443       }
1444       else
1445       {
1446         for (; k < depth; k++)
1447         {
1448           block[0] = dm0(k);
1449           block[1] = dm1(k);
1450           block[2] = dm2(k);
1451           block[3] = dm3(k);
1452           block += 4;
1453         }
1454       }
1455     }
1456
1457     // Copy the remaining columns one at a time (nr==1).
1458     for (Index j2 = packet_cols4; j2 < cols; ++j2)
1459     {
1460       const SubMapper dm0 = rhs.getLinearMapper(0, j2);
1461       for (Index k = 0; k < depth; k++)
1462       {
1463         *block = dm0(k);
1464         block += 1;
1465       }
1466     }
1467   }
1468 };
1469
1470 // Special case for non-vectorized types such as float16.
1471 template <typename NewDimension, Index Rows, Index Cols, typename ArgType, typename Device,
1472           typename Scalar, typename Index, typename nocontract_t, typename contract_t,
1473           bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment, int nr>
1474 struct gemm_pack_rhs<
1475     Scalar, Index,
1476     TensorContractionSubMapper<
1477         Scalar, Index, Rhs,
1478         TensorEvaluator<
1479             const TensorReshapingOp<NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType>>,
1480             Device>,
1481         nocontract_t, contract_t, 1, inner_dim_contiguous, inner_dim_reordered, Alignment>,
1482     nr, ColMajor, false, false>
1483 {
1484   typedef TensorContractionSubMapper<
1485       Scalar, Index, Rhs,
1486       TensorEvaluator<
1487           const TensorReshapingOp<NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType>>,
1488           Device>,
1489       nocontract_t, contract_t, 1, inner_dim_contiguous, inner_dim_reordered, Alignment>
1490       SubMapper;
1491   typedef SubMapper DataMapper;
1492
1493   EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE)
1494
1495   EIGEN_DEVICE_FUNC
1496   EIGEN_DONT_INLINE void operator()(Scalar *block, const DataMapper &rhs, Index depth, Index cols,
1497                                     Index stride = 0, Index offset = 0) const
1498   {
1499     eigen_assert(stride == 0);
1500     eigen_assert(offset == 0);
1501
1502     (void)offset;
1503     (void)stride;
1504
1505     const Index packet_cols4 = (cols / 4) * 4;
1506
1507     for (Index j2 = 0; j2 < packet_cols4; j2 += 4)
1508     {
1509       const SubMapper dm0 = rhs.getLinearMapper(0, j2 + 0);
1510       const SubMapper dm1 = rhs.getLinearMapper(0, j2 + 1);
1511       const SubMapper dm2 = rhs.getLinearMapper(0, j2 + 2);
1512       const SubMapper dm3 = rhs.getLinearMapper(0, j2 + 3);
1513
1514       if (!rhs.nonStandardPatches())
1515       {
1516         for (Index k = 0; k < depth; k++)
1517         {
1518           block[0] = dm0.loadCoeffStandard(k);
1519           block[1] = dm1.loadCoeffStandard(k);
1520           block[2] = dm2.loadCoeffStandard(k);
1521           block[3] = dm3.loadCoeffStandard(k);
1522           block += 4;
1523         }
1524       }
1525       else
1526       {
1527         for (Index k = 0; k < depth; k++)
1528         {
1529           block[0] = dm0(k);
1530           block[1] = dm1(k);
1531           block[2] = dm2(k);
1532           block[3] = dm3(k);
1533           block += 4;
1534         }
1535       }
1536     }
1537
1538     // Copy the remaining columns one at a time (nr==1).
1539     for (Index j2 = packet_cols4; j2 < cols; ++j2)
1540     {
1541       const SubMapper dm0 = rhs.getLinearMapper(0, j2);
1542       for (Index k = 0; k < depth; k++)
1543       {
1544         *block = dm0(k);
1545         block += 1;
1546       }
1547     }
1548   }
1549 };
1550 } // end namespace internal
1551
1552 /** SpatialConvolution
1553  * \ingroup CXX11_NeuralNetworks_Module
1554  *
1555  * \brief Applies a 2D convolution over a multichannel input image.
1556  *
1557  * The input parameter is expected to be a tensor with a rank of 3 or more
1558  * (channels, height, width, and optionally others)
1559  * The kernel parameter is expected to be a 4D tensor (filters, channels,
1560  * kernel_height, kernel_width)
1561  * The input and the kernel must both be in col-major layout. The result will
1562  * also be in col-major layout.
1563  *
1564  * If col_in_stride, row_in_stride > 1, then applies convolution with holes
1565  * (aka atrous convolution), sampling every col_in_stride, row_in_stride input
1566  * pixels.
1567  *
1568  * If padding_top, padding_bottom, padding_left, or padding_right is specified,
1569  * then those paddings will be used to pad the input, and padding_type must be
1570  * PADDING_VALID.
1571  *
1572  * The result can be assigned to a tensor of rank equal to the rank of the
1573  * input. The dimensions of the result will be filters, height, width (and
1574  * others if applicable).
1575  *
1576  * It is possible to swap the order of the width and height dimensions provided
1577  * that the same order is used in the input, the kernel, and the output.
1578  *
1579  * It is also possible to add an output kernel to the contraction, output
1580  * kernel is called by Eigen when it "finalizes" the block of an output tensor.
1581  *
1582  */
1583 template <typename Input, typename Kernel, typename OutputKernel = const NoOpOutputKernel>
1584 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE static const typename internal::conditional<
1585     internal::traits<Input>::Layout == ColMajor,
1586     TensorReshapingOp<
1587         const DSizes<typename internal::traits<Input>::Index,
1588                      internal::traits<Input>::NumDimensions>,
1589         const TensorContractionOp<
1590             const array<IndexPair<typename internal::traits<Input>::Index>, 1>,
1591             const TensorReshapingOp<const DSizes<typename internal::traits<Input>::Index, 2>,
1592                                     const Kernel>,
1593             const TensorReshapingOp<const DSizes<typename internal::traits<Input>::Index, 2>,
1594                                     const TensorImagePatchOp<Dynamic, Dynamic, const Input>>,
1595             const OutputKernel>>,
1596     TensorReshapingOp<
1597         const DSizes<typename internal::traits<Input>::Index,
1598                      internal::traits<Input>::NumDimensions>,
1599         const TensorContractionOp<
1600             const array<IndexPair<typename internal::traits<Input>::Index>, 1>,
1601             const TensorReshapingOp<const DSizes<typename internal::traits<Input>::Index, 2>,
1602                                     const TensorImagePatchOp<Dynamic, Dynamic, const Input>>,
1603             const TensorReshapingOp<const DSizes<typename internal::traits<Input>::Index, 2>,
1604                                     const Kernel>,
1605             const OutputKernel>>>::type
1606 SpatialConvolution(const Input &input, const Kernel &kernel, const Index row_stride = 1,
1607                    const Index col_stride = 1, const PaddingType padding_type = PADDING_SAME,
1608                    const Index row_in_stride = 1, const Index col_in_stride = 1,
1609                    const OutputKernel &output_kernel = OutputKernel(), Index padding_top = 0,
1610                    Index padding_bottom = 0, Index padding_left = 0, Index padding_right = 0)
1611 {
1612   typedef typename internal::traits<Input>::Index TensorIndex;
1613   TensorRef<Tensor<typename internal::traits<Input>::Scalar, internal::traits<Input>::NumDimensions,
1614                    internal::traits<Input>::Layout, TensorIndex>>
1615       in(input);
1616   TensorRef<
1617       Tensor<typename internal::traits<Kernel>::Scalar, internal::traits<Kernel>::NumDimensions,
1618              internal::traits<Kernel>::Layout, TensorIndex>>
1619       kern(kernel);
1620
1621   EIGEN_STATIC_ASSERT(internal::traits<Input>::Layout == internal::traits<Kernel>::Layout,
1622                       YOU_MADE_A_PROGRAMMING_MISTAKE)
1623   const bool isColMajor = (internal::traits<Input>::Layout == ColMajor);
1624
1625   const int NumDims = internal::traits<Input>::NumDimensions;
1626
1627   // Number of filters to apply. This is the same as the output depth of the
1628   // result
1629   const TensorIndex kernelFilters = isColMajor ? kern.dimensions()[0] : kern.dimensions()[3];
1630   // Number of channels. This is the same as the input depth.
1631   const TensorIndex kernelChannels = isColMajor ? kern.dimensions()[1] : kern.dimensions()[2];
1632   const TensorIndex kernelRows = isColMajor ? kern.dimensions()[2] : kern.dimensions()[1];
1633   const TensorIndex kernelCols = isColMajor ? kern.dimensions()[3] : kern.dimensions()[0];
1634
1635   const Index kernelRowsEff = kernelRows + (kernelRows - 1) * (row_in_stride - 1);
1636   const Index kernelColsEff = kernelCols + (kernelCols - 1) * (col_in_stride - 1);
1637
1638   array<IndexPair<TensorIndex>, 1> contract_dims;
1639   contract_dims[0] = IndexPair<TensorIndex>(1, 0);
1640
1641   const TensorIndex InputRows = isColMajor ? in.dimension(1) : in.dimension(NumDims - 2);
1642   const TensorIndex InputCols = isColMajor ? in.dimension(2) : in.dimension(NumDims - 3);
1643   const bool padding_explicit = (padding_top || padding_bottom || padding_left || padding_right);
1644
1645   TensorIndex out_height;
1646   TensorIndex out_width;
1647   switch (padding_type)
1648   {
1649     case PADDING_VALID:
1650     {
1651       const TensorIndex InputRowsEff = InputRows + padding_top + padding_bottom;
1652       const TensorIndex InputColsEff = InputCols + padding_left + padding_right;
1653       out_height = divup(InputRowsEff - kernelRowsEff + 1, row_stride);
1654       out_width = divup(InputColsEff - kernelColsEff + 1, col_stride);
1655       break;
1656     }
1657     case PADDING_SAME:
1658     {
1659       eigen_assert(!padding_explicit);
1660       out_height = divup(InputRows, row_stride);
1661       out_width = divup(InputCols, col_stride);
1662       break;
1663     }
1664     default:
1665     {
1666       // Initialize unused variables to avoid a compiler warning
1667       out_height = 0;
1668       out_width = 0;
1669       eigen_assert(false && "unexpected padding");
1670     }
1671   }
1672
1673   // Molds the output of the patch extraction code into a 2d tensor:
1674   // - the first dimension (dims[0]): the patch values to be multiplied with the
1675   // kernels
1676   // - the second dimension (dims[1]): everything else
1677   DSizes<TensorIndex, 2> pre_contract_dims;
1678   if (isColMajor)
1679   {
1680     pre_contract_dims[0] = kernelChannels * kernelRows * kernelCols;
1681     pre_contract_dims[1] = out_height * out_width;
1682     for (int i = 3; i < NumDims; ++i)
1683     {
1684       pre_contract_dims[1] *= in.dimension(i);
1685     }
1686   }
1687   else
1688   {
1689     pre_contract_dims[1] = kernelChannels * kernelRows * kernelCols;
1690     pre_contract_dims[0] = out_height * out_width;
1691     for (int i = 0; i < NumDims - 3; ++i)
1692     {
1693       pre_contract_dims[0] *= in.dimension(i);
1694     }
1695   }
1696
1697   // Molds the output of the contraction into the shape expected by the used
1698   // (assuming this is ColMajor):
1699   // - 1st dim: kernel filters
1700   // - 2nd dim: output height
1701   // - 3rd dim: output width
1702   // - 4th dim and beyond: everything else including batch size
1703   DSizes<TensorIndex, NumDims> post_contract_dims;
1704   if (isColMajor)
1705   {
1706     post_contract_dims[0] = kernelFilters;
1707     post_contract_dims[1] = out_height;
1708     post_contract_dims[2] = out_width;
1709     for (int i = 3; i < NumDims; ++i)
1710     {
1711       post_contract_dims[i] = in.dimension(i);
1712     }
1713   }
1714   else
1715   {
1716     post_contract_dims[NumDims - 1] = kernelFilters;
1717     post_contract_dims[NumDims - 2] = out_height;
1718     post_contract_dims[NumDims - 3] = out_width;
1719     for (int i = 0; i < NumDims - 3; ++i)
1720     {
1721       post_contract_dims[i] = in.dimension(i);
1722     }
1723   }
1724
1725   DSizes<TensorIndex, 2> kernel_dims;
1726   if (isColMajor)
1727   {
1728     kernel_dims[0] = kernelFilters;
1729     kernel_dims[1] = kernelChannels * kernelRows * kernelCols;
1730   }
1731   else
1732   {
1733     kernel_dims[0] = kernelChannels * kernelRows * kernelCols;
1734     kernel_dims[1] = kernelFilters;
1735   }
1736   if (padding_explicit)
1737   {
1738     return choose(
1739         Cond<internal::traits<Input>::Layout == ColMajor>(),
1740         kernel.reshape(kernel_dims)
1741             .contract(input
1742                           .extract_image_patches(kernelRows, kernelCols, row_stride, col_stride,
1743                                                  row_in_stride, col_in_stride,
1744                                                  /*row_inflate_stride=*/1,
1745                                                  /*col_inflate_stride=*/1, padding_top,
1746                                                  padding_bottom, padding_left, padding_right,
1747                                                  /*padding_value=*/0)
1748                           .reshape(pre_contract_dims),
1749                       contract_dims, output_kernel)
1750             .reshape(post_contract_dims),
1751         input
1752             .extract_image_patches(
1753                 kernelRows, kernelCols, row_stride, col_stride, row_in_stride, col_in_stride,
1754                 /*row_inflate_stride=*/1,
1755                 /*col_inflate_stride=*/1, padding_top, padding_bottom, padding_left, padding_right,
1756                 /*padding_value=*/0)
1757             .reshape(pre_contract_dims)
1758             .contract(kernel.reshape(kernel_dims), contract_dims, output_kernel)
1759             .reshape(post_contract_dims));
1760   }
1761   else
1762   {
1763     return choose(
1764         Cond<internal::traits<Input>::Layout == ColMajor>(),
1765         kernel.reshape(kernel_dims)
1766             .contract(input
1767                           .extract_image_patches(kernelRows, kernelCols, row_stride, col_stride,
1768                                                  row_in_stride, col_in_stride, padding_type)
1769                           .reshape(pre_contract_dims),
1770                       contract_dims, output_kernel)
1771             .reshape(post_contract_dims),
1772         input
1773             .extract_image_patches(kernelRows, kernelCols, row_stride, col_stride, row_in_stride,
1774                                    col_in_stride, padding_type)
1775             .reshape(pre_contract_dims)
1776             .contract(kernel.reshape(kernel_dims), contract_dims, output_kernel)
1777             .reshape(post_contract_dims));
1778   }
1779 }
1780
1781 } // end namespace Eigen
1782
1783 #endif // __NNFW_CKER_EIGEN_EIGEN_SPATIAL_CONVOLUTIONS_INL_H__