Imported Upstream version 1.12.0
[platform/core/ml/nnfw.git] / compute / cker / include / cker / eigen / eigen_spatial_convolutions-inl.h
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *      http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 #ifndef __NNFW_CKER_EIGEN_EIGEN_SPATIAL_CONVOLUTIONS_INL_H__
19 #define __NNFW_CKER_EIGEN_EIGEN_SPATIAL_CONVOLUTIONS_INL_H__
20
21 #include "cker/eigen/eigen_convolution_helpers.h"
22
23 // Note this header is used in both TF and TFLite.
24 namespace Eigen
25 {
26
27 namespace internal
28 {
29
30 // WARNING: Most of the code here implicitly assumes that the matrix is in
31 // ColMajor layout. This is guaranteed by the tensor contraction (see
32 // TensorContraction.h).
33 //
34 // Inside Eigen a tensor contraction is represented by a matrix multiplication.
35 // We don't want to actually extract image patches and reshape the result into
36 // a matrix (this involves allocating huge extra memory), so the patch
37 // extraction and reshape operations are implicit.
38 //
39 // TensorContractionInputMapper takes a matrix index and returns the coefficient
40 // (or the packet) of the "virtual tensor", that would be at that index if we
41 // were to actually reshape the result of patch extraction.
42 //
43 // TensorContractionSubMapper provides a similar view into the "virtual matrix"
44 // at the given vertical and horizontal offsets.
45 //
46 // "Virtual matrix" dimensions:
47 //   *0: kernelChannels * kernelRows * kernelCols;
48 //    1: out_height * out_width; * OTHERS (e.g batches, etc...)
49 //
50 // *) extracted patches are continuous in memory (innermost dimension assuming
51 //    col major layout)
52 //
53 // With this dimensions:
54 //   row - offset within a single patch (in code: patchId)
55 //   col - index of the extracted patch (in code: patchIndex)
56 //         patchIndex ∈ [0..num_patches * OTHERS] (batch and other dimensions)
57 //
58 // TODO(ezhulenev): Consolidate this part of the code with the image patch
59 // extraction code since they are both very similar.
60
61 template <typename NewDimension, Index Rows, Index Cols, typename ArgType, typename Device,
62           typename Scalar_, typename Index, typename nocontract_t, typename contract_t, int Side,
63           int packet_size, bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment>
64 class TensorContractionInputMapper<
65   Scalar_, Index, Side,
66   TensorEvaluator<
67     const TensorReshapingOp<NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType>>, Device>,
68   nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment>
69 {
70 public:
71   typedef Scalar_ Scalar;
72
73   typedef TensorContractionInputMapper<
74     Scalar, Index, Side,
75     TensorEvaluator<
76       const TensorReshapingOp<NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType>>, Device>,
77     nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment>
78     Self;
79
80   typedef TensorContractionSubMapper<
81     Scalar, Index, Side,
82     TensorEvaluator<
83       const TensorReshapingOp<NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType>>, Device>,
84     nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment>
85     SubMapper;
86
87   typedef SubMapper VectorMapper;
88   typedef SubMapper LinearMapper;
89   typedef typename packet_traits<Scalar>::type Packet;
90
91   typedef TensorEvaluator<ArgType, Device> TensorEvaluatorT;
92
93   EIGEN_DEVICE_FUNC
94   TensorContractionInputMapper(
95     const TensorEvaluator<
96       const TensorReshapingOp<NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType>>, Device>
97       &tensor,
98     const nocontract_t &, const nocontract_t &, const contract_t &, const contract_t &)
99     : m_impl(tensor.impl().impl())
100   {
101     Index patch_rows;
102     Index patch_depth;
103     if (internal::traits<ArgType>::Layout == ColMajor)
104     {
105       patch_depth = tensor.impl().dimensions()[0];
106       patch_rows = tensor.impl().dimensions()[1];
107       m_patch_cols = tensor.impl().dimensions()[2];
108       m_num_patches = tensor.impl().dimensions()[3];
109     }
110     else
111     {
112       const size_t NumDims = tensor.impl().dimensions().size();
113       patch_depth = tensor.impl().dimensions()[NumDims - 1];
114       patch_rows = tensor.impl().dimensions()[NumDims - 2];
115       m_patch_cols = tensor.impl().dimensions()[NumDims - 3];
116       m_num_patches = tensor.impl().dimensions()[NumDims - 4];
117     }
118
119     // Strides for navigating through the single patch.
120     m_patch_row_stride = patch_depth;
121     m_patch_col_stride = patch_rows * m_patch_row_stride;
122
123     m_patch_row_inflate_strides = tensor.impl().rowInflateStride();
124     m_patch_col_inflate_strides = tensor.impl().colInflateStride();
125
126     m_colStride = patch_rows;
127
128     m_outputRows = tensor.impl().outputRows();
129     m_outputCols = tensor.impl().outputCols();
130     m_row_strides = tensor.impl().userRowStride();
131     m_col_strides = tensor.impl().userColStride();
132
133     m_in_row_strides = tensor.impl().userInRowStride();
134     m_in_col_strides = tensor.impl().userInColStride();
135
136     if (internal::traits<ArgType>::Layout == ColMajor)
137     {
138       m_inputRows = tensor.impl().impl().dimensions()[1];
139       m_inputCols = tensor.impl().impl().dimensions()[2];
140     }
141     else
142     {
143       const int NumDims = tensor.impl().impl().dimensions().size();
144       m_inputRows = tensor.impl().impl().dimensions()[NumDims - 2];
145       m_inputCols = tensor.impl().impl().dimensions()[NumDims - 3];
146     }
147
148     m_rowInputStride = patch_depth;
149     m_colInputStride = patch_depth * m_inputRows;
150     m_patchInputStride = patch_depth * m_inputRows * m_inputCols;
151
152     m_rowPaddingTop = tensor.impl().rowPaddingTop();
153     m_colPaddingLeft = tensor.impl().colPaddingLeft();
154
155     m_fastPatchRowStride = internal::TensorIntDivisor<Index>(m_patch_row_stride);
156     m_fastPatchColStride = internal::TensorIntDivisor<Index>(m_patch_col_stride);
157     m_fastInputRowStride = internal::TensorIntDivisor<Index>(m_patch_row_inflate_strides);
158     m_fastInputColStride = internal::TensorIntDivisor<Index>(m_patch_col_inflate_strides);
159     m_fastNumPatches = internal::TensorIntDivisor<Index>(m_num_patches);
160     m_fastColStride = internal::TensorIntDivisor<Index>(m_colStride);
161     m_fastOutputRows = internal::TensorIntDivisor<Index>(m_outputRows);
162     m_fastDimZero = internal::TensorIntDivisor<Index>(patch_depth);
163   }
164
165   EIGEN_DEVICE_FUNC
166   TensorContractionInputMapper(const TensorContractionInputMapper &base_mapper)
167     : m_impl(base_mapper.m_impl)
168   {
169     m_patch_cols = base_mapper.m_patch_cols;
170     m_num_patches = base_mapper.m_num_patches;
171
172     m_patch_row_stride = base_mapper.m_patch_row_stride;
173     m_patch_col_stride = base_mapper.m_patch_col_stride;
174
175     m_patch_row_inflate_strides = base_mapper.m_patch_row_inflate_strides;
176     m_patch_col_inflate_strides = base_mapper.m_patch_col_inflate_strides;
177
178     m_colStride = base_mapper.m_colStride;
179
180     m_rowInputStride = base_mapper.m_rowInputStride;
181     m_colInputStride = base_mapper.m_colInputStride;
182     m_patchInputStride = base_mapper.m_patchInputStride;
183
184     m_inputRows = base_mapper.m_inputRows;
185     m_inputCols = base_mapper.m_inputCols;
186
187     m_outputRows = base_mapper.m_outputRows;
188     m_outputCols = base_mapper.m_outputCols;
189     m_row_strides = base_mapper.m_row_strides;
190     m_col_strides = base_mapper.m_col_strides;
191
192     m_in_row_strides = base_mapper.m_in_row_strides;
193     m_in_col_strides = base_mapper.m_in_col_strides;
194
195     m_rowPaddingTop = base_mapper.m_rowPaddingTop;
196     m_colPaddingLeft = base_mapper.m_colPaddingLeft;
197
198     m_fastPatchRowStride = base_mapper.m_fastPatchRowStride;
199     m_fastPatchColStride = base_mapper.m_fastPatchColStride;
200     m_fastInputRowStride = base_mapper.m_fastInputRowStride;
201     m_fastInputColStride = base_mapper.m_fastInputColStride;
202     m_fastNumPatches = base_mapper.m_fastNumPatches;
203     m_fastColStride = base_mapper.m_fastColStride;
204     m_fastOutputRows = base_mapper.m_fastOutputRows;
205     m_fastDimZero = base_mapper.m_fastDimZero;
206   }
207
208   // If true, turns off some optimizations for loading packets since the image
209   // patches are "non-standard" such as there are non-trivial strides or
210   // inflations in the input.
211   EIGEN_DEVICE_FUNC
212   EIGEN_ALWAYS_INLINE bool nonStandardPatches() const
213   {
214     return m_in_row_strides != 1 || m_in_col_strides != 1 || m_patch_row_inflate_strides != 1 ||
215            m_patch_col_inflate_strides != 1;
216   }
217
218   EIGEN_DEVICE_FUNC
219   EIGEN_STRONG_INLINE SubMapper getSubMapper(Index i, Index j) const
220   {
221     return SubMapper(*this, i, j);
222   }
223
224   EIGEN_DEVICE_FUNC
225   EIGEN_STRONG_INLINE LinearMapper getLinearMapper(Index i, Index j) const
226   {
227     return LinearMapper(*this, i, j);
228   }
229
230   EIGEN_DEVICE_FUNC
231   EIGEN_ALWAYS_INLINE Scalar operator()(Index row) const
232   {
233     Index rowIndex, colIndex, otherIndex;
234     computeBaseIndices(0, rowIndex, colIndex, otherIndex);
235     return loadCoeff(row, rowIndex, colIndex, otherIndex);
236   }
237
238   // Load the coefficient at the patchIndex location instead of the usual
239   // m_rowIndex,
240   // m_colIndex, m_otherIndex. This is currently only used by the gpu code.
241   // EIGEN_DEVICE_FUNC
242   EIGEN_DEVICE_FUNC
243   EIGEN_STRONG_INLINE Scalar operator()(Index row, Index patchIndex) const
244   {
245     Index rowIndex, colIndex, otherIndex;
246     computeBaseIndices(patchIndex, rowIndex, colIndex, otherIndex);
247     return loadCoeff(row, rowIndex, colIndex, otherIndex);
248   }
249
250   EIGEN_DEVICE_FUNC
251   EIGEN_ALWAYS_INLINE Packet loadPacket(Index row) const
252   {
253     Index rowIndex, colIndex, otherIndex;
254     computeBaseIndices(0, rowIndex, colIndex, otherIndex);
255     return loadPacket(row, rowIndex, colIndex, otherIndex);
256   }
257
258   // Load the packet at the patchIndex location instead of the usual m_rowIndex,
259   // m_colIndex, m_otherIndex. This is currently only used by the gpu code.
260   EIGEN_DEVICE_FUNC
261   EIGEN_ALWAYS_INLINE Packet loadPacket(Index row, Index patchIndex) const
262   {
263     Index rowIndex, colIndex, otherIndex;
264     computeBaseIndices(patchIndex, rowIndex, colIndex, otherIndex);
265     return loadPacket(row, rowIndex, colIndex, otherIndex);
266   }
267
268   EIGEN_DEVICE_FUNC
269   EIGEN_ALWAYS_INLINE const TensorEvaluator<ArgType, Device> &impl() const { return m_impl; }
270
271   EIGEN_DEVICE_FUNC
272   EIGEN_ALWAYS_INLINE Index patchDepth() const { return m_rowInputStride; }
273   EIGEN_DEVICE_FUNC
274   EIGEN_ALWAYS_INLINE Index patchRows() const { return m_colStride; }
275   EIGEN_DEVICE_FUNC
276   EIGEN_ALWAYS_INLINE Index patchCols() const { return m_patch_cols; }
277
278 private:
279   friend class TensorContractionSubMapper<
280     Scalar, Index, Side,
281     TensorEvaluator<
282       const TensorReshapingOp<NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType>>, Device>,
283     nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment>;
284
285   // Load coefficient from a patch specified by the "within patch offset"
286   // (patchId) and the precomputed indices of the first element of the patch.
287   EIGEN_DEVICE_FUNC
288   EIGEN_STRONG_INLINE Scalar loadCoeff(Index patchId, Index rowIndex, Index colIndex,
289                                        Index otherIndex) const
290   {
291     // Find the offset of the element wrt the location of the first element.
292     const Index patchOffset = patchId / m_fastDimZero;
293
294     const Index colOffset = patchOffset / m_fastColStride;
295     const Index inputCol = colIndex + colOffset * m_in_col_strides;
296     const Index origInputCol = (m_patch_col_inflate_strides == 1)
297                                  ? inputCol
298                                  : ((inputCol >= 0) ? (inputCol / m_fastInputColStride) : 0);
299
300     const Index rowOffset = patchOffset - colOffset * m_colStride;
301     const Index inputRow = rowIndex + rowOffset * m_in_row_strides;
302     const Index origInputRow = (m_patch_row_inflate_strides == 1)
303                                  ? inputRow
304                                  : ((inputRow >= 0) ? (inputRow / m_fastInputRowStride) : 0);
305     if (origInputCol < 0 || origInputRow < 0 || origInputCol >= m_inputCols ||
306         origInputRow >= m_inputRows || (inputCol != origInputCol * m_patch_col_inflate_strides) ||
307         (inputRow != origInputRow * m_patch_row_inflate_strides))
308     {
309       return Scalar(0);
310     }
311     const Index depth = patchId - patchOffset * patchDepth();
312     const Index inputIndex =
313       depth + origInputRow * m_rowInputStride + origInputCol * m_colInputStride + otherIndex;
314     return m_impl.coeff(inputIndex);
315   }
316
317   // This is the same as loadCoeff(...), but optimized for all `inflate_strides`
318   // and `in_strides` equal to 1 (template specialization without templates).
319   EIGEN_DEVICE_FUNC
320   EIGEN_STRONG_INLINE Scalar loadCoeffStandard(Index patchId, Index rowIndex, Index colIndex,
321                                                Index otherIndex) const
322   {
323     eigen_assert(!nonStandardPatches());
324
325     // Find the offset of the element wrt the location of the first element.
326     const Index patchOffset = patchId / m_fastDimZero;
327     const Index colOffset = patchOffset / m_fastColStride;
328     const Index rowOffset = patchOffset - colOffset * m_colStride;
329     const Index inputCol = colIndex + colOffset;
330     const Index inputRow = rowIndex + rowOffset;
331     if (inputCol < 0 || inputCol >= m_inputCols || inputRow < 0 || inputRow >= m_inputRows)
332     {
333       return Scalar(0);
334     }
335     const Index depth = patchId - patchOffset * patchDepth();
336     const Index inputIndex =
337       depth + inputRow * m_rowInputStride + inputCol * m_colInputStride + otherIndex;
338     return m_impl.coeff(inputIndex);
339   }
340
341   // Load packet from a patch specified by the "within patch offset"
342   // (patchId) and the precomputed indices of the first element of the patch.
343   EIGEN_DEVICE_FUNC
344   EIGEN_ALWAYS_INLINE Packet loadPacket(Index patchId, Index rowIndex, Index colIndex,
345                                         Index otherIndex) const
346   {
347     const Index packetSize = internal::unpacket_traits<Packet>::size;
348     EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE)
349     eigen_assert(patchId < patchDepth() * patchRows() * m_patch_cols);
350
351     if (nonStandardPatches())
352     {
353       return packetWithPossibleZero(patchId, rowIndex, colIndex, otherIndex);
354     }
355     typedef decltype(m_impl) TensorEvaluatorT;
356     return loadPacketStandard<Packet, TensorEvaluatorT>(patchId, rowIndex, colIndex, otherIndex);
357   }
358
359   // Helper function to load a 'partial' packet - this is the single column
360   // part of a packet that is split across two columns. In the 'partial' packet,
361   // the elements corresponding to the column (specified through colOffset) are
362   // loaded and the rest of the elements are zero-filled into the 'partial'
363   // packet. This function is called from loadPacketStandardFromTwoColumns().
364   // This code path is exercised only when the packet type supports masked load
365   // and when the partial packet load is available in the TensorEvaluator.
366   EIGEN_DEVICE_FUNC
367   EIGEN_ALWAYS_INLINE Packet loadPartialPacketStandard(Index rowIndex, Index colIndex,
368                                                        Index otherIndex, Index patchId,
369                                                        const Index span[],
370                                                        const Index patchOffsets[],
371                                                        Index colOffset) const
372   {
373     const Index inputCol = colIndex + colOffset;
374     const Index rowOffsets[2] = {patchOffsets[0] - colOffset * m_colStride,
375                                  patchOffsets[1] - colOffset * m_colStride};
376     const Index inputRows[2] = {rowIndex + rowOffsets[0], rowIndex + rowOffsets[1]};
377
378     if (inputRows[0] >= m_inputRows || inputRows[1] < 0 || inputCol >= m_inputCols || inputCol < 0)
379     {
380       // Partial packet is all zeros
381       return internal::pset1<Packet>(Scalar(0));
382     }
383     else if (inputRows[0] >= 0 && inputRows[1] < m_inputRows)
384     {
385       // From inputIndex-span[0], we need to load elements starting from index
386       // span[0] all the way upto (and including) span[1].
387       const Index depth = patchId - patchOffsets[0] * patchDepth();
388       const Index inputIndex =
389         depth + inputRows[0] * m_rowInputStride + inputCol * m_colInputStride + otherIndex;
390       return m_impl.template partialPacket<Packet>(inputIndex - span[0],
391                                                    mask<Packet>(span[0], span[1] + 1));
392     }
393     else
394     {
395       // Using slow path for this partial packet.
396       // We need to load elements starting from index span[0] all the way upto
397       // (and including) span[1]. We split this load into 3 parts:
398       // 0 : span[0]-1 - Zeros will be loaded for these indices
399       // span[0] : span[1] - Elements will be loaded here for these indices
400       // span[1]+1 : packetSize-1 - Zeross will be loaded for these indices
401       const Index packetSize = internal::unpacket_traits<Packet>::size;
402       EIGEN_ALIGN_MAX
403       typename internal::remove_const<Scalar>::type values[packetSize];
404       for (int i = 0; i < span[0]; ++i)
405         values[i] = Scalar(0);
406       for (int i = span[0]; i < span[1] + 1; ++i)
407         values[i] = loadCoeff(patchId - span[0] + i, rowIndex, colIndex, otherIndex);
408       for (int i = span[1] + 1; i < packetSize; ++i)
409         values[i] = Scalar(0);
410       return internal::pload<Packet>(values);
411     }
412   }
413
414   // Helper function to load a packet that is split across two columns.
415   // If required, this function is called from loadPacketStandard() when the
416   // packet type supports masked load and when the partial packet load is
417   // available in the TensorEvaluator.
418   EIGEN_DEVICE_FUNC
419   EIGEN_ALWAYS_INLINE Packet loadPacketStandardFromTwoColumns(Index patchId, Index rowIndex,
420                                                               Index colIndex, Index otherIndex,
421                                                               const Index patchOffsets[],
422                                                               const Index colOffsets[]) const
423   {
424     eigen_assert(colOffsets[1] == colOffsets[0] + 1);
425     const Index packetSize = internal::unpacket_traits<Packet>::size;
426
427     // Packet to load will be split into 2 parts where each part spans a single
428     // column. First determine where to split.
429     const Index patchIdSplit = ((colOffsets[1] * m_colStride) * m_rowInputStride) - 1;
430     const Index patchOffsetSplit = patchIdSplit / m_fastDimZero;
431
432     // patchIds[i]:          patchId corresponding to partial packet i
433     // spans[i]:             Start and end indices corresponding to the elements
434     //                       to be loaded for partial packet i
435     // patchOffsets2Cols[i]: patchOffsets corresponding to partial packet i
436     const Index patchIds[2] = {patchId, patchIdSplit + 1};
437     const Index spans[2][2] = {{0, patchIdSplit - patchId},
438                                {patchIdSplit - patchId + 1, packetSize - 1}};
439     const Index patchOffsets2Cols[2][2] = {{patchOffsets[0], patchOffsetSplit},
440                                            {patchOffsetSplit + 1, patchOffsets[1]}};
441
442     // Load partial packets and do bit-wise OR to generate required packet
443     return internal::por<Packet>(
444       loadPartialPacketStandard(rowIndex, colIndex, otherIndex, patchIds[0], spans[0],
445                                 patchOffsets2Cols[0], colOffsets[0]),
446       loadPartialPacketStandard(rowIndex, colIndex, otherIndex, patchIds[1], spans[1],
447                                 patchOffsets2Cols[1], colOffsets[1]));
448   }
449
450   // Helper function to load a packet that is present in a single columns.
451   // If required, this function is called from loadPacketStandard().
452   EIGEN_DEVICE_FUNC
453   EIGEN_ALWAYS_INLINE Packet loadPacketStandardFromSingleColumn(Index patchId, Index rowIndex,
454                                                                 Index colIndex, Index otherIndex,
455                                                                 const Index patchOffsets[],
456                                                                 const Index colOffsets[],
457                                                                 const Index inputCols[]) const
458   {
459     eigen_assert(colOffsets[0] == colOffsets[1]);
460     const Index rowOffsets[2] = {patchOffsets[0] - colOffsets[0] * m_colStride,
461                                  patchOffsets[1] - colOffsets[1] * m_colStride};
462     eigen_assert(rowOffsets[0] <= rowOffsets[1]);
463     const Index inputRows[2] = {rowIndex + rowOffsets[0], rowIndex + rowOffsets[1]};
464
465     if (inputRows[0] >= m_inputRows || inputRows[1] < 0)
466     {
467       // all zeros
468       return internal::pset1<Packet>(Scalar(0)); // all zeros
469     }
470
471     if (inputRows[0] >= 0 && inputRows[1] < m_inputRows)
472     {
473       // no padding
474       const Index depth = patchId - patchOffsets[0] * patchDepth();
475       const Index inputIndex =
476         depth + inputRows[0] * m_rowInputStride + inputCols[0] * m_colInputStride + otherIndex;
477       return m_impl.template packet<Unaligned>(inputIndex);
478     }
479     return packetWithPossibleZero(patchId, rowIndex, colIndex, otherIndex);
480   }
481
482   // Load standard packet from a patch specified by the "within patch offset"
483   // (patchId) and the precomputed indices of the first element of the patch.
484   // This function will be called if partial packet loading is not available
485   // for the TensorEvaluator or if the packet type does not support masked
486   // load.
487   template <typename PacketT, typename TensorEvaluatorT>
488   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE typename std::enable_if<
489     !TensorEvaluatorHasPartialPacket<TensorEvaluatorT, PacketT, Index>::value, PacketT>::type
490   loadPacketStandard(Index patchId, Index rowIndex, Index colIndex, Index otherIndex) const
491   {
492     const Index packetSize = internal::unpacket_traits<Packet>::size;
493     EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE)
494     eigen_assert(patchId < patchDepth() * patchRows() * m_patch_cols);
495
496     eigen_assert(!nonStandardPatches());
497
498     if ((patchDepth() % packetSize) == 0)
499     {
500       return loadPacketFast(patchId, rowIndex, colIndex, otherIndex);
501     }
502
503     // Offsets and input calculation here are identical to
504     // loadCoeffStandard(...), but repeated twice.
505     const Index patchOffsets[2] = {patchId / m_fastDimZero,
506                                    (patchId + packetSize - 1) / m_fastDimZero};
507     const Index colOffsets[2] = {patchOffsets[0] / m_fastColStride,
508                                  patchOffsets[1] / m_fastColStride};
509     const Index inputCols[2] = {colIndex + colOffsets[0], colIndex + colOffsets[1]};
510
511     if (inputCols[0] >= m_inputCols || inputCols[1] < 0)
512     {
513       // all zeros
514       return internal::pset1<Packet>(Scalar(0));
515     }
516     if (inputCols[0] == inputCols[1])
517     {
518       return loadPacketStandardFromSingleColumn(patchId, rowIndex, colIndex, otherIndex,
519                                                 patchOffsets, colOffsets, inputCols);
520     }
521     return packetWithPossibleZero(patchId, rowIndex, colIndex, otherIndex);
522   }
523
524   // Load standard packet from a patch specified by the "within patch offset"
525   // (patchId) and the precomputed indices of the first element of the patch.
526   // This function will be called if partial packet loading is available for
527   // the TensorEvaluator and if the packet type supports masked load.
528   // The only difference between this and the other case is that if the packet
529   // to load is split across two columns, then in this case instead of going to
530   // the slow (element-by-element) load, we load two packets - each containing
531   // elements from one of the columns (rest of the elements of the packets are
532   // zeroes), and then combine these two packets to generate the required
533   // packet. The idea is to enable fast load (if possible) of these 'partial'
534   // packets.
535   template <typename PacketT, typename TensorEvaluatorT>
536   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE typename std::enable_if<
537     TensorEvaluatorHasPartialPacket<TensorEvaluatorT, PacketT, Index>::value, PacketT>::type
538   loadPacketStandard(Index patchId, Index rowIndex, Index colIndex, Index otherIndex) const
539   {
540     const Index packetSize = internal::unpacket_traits<PacketT>::size;
541     EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE)
542     eigen_assert(patchId < patchDepth() * patchRows() * m_patch_cols);
543
544     eigen_assert(!nonStandardPatches());
545
546     if ((patchDepth() % packetSize) == 0)
547     {
548       return loadPacketFast(patchId, rowIndex, colIndex, otherIndex);
549     }
550
551     // Offsets and input calculation here are identical to
552     // loadCoeffStandard(...), but repeated twice.
553     const Index patchOffsets[2] = {patchId / m_fastDimZero,
554                                    (patchId + packetSize - 1) / m_fastDimZero};
555     const Index colOffsets[2] = {patchOffsets[0] / m_fastColStride,
556                                  patchOffsets[1] / m_fastColStride};
557     const Index inputCols[2] = {colIndex + colOffsets[0], colIndex + colOffsets[1]};
558
559     if (inputCols[0] >= m_inputCols || inputCols[1] < 0)
560     {
561       // all zeros
562       return internal::pset1<PacketT>(Scalar(0));
563     }
564     if (inputCols[0] == inputCols[1])
565     {
566       return loadPacketStandardFromSingleColumn(patchId, rowIndex, colIndex, otherIndex,
567                                                 patchOffsets, colOffsets, inputCols);
568     }
569     if (inputCols[1] == inputCols[0] + 1)
570     {
571       return loadPacketStandardFromTwoColumns(patchId, rowIndex, colIndex, otherIndex, patchOffsets,
572                                               colOffsets);
573     }
574     return packetWithPossibleZero(patchId, rowIndex, colIndex, otherIndex);
575   }
576
577   EIGEN_DEVICE_FUNC
578   EIGEN_ALWAYS_INLINE Packet loadPacketFast(Index patchId, Index rowIndex, Index colIndex,
579                                             Index otherIndex) const
580   {
581     const Index packetSize = internal::unpacket_traits<Packet>::size;
582     EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE)
583     eigen_assert(patchId < patchDepth() * patchRows() * m_patch_cols);
584
585     eigen_assert(!nonStandardPatches());
586     eigen_assert((patchDepth() % packetSize) == 0);
587     // Find the offset of the element wrt the location of the first element.
588     const Index patchOffset = patchId / m_fastDimZero;
589     eigen_assert((patchId + packetSize - 1) / m_fastDimZero == patchOffset);
590
591     const Index colOffset = patchOffset / m_fastColStride;
592     const Index rowOffset = patchOffset - colOffset * m_colStride;
593     const Index inputCol = colIndex + colOffset;
594     const Index inputRow = rowIndex + rowOffset;
595     if (inputCol < 0 || inputRow < 0 || inputCol >= m_inputCols || inputRow >= m_inputRows)
596     {
597       // all zeros
598       return internal::pset1<Packet>(Scalar(0));
599     }
600     // no padding
601     const Index depth = patchId - patchOffset * patchDepth();
602     const Index inputIndex =
603       depth + inputRow * m_rowInputStride + inputCol * m_colInputStride + otherIndex;
604     return m_impl.template packet<Unaligned>(inputIndex);
605   }
606
607   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet packetWithPossibleZero(Index patchId, Index rowIndex,
608                                                                       Index colIndex,
609                                                                       Index otherIndex) const
610   {
611     const int packetSize = internal::unpacket_traits<Packet>::size;
612     EIGEN_ALIGN_MAX
613     typename internal::remove_const<Scalar>::type values[packetSize];
614     for (int i = 0; i < packetSize; ++i)
615     {
616       values[i] = loadCoeff(patchId + i, rowIndex, colIndex, otherIndex);
617     }
618     Packet rslt = internal::pload<Packet>(values);
619     return rslt;
620   }
621
622   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void
623   computeBaseIndices(Index patchIndex, Index &rowIndex, Index &colIndex, Index &otherIndex) const
624   {
625     const size_t NumInputDims =
626       array_size<typename TensorEvaluator<ArgType, Device>::Dimensions>::value;
627     otherIndex = (NumInputDims == 3) ? 0 : patchIndex / m_fastNumPatches;
628     const Index patch2DIndex =
629       (NumInputDims == 3) ? patchIndex : (patchIndex - otherIndex * m_num_patches);
630     otherIndex *= m_patchInputStride;
631     colIndex = patch2DIndex / m_fastOutputRows;
632     rowIndex = patch2DIndex - colIndex * m_outputRows;
633     colIndex = colIndex * m_col_strides - m_colPaddingLeft;
634     rowIndex = rowIndex * m_row_strides - m_rowPaddingTop;
635   }
636
637   Index m_patch_cols;  // number of columns in the patch
638   Index m_num_patches; // number of patches to extract.
639
640   // Strides for navigating through the single patch.
641   Index m_patch_row_stride;
642   Index m_patch_col_stride;
643   internal::TensorIntDivisor<Index> m_fastPatchRowStride;
644   internal::TensorIntDivisor<Index> m_fastPatchColStride;
645
646   Index m_patch_row_inflate_strides; // the strides for row inflation in the
647                                      // image patch
648   Index m_patch_col_inflate_strides; // the strides for col inflation in the
649                                      // image patch
650   // Fast representation of inflation strides.
651   internal::TensorIntDivisor<Index> m_fastInputRowStride;
652   internal::TensorIntDivisor<Index> m_fastInputColStride;
653
654   Index m_otherStride;
655   Index m_colStride;
656   internal::TensorIntDivisor<Index> m_fastNumPatches;
657   internal::TensorIntDivisor<Index> m_fastColStride;
658
659   Index m_rowInputStride;   // row stride in the input tensor
660   Index m_colInputStride;   // col stride in the input tensor
661   Index m_patchInputStride; // patch stride in the input tensor
662
663   Index m_inputRows; // Number of rows in the input tensor
664   Index m_inputCols; // Number of cols in the input tensor
665
666   Index m_outputRows; // Number of convolution output rows
667   Index m_outputCols; // Number of convolution output column
668
669   Index m_row_strides; // User specified row stride
670   Index m_col_strides; // User specified col stride
671
672   Index m_in_row_strides; // User specified input row stride
673   Index m_in_col_strides; // User specified input col stride
674
675   Index m_rowPaddingTop;  // Row padding
676   Index m_colPaddingLeft; // Column padding
677
678   internal::TensorIntDivisor<Index> m_fastOutputRows;
679   internal::TensorIntDivisor<Index> m_fastDimZero;
680
681   const TensorEvaluator<ArgType, Device> m_impl;
682 };
683
684 template <typename NewDimension, Index Rows, Index Cols, typename ArgType, typename Device,
685           typename Scalar, typename Index, typename nocontract_t, typename contract_t, int Side,
686           int packet_size, bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment>
687 class TensorContractionSubMapper<
688   Scalar, Index, Side,
689   TensorEvaluator<
690     const TensorReshapingOp<NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType>>, Device>,
691   nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment>
692 {
693 public:
694   typedef typename packet_traits<Scalar>::type Packet;
695   typedef typename packet_traits<Scalar>::half HalfPacket;
696
697   typedef TensorContractionInputMapper<
698     Scalar, Index, Side,
699     TensorEvaluator<
700       const TensorReshapingOp<NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType>>, Device>,
701     nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment>
702     ParentMapper;
703
704   typedef TensorContractionSubMapper<
705     Scalar, Index, Side,
706     TensorEvaluator<
707       const TensorReshapingOp<NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType>>, Device>,
708     nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment>
709     Self;
710
711   typedef Self LinearMapper;
712
713   typedef typename ParentMapper::TensorEvaluatorT TensorEvaluatorT;
714
715   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionSubMapper(const ParentMapper &base_mapper,
716                                                                    Index vert_offset,
717                                                                    Index horiz_offset)
718     : m_depth_offset(vert_offset), m_col_offset(horiz_offset), m_base_mapper(base_mapper)
719   {
720     m_base_mapper.computeBaseIndices(m_col_offset, m_rowIndex, m_colIndex, m_otherIndex);
721   }
722   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionSubMapper(const Self &base_mapper,
723                                                                    Index vert_offset,
724                                                                    Index horiz_offset)
725     : m_depth_offset(vert_offset + base_mapper.m_depth_offset),
726       m_col_offset(horiz_offset + base_mapper.m_col_offset),
727       m_base_mapper(base_mapper.m_base_mapper)
728   {
729     m_base_mapper.computeBaseIndices(m_col_offset, m_rowIndex, m_colIndex, m_otherIndex);
730   }
731   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i) const
732   {
733     return m_base_mapper.loadCoeff(i + m_depth_offset, m_rowIndex, m_colIndex, m_otherIndex);
734   }
735   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i, Index j) const
736   {
737     return m_base_mapper(i + m_depth_offset, j + m_col_offset);
738   }
739
740   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i) const
741   {
742     return m_base_mapper.loadPacket(i + m_depth_offset, m_rowIndex, m_colIndex, m_otherIndex);
743   }
744   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i, Index j) const
745   {
746     return m_base_mapper.template loadPacket<Alignment>(i + m_depth_offset, j + m_col_offset);
747   }
748   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar loadCoeffStandard(Index i) const
749   {
750     return m_base_mapper.loadCoeffStandard(i + m_depth_offset, m_rowIndex, m_colIndex,
751                                            m_otherIndex);
752   }
753
754   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacketFast(Index i) const
755   {
756     return m_base_mapper.loadPacketFast(i + m_depth_offset, m_rowIndex, m_colIndex, m_otherIndex);
757   }
758   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacketStandard(Index i) const
759   {
760     typedef decltype(m_base_mapper.m_impl) TensorEvaluatorT;
761     return m_base_mapper.template loadPacketStandard<Packet, TensorEvaluatorT>(
762       i + m_depth_offset, m_rowIndex, m_colIndex, m_otherIndex);
763   }
764   template <typename Packet> EIGEN_DEVICE_FUNC bool aligned(Index) const { return false; }
765
766   EIGEN_DEVICE_FUNC
767   EIGEN_ALWAYS_INLINE bool nonStandardPatches() const { return m_base_mapper.nonStandardPatches(); }
768
769   // Max(Col|Row|Depth): compute the upper limit for the column, row and depth
770   // index respectively that fits into the peeled_k elements starting at
771   // m_depth_offset.
772
773   EIGEN_DEVICE_FUNC
774   EIGEN_ALWAYS_INLINE Index maxCol(const Index peeled_k) const
775   {
776     const Index max_col =
777       (m_depth_offset + (peeled_k == 0 ? 0 : peeled_k - 1)) / fastPatchColStride();
778     return std::min<Index>(1 + max_col, patchCols());
779   }
780
781   EIGEN_DEVICE_FUNC
782   EIGEN_ALWAYS_INLINE Index maxRow(const Index peeled_k, const Index col) const
783   {
784     const Index max_row =
785       (m_depth_offset + (peeled_k == 0 ? 0 : peeled_k - 1) - col * patchColStride()) /
786       fastPatchRowStride();
787     return std::min<Index>(1 + max_row, patchRows());
788   }
789
790   EIGEN_DEVICE_FUNC
791   EIGEN_ALWAYS_INLINE Index maxDepth(const Index peeled_k, const Index col, Index row) const
792   {
793     const Index max_depth = m_depth_offset + peeled_k - //
794                             col * patchColStride() -    //
795                             row * patchRowStride();
796     return std::min<Index>(max_depth, patchDepth());
797   }
798
799   // MaxDepth uses only the remaining number of elements in the peeled_k.
800   EIGEN_DEVICE_FUNC
801   EIGEN_ALWAYS_INLINE Index maxDepth(const Index num_elements, const Index start_depth) const
802   {
803     return std::min<Index>(start_depth + num_elements, patchDepth());
804   }
805
806   // Every register matters in this code, so sometimes to prevent register
807   // spilling, instead of the variable that you would expect to see, we use
808   // another one, that is guaranteed to have the same value. E.g. patch depth is
809   // always the same as input depth, and it's also the same as input row stride.
810   // Bunch of other parameters have similar relations.
811
812   typedef internal::TensorIntDivisor<Index> IndexDivisor;
813
814   EIGEN_DEVICE_FUNC
815   EIGEN_ALWAYS_INLINE Index patchDepth() const { return m_base_mapper.m_rowInputStride; }
816   EIGEN_DEVICE_FUNC
817   EIGEN_ALWAYS_INLINE Index patchRows() const { return m_base_mapper.m_colStride; }
818   EIGEN_DEVICE_FUNC
819   EIGEN_ALWAYS_INLINE Index patchCols() const { return m_base_mapper.m_patch_cols; }
820
821   EIGEN_DEVICE_FUNC
822   EIGEN_ALWAYS_INLINE Index patchRowStride() const
823   {
824     eigen_assert(patchDepth() == m_base_mapper.m_patch_row_stride &&
825                  "Patch depth must be equal to patch row stride.");
826     return patchDepth();
827   }
828   EIGEN_DEVICE_FUNC
829   EIGEN_ALWAYS_INLINE Index patchColStride() const { return m_base_mapper.m_patch_col_stride; }
830
831   EIGEN_DEVICE_FUNC
832   EIGEN_ALWAYS_INLINE IndexDivisor fastPatchRowStride() const
833   {
834     eigen_assert(patchDepth() == m_base_mapper.m_patch_row_stride &&
835                  "Patch depth must be equal to patch row stride.");
836     return m_base_mapper.m_fastDimZero; // patch_depth
837   }
838   EIGEN_DEVICE_FUNC
839   EIGEN_ALWAYS_INLINE IndexDivisor fastPatchColStride() const
840   {
841     return m_base_mapper.m_fastPatchColStride;
842   }
843
844   EIGEN_DEVICE_FUNC
845   EIGEN_ALWAYS_INLINE Packet packetNoPadding(const Index depth, const Index baseIndex) const
846   {
847     const Index inputIndex = depth + baseIndex;
848     return m_base_mapper.m_impl.template packet<Unaligned>(inputIndex);
849   }
850   EIGEN_DEVICE_FUNC
851   EIGEN_ALWAYS_INLINE Scalar coeffNoPadding(const Index depth, const Index baseIndex) const
852   {
853     const Index inputIndex = depth + baseIndex;
854     return m_base_mapper.m_impl.coeff(inputIndex);
855   }
856   template <typename PacketT = Packet>
857   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE typename std::enable_if<
858     TensorEvaluatorHasPartialPacket<TensorEvaluatorT, PacketT, Index>::value, PacketT>::type
859   partialPacketNoPadding(const Index depth, const Index baseIndex, Index num_coeffs) const
860   {
861     const Index inputIndex = depth + baseIndex;
862     return m_base_mapper.m_impl.template partialPacket<PacketT>(inputIndex,
863                                                                 mask<PacketT>(0, num_coeffs));
864   }
865   EIGEN_DEVICE_FUNC
866   EIGEN_ALWAYS_INLINE bool hasPadding() const
867   {
868     // TODO(ezhulenev): It does seems that for inflated filter it's still
869     // possible to guarantee "no padding or skipping" for non-standard packing.
870     if (nonStandardPatches())
871       return true;
872
873     // Non zero padding before.
874     if (m_base_mapper.m_rowPaddingTop > 0)
875       return true;
876     if (m_base_mapper.m_colPaddingLeft > 0)
877       return true;
878
879     // Non zero padding after in rows.
880     const Index last_row = (m_base_mapper.m_outputRows - 1) * m_base_mapper.m_row_strides;
881     if (last_row + (patchRows() - 1) >= m_base_mapper.m_inputRows)
882       return true;
883
884     // Non zero padding after in cols.
885     const Index last_col = (m_base_mapper.m_outputCols - 1) * m_base_mapper.m_col_strides;
886     if (last_col + (patchCols() - 1) >= m_base_mapper.m_inputCols)
887       return true;
888
889     return false;
890   }
891   EIGEN_DEVICE_FUNC
892   EIGEN_ALWAYS_INLINE bool padRow(const Index row) const
893   {
894     const Index r = m_rowIndex + row;
895     return r < 0 || r >= m_base_mapper.m_inputRows;
896   }
897   EIGEN_DEVICE_FUNC
898   EIGEN_ALWAYS_INLINE bool padAnyRow(const Index first_row, const Index last_row) const
899   {
900     return m_rowIndex + first_row < 0 || m_rowIndex + last_row >= m_base_mapper.m_inputRows;
901   }
902   EIGEN_DEVICE_FUNC
903   EIGEN_ALWAYS_INLINE bool padOrSkipRow(const Index row, Index *orig_row) const
904   {
905     eigen_assert(nonStandardPatches());
906
907     const Index input_row = m_rowIndex + row * m_base_mapper.m_in_row_strides;
908     *orig_row = (m_base_mapper.m_patch_row_inflate_strides == 1)
909                   ? input_row
910                   : ((input_row >= 0) ? (input_row / m_base_mapper.m_fastInputRowStride) : 0);
911
912     return (*orig_row < 0 || *orig_row >= m_base_mapper.m_inputRows) ||
913            (input_row != *orig_row * m_base_mapper.m_patch_row_inflate_strides);
914   }
915   EIGEN_DEVICE_FUNC
916   EIGEN_ALWAYS_INLINE bool padCol(const Index col) const
917   {
918     const Index c = m_colIndex + col;
919     return c < 0 || c >= m_base_mapper.m_inputCols;
920   }
921   EIGEN_DEVICE_FUNC
922   EIGEN_ALWAYS_INLINE bool padOrSkipCol(const Index col, Index *orig_col) const
923   {
924     eigen_assert(nonStandardPatches());
925
926     const Index input_col = m_colIndex + col * m_base_mapper.m_in_col_strides;
927     *orig_col = (m_base_mapper.m_patch_col_inflate_strides == 1)
928                   ? input_col
929                   : ((input_col >= 0) ? (input_col / m_base_mapper.m_fastInputColStride) : 0);
930
931     return (*orig_col < 0 || *orig_col >= m_base_mapper.m_inputCols) ||
932            (input_col != *orig_col * m_base_mapper.m_patch_col_inflate_strides);
933   }
934   EIGEN_DEVICE_FUNC
935   EIGEN_ALWAYS_INLINE Index baseIndex(const Index row, const Index col) const
936   {
937     const Index r = m_rowIndex + row;
938     const Index c = m_colIndex + col;
939     return r * m_base_mapper.m_rowInputStride + c * m_base_mapper.m_colInputStride + m_otherIndex;
940   }
941   // Compute a base index when original input row and column were precomputed
942   // using padOrSkipRow and padOrSkipCol. Used only for non standard patches.
943   EIGEN_DEVICE_FUNC
944   EIGEN_ALWAYS_INLINE Index origBaseIndex(const Index orig_row, const Index orig_col) const
945   {
946     return orig_row * m_base_mapper.m_rowInputStride + orig_col * m_base_mapper.m_colInputStride +
947            m_otherIndex;
948   }
949
950   EIGEN_DEVICE_FUNC
951   EIGEN_ALWAYS_INLINE Index rowStride() const { return m_base_mapper.m_row_strides; }
952   EIGEN_DEVICE_FUNC
953   EIGEN_ALWAYS_INLINE Index colStride() const { return m_base_mapper.m_col_strides; }
954
955   EIGEN_DEVICE_FUNC
956   EIGEN_ALWAYS_INLINE Index rowOffset() const
957   {
958     const Index patchOffset = m_depth_offset / m_base_mapper.m_fastDimZero;
959     const Index colOffset = patchOffset / m_base_mapper.m_fastColStride;
960     return patchOffset - colOffset * m_base_mapper.m_colStride;
961   }
962
963   EIGEN_DEVICE_FUNC
964   EIGEN_ALWAYS_INLINE Index colOffset() const
965   {
966     const Index patchOffset = m_depth_offset / m_base_mapper.m_fastDimZero;
967     const Index colOffset = patchOffset / m_base_mapper.m_fastColStride;
968     return colOffset;
969   }
970
971   EIGEN_DEVICE_FUNC
972   EIGEN_ALWAYS_INLINE Index depthOffset() const { return m_depth_offset % patchDepth(); }
973
974   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE LinearMapper getLinearMapper(Index i, Index j) const
975   {
976     return LinearMapper(m_base_mapper, i + m_depth_offset, j + m_col_offset);
977   }
978
979 private:
980   Index m_depth_offset; // First row in the input matrix
981   Index m_col_offset;   // First col in the input matrix
982
983   // Knowing that: col_offset == patchIndex * OTHERS, we keep precomputed base
984   // indices for the first element in a patch specified by col_offset
985   // (see computeBaseIndices(...) for details).
986   Index m_rowIndex;
987   Index m_colIndex;
988   Index m_otherIndex;
989
990   const ParentMapper m_base_mapper; // Keeping a copy instead of a reference
991                                     // performs better in benchmarks.
992 };
993
994 // Arrange a block of the right input matrix (in our case it's always a "virtual
995 // matrix" constructed from extracted image patches) in contiguous memory.
996 //
997 // Given column major input (A0 beside A1 in memory):
998 // A0 B0 C0 D0  E0 F0 G0 H0 ... Z0
999 // A1 B1 C1 D1  E1 F1 G1 H1 ... Z1
1000 // A2 B2 C2 D2  E2 F2 G2 H2 ... Z2
1001 // A3 B3 C3 D3  E3 F3 G3 H3 ... Z3
1002 // A4 B4 C4 D4  E4 F4 G4 H4 ... Z4
1003 // A5 B5 C5 D5  E5 F5 G5 H5 ... Z5
1004 // A6 B6 C6 D6  E6 F6 G6 H6 ... Z6
1005 // A7 B7 C7 D7  E7 F7 G7 H7 ... Z7
1006 // A8 ...
1007 // ...
1008 //
1009 // *) A, B, C, ... - patches extracted from the original input.
1010 // *) A0, A1, A2 ... - values from the same patch at different offsets.
1011 //
1012 // The traversal (packed rhs memory) order (B0 besides A0 in memory):
1013 // A0 B0 C0 D0 A1 B1 C1 D1 ...
1014 // E0 F0 G0 H0 E1 F1 G1 H1 ...
1015 // ...
1016 // Z0 Z1 Z2 Z3 Z4 Z5 Z6 Z7 ... <- doesn't belong to any block (nr = 4)
1017 //
1018 // This traversal order must be the same as in default gemm_pack_rhs defined in
1019 // GeneralBlockPanelKernel.h.
1020 //
1021 // *) nr - number of registers along the 'n' dimension.
1022 //    See GeneralBlockPanelKernel.h and "Anatomy of High-Performance Matrix
1023 //    Multiplication" paper.
1024 template <typename NewDimension, Index Rows, Index Cols, typename ArgType, typename Device,
1025           typename Scalar, typename Index, typename nocontract_t, typename contract_t,
1026           int packet_size, bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment,
1027           int nr>
1028 struct gemm_pack_rhs<
1029   Scalar, Index,
1030   TensorContractionSubMapper<
1031     Scalar, Index, Rhs,
1032     TensorEvaluator<
1033       const TensorReshapingOp<NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType>>, Device>,
1034     nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment>,
1035   nr, ColMajor, false, false>
1036 {
1037   typedef TensorContractionSubMapper<
1038     Scalar, Index, Rhs,
1039     TensorEvaluator<
1040       const TensorReshapingOp<NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType>>, Device>,
1041     nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment>
1042     SubMapper;
1043   typedef SubMapper DataMapper;
1044   typedef typename packet_traits<Scalar>::type Packet;
1045
1046   EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE)
1047
1048   EIGEN_DEVICE_FUNC
1049   EIGEN_DONT_INLINE void operator()(Scalar *block, const DataMapper &rhs, Index depth, Index cols,
1050                                     Index stride = 0, Index offset = 0) const
1051   {
1052     eigen_assert(stride == 0);
1053     eigen_assert(offset == 0);
1054     (void)stride;
1055     (void)offset;
1056
1057     const Index packet_cols4 = (cols / 4) * 4;
1058     const Index peeled_k = (depth / packet_size) * packet_size;
1059     const bool non_standard_patches = rhs.nonStandardPatches();
1060
1061     for (Index j2 = 0; j2 < packet_cols4; j2 += 4)
1062     {
1063       const SubMapper dm0 = rhs.getLinearMapper(0, j2 + 0);
1064       const SubMapper dm1 = rhs.getLinearMapper(0, j2 + 1);
1065       const SubMapper dm2 = rhs.getLinearMapper(0, j2 + 2);
1066       const SubMapper dm3 = rhs.getLinearMapper(0, j2 + 3);
1067
1068       Index k = 0;
1069       if ((packet_size % 4) == 0 && !non_standard_patches)
1070       {
1071         // FAST PATH:
1072         // Iterate over patch columns and rows, if we know that a single
1073         // packet do not span across multiple rows or columns.
1074         if ((rhs.patchDepth() % packet_size) == 0)
1075         {
1076           const Index start_col = rhs.colOffset();
1077           const Index max_col = rhs.maxCol(peeled_k);
1078
1079           for (Index c = start_col; c < max_col; ++c)
1080           {
1081             eigen_assert(k <= peeled_k);
1082
1083             const Index start_row = (c == start_col) ? rhs.rowOffset() : 0;
1084             const Index max_row = rhs.maxRow(peeled_k, c);
1085
1086             const bool pad_col0 = dm0.padCol(c);
1087             const bool pad_col1 = dm1.padCol(c);
1088             const bool pad_col2 = dm2.padCol(c);
1089             const bool pad_col3 = dm3.padCol(c);
1090
1091             // Check if we can squeeze reads along the `row` and `depth`
1092             // dimensions (two innermost dimensions).
1093             if (!pad_col0 && !pad_col1 && !pad_col2 && !pad_col3 &&   //
1094                 !dm0.padRow(start_row) && !dm0.padRow(max_row - 1) && //
1095                 !dm1.padRow(start_row) && !dm1.padRow(max_row - 1) && //
1096                 !dm2.padRow(start_row) && !dm2.padRow(max_row - 1) && //
1097                 !dm3.padRow(start_row) && !dm3.padRow(max_row - 1))
1098             {
1099               // Compute how many elements we can squeeze read.
1100               const Index start_depth = (c == start_col) ? rhs.depthOffset() : 0;
1101
1102               // Upper bound for the number of elements in the depth dimension
1103               // that we can squeeze read.
1104               const Index squeeze_length = (max_row - start_row) * rhs.patchDepth() - start_depth;
1105
1106               // Do not overshoot beyond the block size.
1107               const Index max_depth = start_depth + std::min<Index>(peeled_k - k, squeeze_length);
1108               eigen_assert((max_depth - start_depth) % packet_size == 0);
1109
1110               const Index idx0 = dm0.baseIndex(start_row, c);
1111               const Index idx1 = dm1.baseIndex(start_row, c);
1112               const Index idx2 = dm2.baseIndex(start_row, c);
1113               const Index idx3 = dm3.baseIndex(start_row, c);
1114
1115               for (Index d = start_depth; d < max_depth; d += packet_size)
1116               {
1117                 eigen_assert(k < peeled_k);
1118                 PacketBlock<Packet, 4> kernel;
1119                 kernel.packet[0] = rhs.packetNoPadding(d, idx0);
1120                 kernel.packet[1] = rhs.packetNoPadding(d, idx1);
1121                 kernel.packet[2] = rhs.packetNoPadding(d, idx2);
1122                 kernel.packet[3] = rhs.packetNoPadding(d, idx3);
1123                 ptranspose(kernel);
1124                 pstoreu(block + 0 * packet_size, kernel.packet[0]);
1125                 pstoreu(block + 1 * packet_size, kernel.packet[1]);
1126                 pstoreu(block + 2 * packet_size, kernel.packet[2]);
1127                 pstoreu(block + 3 * packet_size, kernel.packet[3]);
1128                 block += 4 * packet_size;
1129                 k += packet_size;
1130               }
1131
1132               // Go to the next column.
1133               continue;
1134             }
1135
1136             // If we can't squeeze reads, process rows one by one.
1137             for (Index r = start_row; r < max_row; ++r)
1138             {
1139               eigen_assert(k <= peeled_k);
1140
1141               const bool pad0 = pad_col0 || dm0.padRow(r);
1142               const bool pad1 = pad_col1 || dm1.padRow(r);
1143               const bool pad2 = pad_col2 || dm2.padRow(r);
1144               const bool pad3 = pad_col3 || dm3.padRow(r);
1145
1146               const Index idx0 = dm0.baseIndex(r, c);
1147               const Index idx1 = dm1.baseIndex(r, c);
1148               const Index idx2 = dm2.baseIndex(r, c);
1149               const Index idx3 = dm3.baseIndex(r, c);
1150
1151               const Index start_depth =
1152                 ((c == start_col) && (r == start_row)) ? rhs.depthOffset() : 0;
1153               const Index max_depth = rhs.maxDepth(peeled_k - k, start_depth);
1154               eigen_assert((max_depth - start_depth) % packet_size == 0);
1155
1156               for (Index d = start_depth; d < max_depth; d += packet_size)
1157               {
1158                 eigen_assert(k < peeled_k);
1159                 PacketBlock<Packet, 4> kernel;
1160                 kernel.packet[0] = pad0 ? pset1<Packet>(Scalar(0)) : rhs.packetNoPadding(d, idx0);
1161                 kernel.packet[1] = pad1 ? pset1<Packet>(Scalar(0)) : rhs.packetNoPadding(d, idx1);
1162                 kernel.packet[2] = pad2 ? pset1<Packet>(Scalar(0)) : rhs.packetNoPadding(d, idx2);
1163                 kernel.packet[3] = pad3 ? pset1<Packet>(Scalar(0)) : rhs.packetNoPadding(d, idx3);
1164                 ptranspose(kernel);
1165                 pstoreu(block + 0 * packet_size, kernel.packet[0]);
1166                 pstoreu(block + 1 * packet_size, kernel.packet[1]);
1167                 pstoreu(block + 2 * packet_size, kernel.packet[2]);
1168                 pstoreu(block + 3 * packet_size, kernel.packet[3]);
1169                 block += 4 * packet_size;
1170                 k += packet_size;
1171               }
1172             }
1173           }
1174
1175           // The loop above should fill peeled_k elements.
1176           eigen_assert(peeled_k == k);
1177         }
1178         else
1179         {
1180           for (; k < peeled_k; k += packet_size)
1181           {
1182             PacketBlock<Packet, 4> kernel;
1183             kernel.packet[0] = dm0.loadPacketStandard(k);
1184             kernel.packet[1] = dm1.loadPacketStandard(k);
1185             kernel.packet[2] = dm2.loadPacketStandard(k);
1186             kernel.packet[3] = dm3.loadPacketStandard(k);
1187             ptranspose(kernel);
1188             pstoreu(block + 0 * packet_size, kernel.packet[0]);
1189             pstoreu(block + 1 * packet_size, kernel.packet[1]);
1190             pstoreu(block + 2 * packet_size, kernel.packet[2]);
1191             pstoreu(block + 3 * packet_size, kernel.packet[3]);
1192             block += 4 * packet_size;
1193           }
1194         }
1195       }
1196
1197       // Copy the remaining coefficients of the column block after the peeled_k.
1198       if (!rhs.nonStandardPatches())
1199       {
1200         for (; k < depth; k++)
1201         {
1202           block[0] = dm0.loadCoeffStandard(k);
1203           block[1] = dm1.loadCoeffStandard(k);
1204           block[2] = dm2.loadCoeffStandard(k);
1205           block[3] = dm3.loadCoeffStandard(k);
1206           block += 4;
1207         }
1208       }
1209       else
1210       {
1211         for (; k < depth; k++)
1212         {
1213           block[0] = dm0(k);
1214           block[1] = dm1(k);
1215           block[2] = dm2(k);
1216           block[3] = dm3(k);
1217           block += 4;
1218         }
1219       }
1220     }
1221
1222     // copy the remaining columns one at a time (nr==1)
1223     for (Index j2 = packet_cols4; j2 < cols; ++j2)
1224     {
1225       const SubMapper dm0 = rhs.getLinearMapper(0, j2);
1226       for (Index k = 0; k < depth; k++)
1227       {
1228         *block = dm0(k);
1229         block += 1;
1230       }
1231     }
1232   }
1233 };
1234
1235 // Template specialization for packet_size = 2. We must special-case packet
1236 // blocks with nr > packet_size, e.g. PacketBlock<Packet2d, 4>.
1237 template <typename NewDimension, Index Rows, Index Cols, typename ArgType, typename Device,
1238           typename Scalar, typename Index, typename nocontract_t, typename contract_t,
1239           bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment, int nr>
1240 struct gemm_pack_rhs<
1241   Scalar, Index,
1242   TensorContractionSubMapper<
1243     Scalar, Index, Rhs,
1244     TensorEvaluator<
1245       const TensorReshapingOp<NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType>>, Device>,
1246     nocontract_t, contract_t, 2, inner_dim_contiguous, inner_dim_reordered, Alignment>,
1247   nr, ColMajor, false, false>
1248 {
1249   typedef TensorContractionSubMapper<
1250     Scalar, Index, Rhs,
1251     TensorEvaluator<
1252       const TensorReshapingOp<NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType>>, Device>,
1253     nocontract_t, contract_t, 2, inner_dim_contiguous, inner_dim_reordered, Alignment>
1254     SubMapper;
1255   typedef SubMapper DataMapper;
1256   typedef typename packet_traits<Scalar>::type Packet;
1257
1258   EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE)
1259
1260   EIGEN_DEVICE_FUNC
1261   EIGEN_DONT_INLINE void operator()(Scalar *block, const DataMapper &rhs, Index depth, Index cols,
1262                                     Index stride = 0, Index offset = 0) const
1263   {
1264     eigen_assert(stride == 0);
1265     eigen_assert(offset == 0);
1266
1267     (void)stride;
1268     (void)offset;
1269
1270     const int packet_size = 2;
1271     const Index packet_cols4 = (cols / 4) * 4;
1272     const Index peeled_k = (depth / packet_size) * packet_size;
1273     const bool non_standard_patches = rhs.nonStandardPatches();
1274
1275     for (Index j2 = 0; j2 < packet_cols4; j2 += 4)
1276     {
1277       const SubMapper dm0 = rhs.getLinearMapper(0, j2 + 0);
1278       const SubMapper dm1 = rhs.getLinearMapper(0, j2 + 1);
1279       const SubMapper dm2 = rhs.getLinearMapper(0, j2 + 2);
1280       const SubMapper dm3 = rhs.getLinearMapper(0, j2 + 3);
1281
1282       Index k = 0;
1283       if (!non_standard_patches)
1284       {
1285         // FAST PATH:
1286         // Iterate over patch columns and rows if we know that a single
1287         // packet do not span across multiple rows or columns.
1288         if ((rhs.patchDepth() % packet_size) == 0)
1289         {
1290           const Index start_col = rhs.colOffset();
1291           const Index max_col = rhs.maxCol(peeled_k);
1292
1293           for (Index c = start_col; c < max_col; ++c)
1294           {
1295             eigen_assert(k <= peeled_k);
1296
1297             const Index start_row = (c == start_col) ? rhs.rowOffset() : 0;
1298             const Index max_row = rhs.maxRow(peeled_k, c);
1299
1300             const bool pad_col0 = dm0.padCol(c);
1301             const bool pad_col1 = dm1.padCol(c);
1302             const bool pad_col2 = dm2.padCol(c);
1303             const bool pad_col3 = dm3.padCol(c);
1304
1305             // We can squeeze reads along the `row` and `depth` dimensions if
1306             // the row stride is `1`, which means that `row` and `depth`
1307             // dimensions are contiguous (two innermost dimensions).
1308             if (rhs.rowStride() == 1 &&                               //
1309                 !pad_col0 && !pad_col1 && !pad_col2 && !pad_col3 &&   //
1310                 !dm0.padRow(start_row) && !dm0.padRow(max_row - 1) && //
1311                 !dm1.padRow(start_row) && !dm1.padRow(max_row - 1) && //
1312                 !dm2.padRow(start_row) && !dm2.padRow(max_row - 1) && //
1313                 !dm3.padRow(start_row) && !dm3.padRow(max_row - 1))
1314             {
1315               // Compute how many elements we can squeeze read.
1316               const Index start_depth = (c == start_col) ? rhs.depthOffset() : 0;
1317
1318               // Upper bound for the number of elements in the depth dimension
1319               // that we can squeeze read.
1320               const Index squeeze_length = (max_row - start_row) * rhs.patchDepth() - start_depth;
1321
1322               // Do not overshoot beyond the block size.
1323               const Index max_depth = start_depth + std::min<Index>(peeled_k - k, squeeze_length);
1324               eigen_assert((max_depth - start_depth) % packet_size == 0);
1325
1326               const Index idx0 = dm0.baseIndex(start_row, c);
1327               const Index idx1 = dm1.baseIndex(start_row, c);
1328               const Index idx2 = dm2.baseIndex(start_row, c);
1329               const Index idx3 = dm3.baseIndex(start_row, c);
1330
1331               for (Index d = start_depth; d < max_depth; d += packet_size)
1332               {
1333                 PacketBlock<Packet, 2> kernel0;
1334                 PacketBlock<Packet, 2> kernel1;
1335                 kernel0.packet[0] = rhs.packetNoPadding(d, idx0);
1336                 kernel0.packet[1] = rhs.packetNoPadding(d, idx1);
1337                 kernel1.packet[0] = rhs.packetNoPadding(d, idx2);
1338                 kernel1.packet[1] = rhs.packetNoPadding(d, idx3);
1339                 ptranspose(kernel0);
1340                 ptranspose(kernel1);
1341                 pstoreu(block + 0 * packet_size, kernel0.packet[0]);
1342                 pstoreu(block + 1 * packet_size, kernel1.packet[0]);
1343                 pstoreu(block + 2 * packet_size, kernel0.packet[1]);
1344                 pstoreu(block + 3 * packet_size, kernel1.packet[1]);
1345                 block += 4 * packet_size;
1346                 k += packet_size;
1347               }
1348
1349               // Go to the next column.
1350               continue;
1351             }
1352
1353             // If we can't squeeze reads, process rows one by one.
1354             for (Index r = start_row; r < max_row; ++r)
1355             {
1356               eigen_assert(k <= peeled_k);
1357
1358               const bool pad0 = pad_col0 || dm0.padRow(r);
1359               const bool pad1 = pad_col1 || dm1.padRow(r);
1360               const bool pad2 = pad_col2 || dm2.padRow(r);
1361               const bool pad3 = pad_col3 || dm3.padRow(r);
1362
1363               const Index idx0 = dm0.baseIndex(r, c);
1364               const Index idx1 = dm1.baseIndex(r, c);
1365               const Index idx2 = dm2.baseIndex(r, c);
1366               const Index idx3 = dm3.baseIndex(r, c);
1367
1368               const Index start_depth =
1369                 ((c == start_col) && (r == start_row)) ? rhs.depthOffset() : 0;
1370               const Index max_depth = rhs.maxDepth(peeled_k - k, start_depth);
1371               eigen_assert((max_depth - start_depth) % packet_size == 0);
1372
1373               for (Index d = start_depth; d < max_depth; d += packet_size)
1374               {
1375                 eigen_assert(k < peeled_k);
1376                 PacketBlock<Packet, 2> kernel0;
1377                 PacketBlock<Packet, 2> kernel1;
1378                 kernel0.packet[0] = pad0 ? pset1<Packet>(Scalar(0)) : rhs.packetNoPadding(d, idx0);
1379                 kernel0.packet[1] = pad1 ? pset1<Packet>(Scalar(0)) : rhs.packetNoPadding(d, idx1);
1380                 kernel1.packet[0] = pad2 ? pset1<Packet>(Scalar(0)) : rhs.packetNoPadding(d, idx2);
1381                 kernel1.packet[1] = pad3 ? pset1<Packet>(Scalar(0)) : rhs.packetNoPadding(d, idx3);
1382                 ptranspose(kernel0);
1383                 ptranspose(kernel1);
1384                 pstoreu(block + 0 * packet_size, kernel0.packet[0]);
1385                 pstoreu(block + 1 * packet_size, kernel1.packet[0]);
1386                 pstoreu(block + 2 * packet_size, kernel0.packet[1]);
1387                 pstoreu(block + 3 * packet_size, kernel1.packet[1]);
1388                 block += 4 * packet_size;
1389                 k += packet_size;
1390               }
1391             }
1392           }
1393
1394           // The loop above should fill peeled_k elements.
1395           eigen_assert(peeled_k == k);
1396         }
1397         else
1398         {
1399           // Packet can span multiple rows or columns, so we have to go
1400           // though the slower "standard" path.
1401           for (; k < peeled_k; k += packet_size)
1402           {
1403             PacketBlock<Packet, 2> kernel0;
1404             PacketBlock<Packet, 2> kernel1;
1405             kernel0.packet[0] = dm0.loadPacketStandard(k);
1406             kernel0.packet[1] = dm1.loadPacketStandard(k);
1407             kernel1.packet[0] = dm2.loadPacketStandard(k);
1408             kernel1.packet[1] = dm3.loadPacketStandard(k);
1409             ptranspose(kernel0);
1410             ptranspose(kernel1);
1411             pstoreu(block + 0 * packet_size, kernel0.packet[0]);
1412             pstoreu(block + 1 * packet_size, kernel1.packet[0]);
1413             pstoreu(block + 2 * packet_size, kernel0.packet[1]);
1414             pstoreu(block + 3 * packet_size, kernel1.packet[1]);
1415             block += 4 * packet_size;
1416           }
1417         }
1418       }
1419
1420       // Copy the remaining coefficients of the column block after the peeled_k.
1421       if (!non_standard_patches)
1422       {
1423         for (; k < depth; k++)
1424         {
1425           block[0] = dm0.loadCoeffStandard(k);
1426           block[1] = dm1.loadCoeffStandard(k);
1427           block[2] = dm2.loadCoeffStandard(k);
1428           block[3] = dm3.loadCoeffStandard(k);
1429           block += 4;
1430         }
1431       }
1432       else
1433       {
1434         for (; k < depth; k++)
1435         {
1436           block[0] = dm0(k);
1437           block[1] = dm1(k);
1438           block[2] = dm2(k);
1439           block[3] = dm3(k);
1440           block += 4;
1441         }
1442       }
1443     }
1444
1445     // Copy the remaining columns one at a time (nr==1).
1446     for (Index j2 = packet_cols4; j2 < cols; ++j2)
1447     {
1448       const SubMapper dm0 = rhs.getLinearMapper(0, j2);
1449       for (Index k = 0; k < depth; k++)
1450       {
1451         *block = dm0(k);
1452         block += 1;
1453       }
1454     }
1455   }
1456 };
1457
1458 // Special case for non-vectorized types such as float16.
1459 template <typename NewDimension, Index Rows, Index Cols, typename ArgType, typename Device,
1460           typename Scalar, typename Index, typename nocontract_t, typename contract_t,
1461           bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment, int nr>
1462 struct gemm_pack_rhs<
1463   Scalar, Index,
1464   TensorContractionSubMapper<
1465     Scalar, Index, Rhs,
1466     TensorEvaluator<
1467       const TensorReshapingOp<NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType>>, Device>,
1468     nocontract_t, contract_t, 1, inner_dim_contiguous, inner_dim_reordered, Alignment>,
1469   nr, ColMajor, false, false>
1470 {
1471   typedef TensorContractionSubMapper<
1472     Scalar, Index, Rhs,
1473     TensorEvaluator<
1474       const TensorReshapingOp<NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType>>, Device>,
1475     nocontract_t, contract_t, 1, inner_dim_contiguous, inner_dim_reordered, Alignment>
1476     SubMapper;
1477   typedef SubMapper DataMapper;
1478
1479   EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE)
1480
1481   EIGEN_DEVICE_FUNC
1482   EIGEN_DONT_INLINE void operator()(Scalar *block, const DataMapper &rhs, Index depth, Index cols,
1483                                     Index stride = 0, Index offset = 0) const
1484   {
1485     eigen_assert(stride == 0);
1486     eigen_assert(offset == 0);
1487
1488     (void)offset;
1489     (void)stride;
1490
1491     const Index packet_cols4 = (cols / 4) * 4;
1492
1493     for (Index j2 = 0; j2 < packet_cols4; j2 += 4)
1494     {
1495       const SubMapper dm0 = rhs.getLinearMapper(0, j2 + 0);
1496       const SubMapper dm1 = rhs.getLinearMapper(0, j2 + 1);
1497       const SubMapper dm2 = rhs.getLinearMapper(0, j2 + 2);
1498       const SubMapper dm3 = rhs.getLinearMapper(0, j2 + 3);
1499
1500       if (!rhs.nonStandardPatches())
1501       {
1502         for (Index k = 0; k < depth; k++)
1503         {
1504           block[0] = dm0.loadCoeffStandard(k);
1505           block[1] = dm1.loadCoeffStandard(k);
1506           block[2] = dm2.loadCoeffStandard(k);
1507           block[3] = dm3.loadCoeffStandard(k);
1508           block += 4;
1509         }
1510       }
1511       else
1512       {
1513         for (Index k = 0; k < depth; k++)
1514         {
1515           block[0] = dm0(k);
1516           block[1] = dm1(k);
1517           block[2] = dm2(k);
1518           block[3] = dm3(k);
1519           block += 4;
1520         }
1521       }
1522     }
1523
1524     // Copy the remaining columns one at a time (nr==1).
1525     for (Index j2 = packet_cols4; j2 < cols; ++j2)
1526     {
1527       const SubMapper dm0 = rhs.getLinearMapper(0, j2);
1528       for (Index k = 0; k < depth; k++)
1529       {
1530         *block = dm0(k);
1531         block += 1;
1532       }
1533     }
1534   }
1535 };
1536 } // end namespace internal
1537
1538 /** SpatialConvolution
1539  * \ingroup CXX11_NeuralNetworks_Module
1540  *
1541  * \brief Applies a 2D convolution over a multichannel input image.
1542  *
1543  * The input parameter is expected to be a tensor with a rank of 3 or more
1544  * (channels, height, width, and optionally others)
1545  * The kernel parameter is expected to be a 4D tensor (filters, channels,
1546  * kernel_height, kernel_width)
1547  * The input and the kernel must both be in col-major layout. The result will
1548  * also be in col-major layout.
1549  *
1550  * If col_in_stride, row_in_stride > 1, then applies convolution with holes
1551  * (aka atrous convolution), sampling every col_in_stride, row_in_stride input
1552  * pixels.
1553  *
1554  * If padding_top, padding_bottom, padding_left, or padding_right is specified,
1555  * then those paddings will be used to pad the input, and padding_type must be
1556  * PADDING_VALID.
1557  *
1558  * The result can be assigned to a tensor of rank equal to the rank of the
1559  * input. The dimensions of the result will be filters, height, width (and
1560  * others if applicable).
1561  *
1562  * It is possible to swap the order of the width and height dimensions provided
1563  * that the same order is used in the input, the kernel, and the output.
1564  *
1565  * It is also possible to add an output kernel to the contraction, output
1566  * kernel is called by Eigen when it "finalizes" the block of an output tensor.
1567  *
1568  */
1569 template <typename Input, typename Kernel, typename OutputKernel = const NoOpOutputKernel>
1570 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE static const typename internal::conditional<
1571   internal::traits<Input>::Layout == ColMajor,
1572   TensorReshapingOp<
1573     const DSizes<typename internal::traits<Input>::Index, internal::traits<Input>::NumDimensions>,
1574     const TensorContractionOp<
1575       const array<IndexPair<typename internal::traits<Input>::Index>, 1>,
1576       const TensorReshapingOp<const DSizes<typename internal::traits<Input>::Index, 2>,
1577                               const Kernel>,
1578       const TensorReshapingOp<const DSizes<typename internal::traits<Input>::Index, 2>,
1579                               const TensorImagePatchOp<Dynamic, Dynamic, const Input>>,
1580       const OutputKernel>>,
1581   TensorReshapingOp<
1582     const DSizes<typename internal::traits<Input>::Index, internal::traits<Input>::NumDimensions>,
1583     const TensorContractionOp<
1584       const array<IndexPair<typename internal::traits<Input>::Index>, 1>,
1585       const TensorReshapingOp<const DSizes<typename internal::traits<Input>::Index, 2>,
1586                               const TensorImagePatchOp<Dynamic, Dynamic, const Input>>,
1587       const TensorReshapingOp<const DSizes<typename internal::traits<Input>::Index, 2>,
1588                               const Kernel>,
1589       const OutputKernel>>>::type
1590 SpatialConvolution(const Input &input, const Kernel &kernel, const Index row_stride = 1,
1591                    const Index col_stride = 1, const PaddingType padding_type = PADDING_SAME,
1592                    const Index row_in_stride = 1, const Index col_in_stride = 1,
1593                    const OutputKernel &output_kernel = OutputKernel(), Index padding_top = 0,
1594                    Index padding_bottom = 0, Index padding_left = 0, Index padding_right = 0)
1595 {
1596   typedef typename internal::traits<Input>::Index TensorIndex;
1597   TensorRef<Tensor<typename internal::traits<Input>::Scalar, internal::traits<Input>::NumDimensions,
1598                    internal::traits<Input>::Layout, TensorIndex>>
1599     in(input);
1600   TensorRef<
1601     Tensor<typename internal::traits<Kernel>::Scalar, internal::traits<Kernel>::NumDimensions,
1602            internal::traits<Kernel>::Layout, TensorIndex>>
1603     kern(kernel);
1604
1605   EIGEN_STATIC_ASSERT(internal::traits<Input>::Layout == internal::traits<Kernel>::Layout,
1606                       YOU_MADE_A_PROGRAMMING_MISTAKE)
1607   const bool isColMajor = (internal::traits<Input>::Layout == ColMajor);
1608
1609   const int NumDims = internal::traits<Input>::NumDimensions;
1610
1611   // Number of filters to apply. This is the same as the output depth of the
1612   // result
1613   const TensorIndex kernelFilters = isColMajor ? kern.dimensions()[0] : kern.dimensions()[3];
1614   // Number of channels. This is the same as the input depth.
1615   const TensorIndex kernelChannels = isColMajor ? kern.dimensions()[1] : kern.dimensions()[2];
1616   const TensorIndex kernelRows = isColMajor ? kern.dimensions()[2] : kern.dimensions()[1];
1617   const TensorIndex kernelCols = isColMajor ? kern.dimensions()[3] : kern.dimensions()[0];
1618
1619   const Index kernelRowsEff = kernelRows + (kernelRows - 1) * (row_in_stride - 1);
1620   const Index kernelColsEff = kernelCols + (kernelCols - 1) * (col_in_stride - 1);
1621
1622   array<IndexPair<TensorIndex>, 1> contract_dims;
1623   contract_dims[0] = IndexPair<TensorIndex>(1, 0);
1624
1625   const TensorIndex InputRows = isColMajor ? in.dimension(1) : in.dimension(NumDims - 2);
1626   const TensorIndex InputCols = isColMajor ? in.dimension(2) : in.dimension(NumDims - 3);
1627   const bool padding_explicit = (padding_top || padding_bottom || padding_left || padding_right);
1628
1629   TensorIndex out_height;
1630   TensorIndex out_width;
1631   switch (padding_type)
1632   {
1633     case PADDING_VALID:
1634     {
1635       const TensorIndex InputRowsEff = InputRows + padding_top + padding_bottom;
1636       const TensorIndex InputColsEff = InputCols + padding_left + padding_right;
1637       out_height = divup(InputRowsEff - kernelRowsEff + 1, row_stride);
1638       out_width = divup(InputColsEff - kernelColsEff + 1, col_stride);
1639       break;
1640     }
1641     case PADDING_SAME:
1642     {
1643       eigen_assert(!padding_explicit);
1644       out_height = divup(InputRows, row_stride);
1645       out_width = divup(InputCols, col_stride);
1646       break;
1647     }
1648     default:
1649     {
1650       // Initialize unused variables to avoid a compiler warning
1651       out_height = 0;
1652       out_width = 0;
1653       eigen_assert(false && "unexpected padding");
1654     }
1655   }
1656
1657   // Molds the output of the patch extraction code into a 2d tensor:
1658   // - the first dimension (dims[0]): the patch values to be multiplied with the
1659   // kernels
1660   // - the second dimension (dims[1]): everything else
1661   DSizes<TensorIndex, 2> pre_contract_dims;
1662   if (isColMajor)
1663   {
1664     pre_contract_dims[0] = kernelChannels * kernelRows * kernelCols;
1665     pre_contract_dims[1] = out_height * out_width;
1666     for (int i = 3; i < NumDims; ++i)
1667     {
1668       pre_contract_dims[1] *= in.dimension(i);
1669     }
1670   }
1671   else
1672   {
1673     pre_contract_dims[1] = kernelChannels * kernelRows * kernelCols;
1674     pre_contract_dims[0] = out_height * out_width;
1675     for (int i = 0; i < NumDims - 3; ++i)
1676     {
1677       pre_contract_dims[0] *= in.dimension(i);
1678     }
1679   }
1680
1681   // Molds the output of the contraction into the shape expected by the used
1682   // (assuming this is ColMajor):
1683   // - 1st dim: kernel filters
1684   // - 2nd dim: output height
1685   // - 3rd dim: output width
1686   // - 4th dim and beyond: everything else including batch size
1687   DSizes<TensorIndex, NumDims> post_contract_dims;
1688   if (isColMajor)
1689   {
1690     post_contract_dims[0] = kernelFilters;
1691     post_contract_dims[1] = out_height;
1692     post_contract_dims[2] = out_width;
1693     for (int i = 3; i < NumDims; ++i)
1694     {
1695       post_contract_dims[i] = in.dimension(i);
1696     }
1697   }
1698   else
1699   {
1700     post_contract_dims[NumDims - 1] = kernelFilters;
1701     post_contract_dims[NumDims - 2] = out_height;
1702     post_contract_dims[NumDims - 3] = out_width;
1703     for (int i = 0; i < NumDims - 3; ++i)
1704     {
1705       post_contract_dims[i] = in.dimension(i);
1706     }
1707   }
1708
1709   DSizes<TensorIndex, 2> kernel_dims;
1710   if (isColMajor)
1711   {
1712     kernel_dims[0] = kernelFilters;
1713     kernel_dims[1] = kernelChannels * kernelRows * kernelCols;
1714   }
1715   else
1716   {
1717     kernel_dims[0] = kernelChannels * kernelRows * kernelCols;
1718     kernel_dims[1] = kernelFilters;
1719   }
1720   if (padding_explicit)
1721   {
1722     return choose(Cond<internal::traits<Input>::Layout == ColMajor>(),
1723                   kernel.reshape(kernel_dims)
1724                     .contract(input
1725                                 .extract_image_patches(kernelRows, kernelCols, row_stride,
1726                                                        col_stride, row_in_stride, col_in_stride,
1727                                                        /*row_inflate_stride=*/1,
1728                                                        /*col_inflate_stride=*/1, padding_top,
1729                                                        padding_bottom, padding_left, padding_right,
1730                                                        /*padding_value=*/0)
1731                                 .reshape(pre_contract_dims),
1732                               contract_dims, output_kernel)
1733                     .reshape(post_contract_dims),
1734                   input
1735                     .extract_image_patches(kernelRows, kernelCols, row_stride, col_stride,
1736                                            row_in_stride, col_in_stride,
1737                                            /*row_inflate_stride=*/1,
1738                                            /*col_inflate_stride=*/1, padding_top, padding_bottom,
1739                                            padding_left, padding_right,
1740                                            /*padding_value=*/0)
1741                     .reshape(pre_contract_dims)
1742                     .contract(kernel.reshape(kernel_dims), contract_dims, output_kernel)
1743                     .reshape(post_contract_dims));
1744   }
1745   else
1746   {
1747     return choose(
1748       Cond<internal::traits<Input>::Layout == ColMajor>(),
1749       kernel.reshape(kernel_dims)
1750         .contract(input
1751                     .extract_image_patches(kernelRows, kernelCols, row_stride, col_stride,
1752                                            row_in_stride, col_in_stride, padding_type)
1753                     .reshape(pre_contract_dims),
1754                   contract_dims, output_kernel)
1755         .reshape(post_contract_dims),
1756       input
1757         .extract_image_patches(kernelRows, kernelCols, row_stride, col_stride, row_in_stride,
1758                                col_in_stride, padding_type)
1759         .reshape(pre_contract_dims)
1760         .contract(kernel.reshape(kernel_dims), contract_dims, output_kernel)
1761         .reshape(post_contract_dims));
1762   }
1763 }
1764
1765 } // end namespace Eigen
1766
1767 #endif // __NNFW_CKER_EIGEN_EIGEN_SPATIAL_CONVOLUTIONS_INL_H__