2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
9 * http://www.apache.org/licenses/LICENSE-2.0
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
18 #ifndef __NNFW_CKER_EIGEN_EIGEN_SPATIAL_CONVOLUTIONS_INL_H__
19 #define __NNFW_CKER_EIGEN_EIGEN_SPATIAL_CONVOLUTIONS_INL_H__
21 #include "cker/eigen/eigen_convolution_helpers.h"
23 // Note this header is used in both TF and TFLite.
30 // WARNING: Most of the code here implicitly assumes that the matrix is in
31 // ColMajor layout. This is guaranteed by the tensor contraction (see
32 // TensorContraction.h).
34 // Inside Eigen a tensor contraction is represented by a matrix multiplication.
35 // We don't want to actually extract image patches and reshape the result into
36 // a matrix (this involves allocating huge extra memory), so the patch
37 // extraction and reshape operations are implicit.
39 // TensorContractionInputMapper takes a matrix index and returns the coefficient
40 // (or the packet) of the "virtual tensor", that would be at that index if we
41 // were to actually reshape the result of patch extraction.
43 // TensorContractionSubMapper provides a similar view into the "virtual matrix"
44 // at the given vertical and horizontal offsets.
46 // "Virtual matrix" dimensions:
47 // *0: kernelChannels * kernelRows * kernelCols;
48 // 1: out_height * out_width; * OTHERS (e.g batches, etc...)
50 // *) extracted patches are continuous in memory (innermost dimension assuming
53 // With this dimensions:
54 // row - offset within a single patch (in code: patchId)
55 // col - index of the extracted patch (in code: patchIndex)
56 // patchIndex ∈ [0..num_patches * OTHERS] (batch and other dimensions)
58 // TODO(ezhulenev): Consolidate this part of the code with the image patch
59 // extraction code since they are both very similar.
61 template <typename NewDimension, Index Rows, Index Cols, typename ArgType, typename Device,
62 typename Scalar_, typename Index, typename nocontract_t, typename contract_t, int Side,
63 int packet_size, bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment>
64 class TensorContractionInputMapper<
67 const TensorReshapingOp<NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType>>, Device>,
68 nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment>
71 typedef Scalar_ Scalar;
73 typedef TensorContractionInputMapper<
76 const TensorReshapingOp<NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType>>, Device>,
77 nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment>
80 typedef TensorContractionSubMapper<
83 const TensorReshapingOp<NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType>>, Device>,
84 nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment>
87 typedef SubMapper VectorMapper;
88 typedef SubMapper LinearMapper;
89 typedef typename packet_traits<Scalar>::type Packet;
91 typedef TensorEvaluator<ArgType, Device> TensorEvaluatorT;
94 TensorContractionInputMapper(
95 const TensorEvaluator<
96 const TensorReshapingOp<NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType>>, Device>
98 const nocontract_t &, const nocontract_t &, const contract_t &, const contract_t &)
99 : m_impl(tensor.impl().impl())
103 if (internal::traits<ArgType>::Layout == ColMajor)
105 patch_depth = tensor.impl().dimensions()[0];
106 patch_rows = tensor.impl().dimensions()[1];
107 m_patch_cols = tensor.impl().dimensions()[2];
108 m_num_patches = tensor.impl().dimensions()[3];
112 const size_t NumDims = tensor.impl().dimensions().size();
113 patch_depth = tensor.impl().dimensions()[NumDims - 1];
114 patch_rows = tensor.impl().dimensions()[NumDims - 2];
115 m_patch_cols = tensor.impl().dimensions()[NumDims - 3];
116 m_num_patches = tensor.impl().dimensions()[NumDims - 4];
119 // Strides for navigating through the single patch.
120 m_patch_row_stride = patch_depth;
121 m_patch_col_stride = patch_rows * m_patch_row_stride;
123 m_patch_row_inflate_strides = tensor.impl().rowInflateStride();
124 m_patch_col_inflate_strides = tensor.impl().colInflateStride();
126 m_colStride = patch_rows;
128 m_outputRows = tensor.impl().outputRows();
129 m_outputCols = tensor.impl().outputCols();
130 m_row_strides = tensor.impl().userRowStride();
131 m_col_strides = tensor.impl().userColStride();
133 m_in_row_strides = tensor.impl().userInRowStride();
134 m_in_col_strides = tensor.impl().userInColStride();
136 if (internal::traits<ArgType>::Layout == ColMajor)
138 m_inputRows = tensor.impl().impl().dimensions()[1];
139 m_inputCols = tensor.impl().impl().dimensions()[2];
143 const int NumDims = tensor.impl().impl().dimensions().size();
144 m_inputRows = tensor.impl().impl().dimensions()[NumDims - 2];
145 m_inputCols = tensor.impl().impl().dimensions()[NumDims - 3];
148 m_rowInputStride = patch_depth;
149 m_colInputStride = patch_depth * m_inputRows;
150 m_patchInputStride = patch_depth * m_inputRows * m_inputCols;
152 m_rowPaddingTop = tensor.impl().rowPaddingTop();
153 m_colPaddingLeft = tensor.impl().colPaddingLeft();
155 m_fastPatchRowStride = internal::TensorIntDivisor<Index>(m_patch_row_stride);
156 m_fastPatchColStride = internal::TensorIntDivisor<Index>(m_patch_col_stride);
157 m_fastInputRowStride = internal::TensorIntDivisor<Index>(m_patch_row_inflate_strides);
158 m_fastInputColStride = internal::TensorIntDivisor<Index>(m_patch_col_inflate_strides);
159 m_fastNumPatches = internal::TensorIntDivisor<Index>(m_num_patches);
160 m_fastColStride = internal::TensorIntDivisor<Index>(m_colStride);
161 m_fastOutputRows = internal::TensorIntDivisor<Index>(m_outputRows);
162 m_fastDimZero = internal::TensorIntDivisor<Index>(patch_depth);
166 TensorContractionInputMapper(const TensorContractionInputMapper &base_mapper)
167 : m_impl(base_mapper.m_impl)
169 m_patch_cols = base_mapper.m_patch_cols;
170 m_num_patches = base_mapper.m_num_patches;
172 m_patch_row_stride = base_mapper.m_patch_row_stride;
173 m_patch_col_stride = base_mapper.m_patch_col_stride;
175 m_patch_row_inflate_strides = base_mapper.m_patch_row_inflate_strides;
176 m_patch_col_inflate_strides = base_mapper.m_patch_col_inflate_strides;
178 m_colStride = base_mapper.m_colStride;
180 m_rowInputStride = base_mapper.m_rowInputStride;
181 m_colInputStride = base_mapper.m_colInputStride;
182 m_patchInputStride = base_mapper.m_patchInputStride;
184 m_inputRows = base_mapper.m_inputRows;
185 m_inputCols = base_mapper.m_inputCols;
187 m_outputRows = base_mapper.m_outputRows;
188 m_outputCols = base_mapper.m_outputCols;
189 m_row_strides = base_mapper.m_row_strides;
190 m_col_strides = base_mapper.m_col_strides;
192 m_in_row_strides = base_mapper.m_in_row_strides;
193 m_in_col_strides = base_mapper.m_in_col_strides;
195 m_rowPaddingTop = base_mapper.m_rowPaddingTop;
196 m_colPaddingLeft = base_mapper.m_colPaddingLeft;
198 m_fastPatchRowStride = base_mapper.m_fastPatchRowStride;
199 m_fastPatchColStride = base_mapper.m_fastPatchColStride;
200 m_fastInputRowStride = base_mapper.m_fastInputRowStride;
201 m_fastInputColStride = base_mapper.m_fastInputColStride;
202 m_fastNumPatches = base_mapper.m_fastNumPatches;
203 m_fastColStride = base_mapper.m_fastColStride;
204 m_fastOutputRows = base_mapper.m_fastOutputRows;
205 m_fastDimZero = base_mapper.m_fastDimZero;
208 // If true, turns off some optimizations for loading packets since the image
209 // patches are "non-standard" such as there are non-trivial strides or
210 // inflations in the input.
212 EIGEN_ALWAYS_INLINE bool nonStandardPatches() const
214 return m_in_row_strides != 1 || m_in_col_strides != 1 || m_patch_row_inflate_strides != 1 ||
215 m_patch_col_inflate_strides != 1;
219 EIGEN_STRONG_INLINE SubMapper getSubMapper(Index i, Index j) const
221 return SubMapper(*this, i, j);
225 EIGEN_STRONG_INLINE LinearMapper getLinearMapper(Index i, Index j) const
227 return LinearMapper(*this, i, j);
231 EIGEN_ALWAYS_INLINE Scalar operator()(Index row) const
233 Index rowIndex, colIndex, otherIndex;
234 computeBaseIndices(0, rowIndex, colIndex, otherIndex);
235 return loadCoeff(row, rowIndex, colIndex, otherIndex);
238 // Load the coefficient at the patchIndex location instead of the usual
240 // m_colIndex, m_otherIndex. This is currently only used by the gpu code.
243 EIGEN_STRONG_INLINE Scalar operator()(Index row, Index patchIndex) const
245 Index rowIndex, colIndex, otherIndex;
246 computeBaseIndices(patchIndex, rowIndex, colIndex, otherIndex);
247 return loadCoeff(row, rowIndex, colIndex, otherIndex);
251 EIGEN_ALWAYS_INLINE Packet loadPacket(Index row) const
253 Index rowIndex, colIndex, otherIndex;
254 computeBaseIndices(0, rowIndex, colIndex, otherIndex);
255 return loadPacket(row, rowIndex, colIndex, otherIndex);
258 // Load the packet at the patchIndex location instead of the usual m_rowIndex,
259 // m_colIndex, m_otherIndex. This is currently only used by the gpu code.
261 EIGEN_ALWAYS_INLINE Packet loadPacket(Index row, Index patchIndex) const
263 Index rowIndex, colIndex, otherIndex;
264 computeBaseIndices(patchIndex, rowIndex, colIndex, otherIndex);
265 return loadPacket(row, rowIndex, colIndex, otherIndex);
269 EIGEN_ALWAYS_INLINE const TensorEvaluator<ArgType, Device> &impl() const { return m_impl; }
272 EIGEN_ALWAYS_INLINE Index patchDepth() const { return m_rowInputStride; }
274 EIGEN_ALWAYS_INLINE Index patchRows() const { return m_colStride; }
276 EIGEN_ALWAYS_INLINE Index patchCols() const { return m_patch_cols; }
279 friend class TensorContractionSubMapper<
282 const TensorReshapingOp<NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType>>, Device>,
283 nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment>;
285 // Load coefficient from a patch specified by the "within patch offset"
286 // (patchId) and the precomputed indices of the first element of the patch.
288 EIGEN_STRONG_INLINE Scalar loadCoeff(Index patchId, Index rowIndex, Index colIndex,
289 Index otherIndex) const
291 // Find the offset of the element wrt the location of the first element.
292 const Index patchOffset = patchId / m_fastDimZero;
294 const Index colOffset = patchOffset / m_fastColStride;
295 const Index inputCol = colIndex + colOffset * m_in_col_strides;
296 const Index origInputCol = (m_patch_col_inflate_strides == 1)
298 : ((inputCol >= 0) ? (inputCol / m_fastInputColStride) : 0);
300 const Index rowOffset = patchOffset - colOffset * m_colStride;
301 const Index inputRow = rowIndex + rowOffset * m_in_row_strides;
302 const Index origInputRow = (m_patch_row_inflate_strides == 1)
304 : ((inputRow >= 0) ? (inputRow / m_fastInputRowStride) : 0);
305 if (origInputCol < 0 || origInputRow < 0 || origInputCol >= m_inputCols ||
306 origInputRow >= m_inputRows || (inputCol != origInputCol * m_patch_col_inflate_strides) ||
307 (inputRow != origInputRow * m_patch_row_inflate_strides))
311 const Index depth = patchId - patchOffset * patchDepth();
312 const Index inputIndex =
313 depth + origInputRow * m_rowInputStride + origInputCol * m_colInputStride + otherIndex;
314 return m_impl.coeff(inputIndex);
317 // This is the same as loadCoeff(...), but optimized for all `inflate_strides`
318 // and `in_strides` equal to 1 (template specialization without templates).
320 EIGEN_STRONG_INLINE Scalar loadCoeffStandard(Index patchId, Index rowIndex, Index colIndex,
321 Index otherIndex) const
323 eigen_assert(!nonStandardPatches());
325 // Find the offset of the element wrt the location of the first element.
326 const Index patchOffset = patchId / m_fastDimZero;
327 const Index colOffset = patchOffset / m_fastColStride;
328 const Index rowOffset = patchOffset - colOffset * m_colStride;
329 const Index inputCol = colIndex + colOffset;
330 const Index inputRow = rowIndex + rowOffset;
331 if (inputCol < 0 || inputCol >= m_inputCols || inputRow < 0 || inputRow >= m_inputRows)
335 const Index depth = patchId - patchOffset * patchDepth();
336 const Index inputIndex =
337 depth + inputRow * m_rowInputStride + inputCol * m_colInputStride + otherIndex;
338 return m_impl.coeff(inputIndex);
341 // Load packet from a patch specified by the "within patch offset"
342 // (patchId) and the precomputed indices of the first element of the patch.
344 EIGEN_ALWAYS_INLINE Packet loadPacket(Index patchId, Index rowIndex, Index colIndex,
345 Index otherIndex) const
347 const Index packetSize = internal::unpacket_traits<Packet>::size;
348 EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE)
349 eigen_assert(patchId < patchDepth() * patchRows() * m_patch_cols);
351 if (nonStandardPatches())
353 return packetWithPossibleZero(patchId, rowIndex, colIndex, otherIndex);
355 typedef decltype(m_impl) TensorEvaluatorT;
356 return loadPacketStandard<Packet, TensorEvaluatorT>(patchId, rowIndex, colIndex, otherIndex);
359 // Helper function to load a 'partial' packet - this is the single column
360 // part of a packet that is split across two columns. In the 'partial' packet,
361 // the elements corresponding to the column (specified through colOffset) are
362 // loaded and the rest of the elements are zero-filled into the 'partial'
363 // packet. This function is called from loadPacketStandardFromTwoColumns().
364 // This code path is exercised only when the packet type supports masked load
365 // and when the partial packet load is available in the TensorEvaluator.
367 EIGEN_ALWAYS_INLINE Packet loadPartialPacketStandard(Index rowIndex, Index colIndex,
368 Index otherIndex, Index patchId,
370 const Index patchOffsets[],
371 Index colOffset) const
373 const Index inputCol = colIndex + colOffset;
374 const Index rowOffsets[2] = {patchOffsets[0] - colOffset * m_colStride,
375 patchOffsets[1] - colOffset * m_colStride};
376 const Index inputRows[2] = {rowIndex + rowOffsets[0], rowIndex + rowOffsets[1]};
378 if (inputRows[0] >= m_inputRows || inputRows[1] < 0 || inputCol >= m_inputCols || inputCol < 0)
380 // Partial packet is all zeros
381 return internal::pset1<Packet>(Scalar(0));
383 else if (inputRows[0] >= 0 && inputRows[1] < m_inputRows)
385 // From inputIndex-span[0], we need to load elements starting from index
386 // span[0] all the way upto (and including) span[1].
387 const Index depth = patchId - patchOffsets[0] * patchDepth();
388 const Index inputIndex =
389 depth + inputRows[0] * m_rowInputStride + inputCol * m_colInputStride + otherIndex;
390 return m_impl.template partialPacket<Packet>(inputIndex - span[0],
391 mask<Packet>(span[0], span[1] + 1));
395 // Using slow path for this partial packet.
396 // We need to load elements starting from index span[0] all the way upto
397 // (and including) span[1]. We split this load into 3 parts:
398 // 0 : span[0]-1 - Zeros will be loaded for these indices
399 // span[0] : span[1] - Elements will be loaded here for these indices
400 // span[1]+1 : packetSize-1 - Zeross will be loaded for these indices
401 const Index packetSize = internal::unpacket_traits<Packet>::size;
403 typename internal::remove_const<Scalar>::type values[packetSize];
404 for (int i = 0; i < span[0]; ++i)
405 values[i] = Scalar(0);
406 for (int i = span[0]; i < span[1] + 1; ++i)
407 values[i] = loadCoeff(patchId - span[0] + i, rowIndex, colIndex, otherIndex);
408 for (int i = span[1] + 1; i < packetSize; ++i)
409 values[i] = Scalar(0);
410 return internal::pload<Packet>(values);
414 // Helper function to load a packet that is split across two columns.
415 // If required, this function is called from loadPacketStandard() when the
416 // packet type supports masked load and when the partial packet load is
417 // available in the TensorEvaluator.
419 EIGEN_ALWAYS_INLINE Packet loadPacketStandardFromTwoColumns(Index patchId, Index rowIndex,
420 Index colIndex, Index otherIndex,
421 const Index patchOffsets[],
422 const Index colOffsets[]) const
424 eigen_assert(colOffsets[1] == colOffsets[0] + 1);
425 const Index packetSize = internal::unpacket_traits<Packet>::size;
427 // Packet to load will be split into 2 parts where each part spans a single
428 // column. First determine where to split.
429 const Index patchIdSplit = ((colOffsets[1] * m_colStride) * m_rowInputStride) - 1;
430 const Index patchOffsetSplit = patchIdSplit / m_fastDimZero;
432 // patchIds[i]: patchId corresponding to partial packet i
433 // spans[i]: Start and end indices corresponding to the elements
434 // to be loaded for partial packet i
435 // patchOffsets2Cols[i]: patchOffsets corresponding to partial packet i
436 const Index patchIds[2] = {patchId, patchIdSplit + 1};
437 const Index spans[2][2] = {{0, patchIdSplit - patchId},
438 {patchIdSplit - patchId + 1, packetSize - 1}};
439 const Index patchOffsets2Cols[2][2] = {{patchOffsets[0], patchOffsetSplit},
440 {patchOffsetSplit + 1, patchOffsets[1]}};
442 // Load partial packets and do bit-wise OR to generate required packet
443 return internal::por<Packet>(
444 loadPartialPacketStandard(rowIndex, colIndex, otherIndex, patchIds[0], spans[0],
445 patchOffsets2Cols[0], colOffsets[0]),
446 loadPartialPacketStandard(rowIndex, colIndex, otherIndex, patchIds[1], spans[1],
447 patchOffsets2Cols[1], colOffsets[1]));
450 // Helper function to load a packet that is present in a single columns.
451 // If required, this function is called from loadPacketStandard().
453 EIGEN_ALWAYS_INLINE Packet loadPacketStandardFromSingleColumn(Index patchId, Index rowIndex,
454 Index colIndex, Index otherIndex,
455 const Index patchOffsets[],
456 const Index colOffsets[],
457 const Index inputCols[]) const
459 eigen_assert(colOffsets[0] == colOffsets[1]);
460 const Index rowOffsets[2] = {patchOffsets[0] - colOffsets[0] * m_colStride,
461 patchOffsets[1] - colOffsets[1] * m_colStride};
462 eigen_assert(rowOffsets[0] <= rowOffsets[1]);
463 const Index inputRows[2] = {rowIndex + rowOffsets[0], rowIndex + rowOffsets[1]};
465 if (inputRows[0] >= m_inputRows || inputRows[1] < 0)
468 return internal::pset1<Packet>(Scalar(0)); // all zeros
471 if (inputRows[0] >= 0 && inputRows[1] < m_inputRows)
474 const Index depth = patchId - patchOffsets[0] * patchDepth();
475 const Index inputIndex =
476 depth + inputRows[0] * m_rowInputStride + inputCols[0] * m_colInputStride + otherIndex;
477 return m_impl.template packet<Unaligned>(inputIndex);
479 return packetWithPossibleZero(patchId, rowIndex, colIndex, otherIndex);
482 // Load standard packet from a patch specified by the "within patch offset"
483 // (patchId) and the precomputed indices of the first element of the patch.
484 // This function will be called if partial packet loading is not available
485 // for the TensorEvaluator or if the packet type does not support masked
487 template <typename PacketT, typename TensorEvaluatorT>
488 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE typename std::enable_if<
489 !TensorEvaluatorHasPartialPacket<TensorEvaluatorT, PacketT, Index>::value, PacketT>::type
490 loadPacketStandard(Index patchId, Index rowIndex, Index colIndex, Index otherIndex) const
492 const Index packetSize = internal::unpacket_traits<Packet>::size;
493 EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE)
494 eigen_assert(patchId < patchDepth() * patchRows() * m_patch_cols);
496 eigen_assert(!nonStandardPatches());
498 if ((patchDepth() % packetSize) == 0)
500 return loadPacketFast(patchId, rowIndex, colIndex, otherIndex);
503 // Offsets and input calculation here are identical to
504 // loadCoeffStandard(...), but repeated twice.
505 const Index patchOffsets[2] = {patchId / m_fastDimZero,
506 (patchId + packetSize - 1) / m_fastDimZero};
507 const Index colOffsets[2] = {patchOffsets[0] / m_fastColStride,
508 patchOffsets[1] / m_fastColStride};
509 const Index inputCols[2] = {colIndex + colOffsets[0], colIndex + colOffsets[1]};
511 if (inputCols[0] >= m_inputCols || inputCols[1] < 0)
514 return internal::pset1<Packet>(Scalar(0));
516 if (inputCols[0] == inputCols[1])
518 return loadPacketStandardFromSingleColumn(patchId, rowIndex, colIndex, otherIndex,
519 patchOffsets, colOffsets, inputCols);
521 return packetWithPossibleZero(patchId, rowIndex, colIndex, otherIndex);
524 // Load standard packet from a patch specified by the "within patch offset"
525 // (patchId) and the precomputed indices of the first element of the patch.
526 // This function will be called if partial packet loading is available for
527 // the TensorEvaluator and if the packet type supports masked load.
528 // The only difference between this and the other case is that if the packet
529 // to load is split across two columns, then in this case instead of going to
530 // the slow (element-by-element) load, we load two packets - each containing
531 // elements from one of the columns (rest of the elements of the packets are
532 // zeroes), and then combine these two packets to generate the required
533 // packet. The idea is to enable fast load (if possible) of these 'partial'
535 template <typename PacketT, typename TensorEvaluatorT>
536 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE typename std::enable_if<
537 TensorEvaluatorHasPartialPacket<TensorEvaluatorT, PacketT, Index>::value, PacketT>::type
538 loadPacketStandard(Index patchId, Index rowIndex, Index colIndex, Index otherIndex) const
540 const Index packetSize = internal::unpacket_traits<PacketT>::size;
541 EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE)
542 eigen_assert(patchId < patchDepth() * patchRows() * m_patch_cols);
544 eigen_assert(!nonStandardPatches());
546 if ((patchDepth() % packetSize) == 0)
548 return loadPacketFast(patchId, rowIndex, colIndex, otherIndex);
551 // Offsets and input calculation here are identical to
552 // loadCoeffStandard(...), but repeated twice.
553 const Index patchOffsets[2] = {patchId / m_fastDimZero,
554 (patchId + packetSize - 1) / m_fastDimZero};
555 const Index colOffsets[2] = {patchOffsets[0] / m_fastColStride,
556 patchOffsets[1] / m_fastColStride};
557 const Index inputCols[2] = {colIndex + colOffsets[0], colIndex + colOffsets[1]};
559 if (inputCols[0] >= m_inputCols || inputCols[1] < 0)
562 return internal::pset1<PacketT>(Scalar(0));
564 if (inputCols[0] == inputCols[1])
566 return loadPacketStandardFromSingleColumn(patchId, rowIndex, colIndex, otherIndex,
567 patchOffsets, colOffsets, inputCols);
569 if (inputCols[1] == inputCols[0] + 1)
571 return loadPacketStandardFromTwoColumns(patchId, rowIndex, colIndex, otherIndex, patchOffsets,
574 return packetWithPossibleZero(patchId, rowIndex, colIndex, otherIndex);
578 EIGEN_ALWAYS_INLINE Packet loadPacketFast(Index patchId, Index rowIndex, Index colIndex,
579 Index otherIndex) const
581 const Index packetSize = internal::unpacket_traits<Packet>::size;
582 EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE)
583 eigen_assert(patchId < patchDepth() * patchRows() * m_patch_cols);
585 eigen_assert(!nonStandardPatches());
586 eigen_assert((patchDepth() % packetSize) == 0);
587 // Find the offset of the element wrt the location of the first element.
588 const Index patchOffset = patchId / m_fastDimZero;
589 eigen_assert((patchId + packetSize - 1) / m_fastDimZero == patchOffset);
591 const Index colOffset = patchOffset / m_fastColStride;
592 const Index rowOffset = patchOffset - colOffset * m_colStride;
593 const Index inputCol = colIndex + colOffset;
594 const Index inputRow = rowIndex + rowOffset;
595 if (inputCol < 0 || inputRow < 0 || inputCol >= m_inputCols || inputRow >= m_inputRows)
598 return internal::pset1<Packet>(Scalar(0));
601 const Index depth = patchId - patchOffset * patchDepth();
602 const Index inputIndex =
603 depth + inputRow * m_rowInputStride + inputCol * m_colInputStride + otherIndex;
604 return m_impl.template packet<Unaligned>(inputIndex);
607 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet packetWithPossibleZero(Index patchId, Index rowIndex,
609 Index otherIndex) const
611 const int packetSize = internal::unpacket_traits<Packet>::size;
613 typename internal::remove_const<Scalar>::type values[packetSize];
614 for (int i = 0; i < packetSize; ++i)
616 values[i] = loadCoeff(patchId + i, rowIndex, colIndex, otherIndex);
618 Packet rslt = internal::pload<Packet>(values);
622 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void
623 computeBaseIndices(Index patchIndex, Index &rowIndex, Index &colIndex, Index &otherIndex) const
625 const size_t NumInputDims =
626 array_size<typename TensorEvaluator<ArgType, Device>::Dimensions>::value;
627 otherIndex = (NumInputDims == 3) ? 0 : patchIndex / m_fastNumPatches;
628 const Index patch2DIndex =
629 (NumInputDims == 3) ? patchIndex : (patchIndex - otherIndex * m_num_patches);
630 otherIndex *= m_patchInputStride;
631 colIndex = patch2DIndex / m_fastOutputRows;
632 rowIndex = patch2DIndex - colIndex * m_outputRows;
633 colIndex = colIndex * m_col_strides - m_colPaddingLeft;
634 rowIndex = rowIndex * m_row_strides - m_rowPaddingTop;
637 Index m_patch_cols; // number of columns in the patch
638 Index m_num_patches; // number of patches to extract.
640 // Strides for navigating through the single patch.
641 Index m_patch_row_stride;
642 Index m_patch_col_stride;
643 internal::TensorIntDivisor<Index> m_fastPatchRowStride;
644 internal::TensorIntDivisor<Index> m_fastPatchColStride;
646 Index m_patch_row_inflate_strides; // the strides for row inflation in the
648 Index m_patch_col_inflate_strides; // the strides for col inflation in the
650 // Fast representation of inflation strides.
651 internal::TensorIntDivisor<Index> m_fastInputRowStride;
652 internal::TensorIntDivisor<Index> m_fastInputColStride;
656 internal::TensorIntDivisor<Index> m_fastNumPatches;
657 internal::TensorIntDivisor<Index> m_fastColStride;
659 Index m_rowInputStride; // row stride in the input tensor
660 Index m_colInputStride; // col stride in the input tensor
661 Index m_patchInputStride; // patch stride in the input tensor
663 Index m_inputRows; // Number of rows in the input tensor
664 Index m_inputCols; // Number of cols in the input tensor
666 Index m_outputRows; // Number of convolution output rows
667 Index m_outputCols; // Number of convolution output column
669 Index m_row_strides; // User specified row stride
670 Index m_col_strides; // User specified col stride
672 Index m_in_row_strides; // User specified input row stride
673 Index m_in_col_strides; // User specified input col stride
675 Index m_rowPaddingTop; // Row padding
676 Index m_colPaddingLeft; // Column padding
678 internal::TensorIntDivisor<Index> m_fastOutputRows;
679 internal::TensorIntDivisor<Index> m_fastDimZero;
681 const TensorEvaluator<ArgType, Device> m_impl;
684 template <typename NewDimension, Index Rows, Index Cols, typename ArgType, typename Device,
685 typename Scalar, typename Index, typename nocontract_t, typename contract_t, int Side,
686 int packet_size, bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment>
687 class TensorContractionSubMapper<
690 const TensorReshapingOp<NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType>>, Device>,
691 nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment>
694 typedef typename packet_traits<Scalar>::type Packet;
695 typedef typename packet_traits<Scalar>::half HalfPacket;
697 typedef TensorContractionInputMapper<
700 const TensorReshapingOp<NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType>>, Device>,
701 nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment>
704 typedef TensorContractionSubMapper<
707 const TensorReshapingOp<NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType>>, Device>,
708 nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment>
711 typedef Self LinearMapper;
713 typedef typename ParentMapper::TensorEvaluatorT TensorEvaluatorT;
715 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionSubMapper(const ParentMapper &base_mapper,
718 : m_depth_offset(vert_offset), m_col_offset(horiz_offset), m_base_mapper(base_mapper)
720 m_base_mapper.computeBaseIndices(m_col_offset, m_rowIndex, m_colIndex, m_otherIndex);
722 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionSubMapper(const Self &base_mapper,
725 : m_depth_offset(vert_offset + base_mapper.m_depth_offset),
726 m_col_offset(horiz_offset + base_mapper.m_col_offset),
727 m_base_mapper(base_mapper.m_base_mapper)
729 m_base_mapper.computeBaseIndices(m_col_offset, m_rowIndex, m_colIndex, m_otherIndex);
731 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i) const
733 return m_base_mapper.loadCoeff(i + m_depth_offset, m_rowIndex, m_colIndex, m_otherIndex);
735 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i, Index j) const
737 return m_base_mapper(i + m_depth_offset, j + m_col_offset);
740 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i) const
742 return m_base_mapper.loadPacket(i + m_depth_offset, m_rowIndex, m_colIndex, m_otherIndex);
744 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i, Index j) const
746 return m_base_mapper.template loadPacket<Alignment>(i + m_depth_offset, j + m_col_offset);
748 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar loadCoeffStandard(Index i) const
750 return m_base_mapper.loadCoeffStandard(i + m_depth_offset, m_rowIndex, m_colIndex,
754 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacketFast(Index i) const
756 return m_base_mapper.loadPacketFast(i + m_depth_offset, m_rowIndex, m_colIndex, m_otherIndex);
758 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacketStandard(Index i) const
760 typedef decltype(m_base_mapper.m_impl) TensorEvaluatorT;
761 return m_base_mapper.template loadPacketStandard<Packet, TensorEvaluatorT>(
762 i + m_depth_offset, m_rowIndex, m_colIndex, m_otherIndex);
764 template <typename Packet> EIGEN_DEVICE_FUNC bool aligned(Index) const { return false; }
767 EIGEN_ALWAYS_INLINE bool nonStandardPatches() const { return m_base_mapper.nonStandardPatches(); }
769 // Max(Col|Row|Depth): compute the upper limit for the column, row and depth
770 // index respectively that fits into the peeled_k elements starting at
774 EIGEN_ALWAYS_INLINE Index maxCol(const Index peeled_k) const
776 const Index max_col =
777 (m_depth_offset + (peeled_k == 0 ? 0 : peeled_k - 1)) / fastPatchColStride();
778 return std::min<Index>(1 + max_col, patchCols());
782 EIGEN_ALWAYS_INLINE Index maxRow(const Index peeled_k, const Index col) const
784 const Index max_row =
785 (m_depth_offset + (peeled_k == 0 ? 0 : peeled_k - 1) - col * patchColStride()) /
786 fastPatchRowStride();
787 return std::min<Index>(1 + max_row, patchRows());
791 EIGEN_ALWAYS_INLINE Index maxDepth(const Index peeled_k, const Index col, Index row) const
793 const Index max_depth = m_depth_offset + peeled_k - //
794 col * patchColStride() - //
795 row * patchRowStride();
796 return std::min<Index>(max_depth, patchDepth());
799 // MaxDepth uses only the remaining number of elements in the peeled_k.
801 EIGEN_ALWAYS_INLINE Index maxDepth(const Index num_elements, const Index start_depth) const
803 return std::min<Index>(start_depth + num_elements, patchDepth());
806 // Every register matters in this code, so sometimes to prevent register
807 // spilling, instead of the variable that you would expect to see, we use
808 // another one, that is guaranteed to have the same value. E.g. patch depth is
809 // always the same as input depth, and it's also the same as input row stride.
810 // Bunch of other parameters have similar relations.
812 typedef internal::TensorIntDivisor<Index> IndexDivisor;
815 EIGEN_ALWAYS_INLINE Index patchDepth() const { return m_base_mapper.m_rowInputStride; }
817 EIGEN_ALWAYS_INLINE Index patchRows() const { return m_base_mapper.m_colStride; }
819 EIGEN_ALWAYS_INLINE Index patchCols() const { return m_base_mapper.m_patch_cols; }
822 EIGEN_ALWAYS_INLINE Index patchRowStride() const
824 eigen_assert(patchDepth() == m_base_mapper.m_patch_row_stride &&
825 "Patch depth must be equal to patch row stride.");
829 EIGEN_ALWAYS_INLINE Index patchColStride() const { return m_base_mapper.m_patch_col_stride; }
832 EIGEN_ALWAYS_INLINE IndexDivisor fastPatchRowStride() const
834 eigen_assert(patchDepth() == m_base_mapper.m_patch_row_stride &&
835 "Patch depth must be equal to patch row stride.");
836 return m_base_mapper.m_fastDimZero; // patch_depth
839 EIGEN_ALWAYS_INLINE IndexDivisor fastPatchColStride() const
841 return m_base_mapper.m_fastPatchColStride;
845 EIGEN_ALWAYS_INLINE Packet packetNoPadding(const Index depth, const Index baseIndex) const
847 const Index inputIndex = depth + baseIndex;
848 return m_base_mapper.m_impl.template packet<Unaligned>(inputIndex);
851 EIGEN_ALWAYS_INLINE Scalar coeffNoPadding(const Index depth, const Index baseIndex) const
853 const Index inputIndex = depth + baseIndex;
854 return m_base_mapper.m_impl.coeff(inputIndex);
856 template <typename PacketT = Packet>
857 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE typename std::enable_if<
858 TensorEvaluatorHasPartialPacket<TensorEvaluatorT, PacketT, Index>::value, PacketT>::type
859 partialPacketNoPadding(const Index depth, const Index baseIndex, Index num_coeffs) const
861 const Index inputIndex = depth + baseIndex;
862 return m_base_mapper.m_impl.template partialPacket<PacketT>(inputIndex,
863 mask<PacketT>(0, num_coeffs));
866 EIGEN_ALWAYS_INLINE bool hasPadding() const
868 // TODO(ezhulenev): It does seems that for inflated filter it's still
869 // possible to guarantee "no padding or skipping" for non-standard packing.
870 if (nonStandardPatches())
873 // Non zero padding before.
874 if (m_base_mapper.m_rowPaddingTop > 0)
876 if (m_base_mapper.m_colPaddingLeft > 0)
879 // Non zero padding after in rows.
880 const Index last_row = (m_base_mapper.m_outputRows - 1) * m_base_mapper.m_row_strides;
881 if (last_row + (patchRows() - 1) >= m_base_mapper.m_inputRows)
884 // Non zero padding after in cols.
885 const Index last_col = (m_base_mapper.m_outputCols - 1) * m_base_mapper.m_col_strides;
886 if (last_col + (patchCols() - 1) >= m_base_mapper.m_inputCols)
892 EIGEN_ALWAYS_INLINE bool padRow(const Index row) const
894 const Index r = m_rowIndex + row;
895 return r < 0 || r >= m_base_mapper.m_inputRows;
898 EIGEN_ALWAYS_INLINE bool padAnyRow(const Index first_row, const Index last_row) const
900 return m_rowIndex + first_row < 0 || m_rowIndex + last_row >= m_base_mapper.m_inputRows;
903 EIGEN_ALWAYS_INLINE bool padOrSkipRow(const Index row, Index *orig_row) const
905 eigen_assert(nonStandardPatches());
907 const Index input_row = m_rowIndex + row * m_base_mapper.m_in_row_strides;
908 *orig_row = (m_base_mapper.m_patch_row_inflate_strides == 1)
910 : ((input_row >= 0) ? (input_row / m_base_mapper.m_fastInputRowStride) : 0);
912 return (*orig_row < 0 || *orig_row >= m_base_mapper.m_inputRows) ||
913 (input_row != *orig_row * m_base_mapper.m_patch_row_inflate_strides);
916 EIGEN_ALWAYS_INLINE bool padCol(const Index col) const
918 const Index c = m_colIndex + col;
919 return c < 0 || c >= m_base_mapper.m_inputCols;
922 EIGEN_ALWAYS_INLINE bool padOrSkipCol(const Index col, Index *orig_col) const
924 eigen_assert(nonStandardPatches());
926 const Index input_col = m_colIndex + col * m_base_mapper.m_in_col_strides;
927 *orig_col = (m_base_mapper.m_patch_col_inflate_strides == 1)
929 : ((input_col >= 0) ? (input_col / m_base_mapper.m_fastInputColStride) : 0);
931 return (*orig_col < 0 || *orig_col >= m_base_mapper.m_inputCols) ||
932 (input_col != *orig_col * m_base_mapper.m_patch_col_inflate_strides);
935 EIGEN_ALWAYS_INLINE Index baseIndex(const Index row, const Index col) const
937 const Index r = m_rowIndex + row;
938 const Index c = m_colIndex + col;
939 return r * m_base_mapper.m_rowInputStride + c * m_base_mapper.m_colInputStride + m_otherIndex;
941 // Compute a base index when original input row and column were precomputed
942 // using padOrSkipRow and padOrSkipCol. Used only for non standard patches.
944 EIGEN_ALWAYS_INLINE Index origBaseIndex(const Index orig_row, const Index orig_col) const
946 return orig_row * m_base_mapper.m_rowInputStride + orig_col * m_base_mapper.m_colInputStride +
951 EIGEN_ALWAYS_INLINE Index rowStride() const { return m_base_mapper.m_row_strides; }
953 EIGEN_ALWAYS_INLINE Index colStride() const { return m_base_mapper.m_col_strides; }
956 EIGEN_ALWAYS_INLINE Index rowOffset() const
958 const Index patchOffset = m_depth_offset / m_base_mapper.m_fastDimZero;
959 const Index colOffset = patchOffset / m_base_mapper.m_fastColStride;
960 return patchOffset - colOffset * m_base_mapper.m_colStride;
964 EIGEN_ALWAYS_INLINE Index colOffset() const
966 const Index patchOffset = m_depth_offset / m_base_mapper.m_fastDimZero;
967 const Index colOffset = patchOffset / m_base_mapper.m_fastColStride;
972 EIGEN_ALWAYS_INLINE Index depthOffset() const { return m_depth_offset % patchDepth(); }
974 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE LinearMapper getLinearMapper(Index i, Index j) const
976 return LinearMapper(m_base_mapper, i + m_depth_offset, j + m_col_offset);
980 Index m_depth_offset; // First row in the input matrix
981 Index m_col_offset; // First col in the input matrix
983 // Knowing that: col_offset == patchIndex * OTHERS, we keep precomputed base
984 // indices for the first element in a patch specified by col_offset
985 // (see computeBaseIndices(...) for details).
990 const ParentMapper m_base_mapper; // Keeping a copy instead of a reference
991 // performs better in benchmarks.
994 // Arrange a block of the right input matrix (in our case it's always a "virtual
995 // matrix" constructed from extracted image patches) in contiguous memory.
997 // Given column major input (A0 beside A1 in memory):
998 // A0 B0 C0 D0 E0 F0 G0 H0 ... Z0
999 // A1 B1 C1 D1 E1 F1 G1 H1 ... Z1
1000 // A2 B2 C2 D2 E2 F2 G2 H2 ... Z2
1001 // A3 B3 C3 D3 E3 F3 G3 H3 ... Z3
1002 // A4 B4 C4 D4 E4 F4 G4 H4 ... Z4
1003 // A5 B5 C5 D5 E5 F5 G5 H5 ... Z5
1004 // A6 B6 C6 D6 E6 F6 G6 H6 ... Z6
1005 // A7 B7 C7 D7 E7 F7 G7 H7 ... Z7
1009 // *) A, B, C, ... - patches extracted from the original input.
1010 // *) A0, A1, A2 ... - values from the same patch at different offsets.
1012 // The traversal (packed rhs memory) order (B0 besides A0 in memory):
1013 // A0 B0 C0 D0 A1 B1 C1 D1 ...
1014 // E0 F0 G0 H0 E1 F1 G1 H1 ...
1016 // Z0 Z1 Z2 Z3 Z4 Z5 Z6 Z7 ... <- doesn't belong to any block (nr = 4)
1018 // This traversal order must be the same as in default gemm_pack_rhs defined in
1019 // GeneralBlockPanelKernel.h.
1021 // *) nr - number of registers along the 'n' dimension.
1022 // See GeneralBlockPanelKernel.h and "Anatomy of High-Performance Matrix
1023 // Multiplication" paper.
1024 template <typename NewDimension, Index Rows, Index Cols, typename ArgType, typename Device,
1025 typename Scalar, typename Index, typename nocontract_t, typename contract_t,
1026 int packet_size, bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment,
1028 struct gemm_pack_rhs<
1030 TensorContractionSubMapper<
1033 const TensorReshapingOp<NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType>>, Device>,
1034 nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment>,
1035 nr, ColMajor, false, false>
1037 typedef TensorContractionSubMapper<
1040 const TensorReshapingOp<NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType>>, Device>,
1041 nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment>
1043 typedef SubMapper DataMapper;
1044 typedef typename packet_traits<Scalar>::type Packet;
1046 EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE)
1049 EIGEN_DONT_INLINE void operator()(Scalar *block, const DataMapper &rhs, Index depth, Index cols,
1050 Index stride = 0, Index offset = 0) const
1052 eigen_assert(stride == 0);
1053 eigen_assert(offset == 0);
1057 const Index packet_cols4 = (cols / 4) * 4;
1058 const Index peeled_k = (depth / packet_size) * packet_size;
1059 const bool non_standard_patches = rhs.nonStandardPatches();
1061 for (Index j2 = 0; j2 < packet_cols4; j2 += 4)
1063 const SubMapper dm0 = rhs.getLinearMapper(0, j2 + 0);
1064 const SubMapper dm1 = rhs.getLinearMapper(0, j2 + 1);
1065 const SubMapper dm2 = rhs.getLinearMapper(0, j2 + 2);
1066 const SubMapper dm3 = rhs.getLinearMapper(0, j2 + 3);
1069 if ((packet_size % 4) == 0 && !non_standard_patches)
1072 // Iterate over patch columns and rows, if we know that a single
1073 // packet do not span across multiple rows or columns.
1074 if ((rhs.patchDepth() % packet_size) == 0)
1076 const Index start_col = rhs.colOffset();
1077 const Index max_col = rhs.maxCol(peeled_k);
1079 for (Index c = start_col; c < max_col; ++c)
1081 eigen_assert(k <= peeled_k);
1083 const Index start_row = (c == start_col) ? rhs.rowOffset() : 0;
1084 const Index max_row = rhs.maxRow(peeled_k, c);
1086 const bool pad_col0 = dm0.padCol(c);
1087 const bool pad_col1 = dm1.padCol(c);
1088 const bool pad_col2 = dm2.padCol(c);
1089 const bool pad_col3 = dm3.padCol(c);
1091 // Check if we can squeeze reads along the `row` and `depth`
1092 // dimensions (two innermost dimensions).
1093 if (!pad_col0 && !pad_col1 && !pad_col2 && !pad_col3 && //
1094 !dm0.padRow(start_row) && !dm0.padRow(max_row - 1) && //
1095 !dm1.padRow(start_row) && !dm1.padRow(max_row - 1) && //
1096 !dm2.padRow(start_row) && !dm2.padRow(max_row - 1) && //
1097 !dm3.padRow(start_row) && !dm3.padRow(max_row - 1))
1099 // Compute how many elements we can squeeze read.
1100 const Index start_depth = (c == start_col) ? rhs.depthOffset() : 0;
1102 // Upper bound for the number of elements in the depth dimension
1103 // that we can squeeze read.
1104 const Index squeeze_length = (max_row - start_row) * rhs.patchDepth() - start_depth;
1106 // Do not overshoot beyond the block size.
1107 const Index max_depth = start_depth + std::min<Index>(peeled_k - k, squeeze_length);
1108 eigen_assert((max_depth - start_depth) % packet_size == 0);
1110 const Index idx0 = dm0.baseIndex(start_row, c);
1111 const Index idx1 = dm1.baseIndex(start_row, c);
1112 const Index idx2 = dm2.baseIndex(start_row, c);
1113 const Index idx3 = dm3.baseIndex(start_row, c);
1115 for (Index d = start_depth; d < max_depth; d += packet_size)
1117 eigen_assert(k < peeled_k);
1118 PacketBlock<Packet, 4> kernel;
1119 kernel.packet[0] = rhs.packetNoPadding(d, idx0);
1120 kernel.packet[1] = rhs.packetNoPadding(d, idx1);
1121 kernel.packet[2] = rhs.packetNoPadding(d, idx2);
1122 kernel.packet[3] = rhs.packetNoPadding(d, idx3);
1124 pstoreu(block + 0 * packet_size, kernel.packet[0]);
1125 pstoreu(block + 1 * packet_size, kernel.packet[1]);
1126 pstoreu(block + 2 * packet_size, kernel.packet[2]);
1127 pstoreu(block + 3 * packet_size, kernel.packet[3]);
1128 block += 4 * packet_size;
1132 // Go to the next column.
1136 // If we can't squeeze reads, process rows one by one.
1137 for (Index r = start_row; r < max_row; ++r)
1139 eigen_assert(k <= peeled_k);
1141 const bool pad0 = pad_col0 || dm0.padRow(r);
1142 const bool pad1 = pad_col1 || dm1.padRow(r);
1143 const bool pad2 = pad_col2 || dm2.padRow(r);
1144 const bool pad3 = pad_col3 || dm3.padRow(r);
1146 const Index idx0 = dm0.baseIndex(r, c);
1147 const Index idx1 = dm1.baseIndex(r, c);
1148 const Index idx2 = dm2.baseIndex(r, c);
1149 const Index idx3 = dm3.baseIndex(r, c);
1151 const Index start_depth =
1152 ((c == start_col) && (r == start_row)) ? rhs.depthOffset() : 0;
1153 const Index max_depth = rhs.maxDepth(peeled_k - k, start_depth);
1154 eigen_assert((max_depth - start_depth) % packet_size == 0);
1156 for (Index d = start_depth; d < max_depth; d += packet_size)
1158 eigen_assert(k < peeled_k);
1159 PacketBlock<Packet, 4> kernel;
1160 kernel.packet[0] = pad0 ? pset1<Packet>(Scalar(0)) : rhs.packetNoPadding(d, idx0);
1161 kernel.packet[1] = pad1 ? pset1<Packet>(Scalar(0)) : rhs.packetNoPadding(d, idx1);
1162 kernel.packet[2] = pad2 ? pset1<Packet>(Scalar(0)) : rhs.packetNoPadding(d, idx2);
1163 kernel.packet[3] = pad3 ? pset1<Packet>(Scalar(0)) : rhs.packetNoPadding(d, idx3);
1165 pstoreu(block + 0 * packet_size, kernel.packet[0]);
1166 pstoreu(block + 1 * packet_size, kernel.packet[1]);
1167 pstoreu(block + 2 * packet_size, kernel.packet[2]);
1168 pstoreu(block + 3 * packet_size, kernel.packet[3]);
1169 block += 4 * packet_size;
1175 // The loop above should fill peeled_k elements.
1176 eigen_assert(peeled_k == k);
1180 for (; k < peeled_k; k += packet_size)
1182 PacketBlock<Packet, 4> kernel;
1183 kernel.packet[0] = dm0.loadPacketStandard(k);
1184 kernel.packet[1] = dm1.loadPacketStandard(k);
1185 kernel.packet[2] = dm2.loadPacketStandard(k);
1186 kernel.packet[3] = dm3.loadPacketStandard(k);
1188 pstoreu(block + 0 * packet_size, kernel.packet[0]);
1189 pstoreu(block + 1 * packet_size, kernel.packet[1]);
1190 pstoreu(block + 2 * packet_size, kernel.packet[2]);
1191 pstoreu(block + 3 * packet_size, kernel.packet[3]);
1192 block += 4 * packet_size;
1197 // Copy the remaining coefficients of the column block after the peeled_k.
1198 if (!rhs.nonStandardPatches())
1200 for (; k < depth; k++)
1202 block[0] = dm0.loadCoeffStandard(k);
1203 block[1] = dm1.loadCoeffStandard(k);
1204 block[2] = dm2.loadCoeffStandard(k);
1205 block[3] = dm3.loadCoeffStandard(k);
1211 for (; k < depth; k++)
1222 // copy the remaining columns one at a time (nr==1)
1223 for (Index j2 = packet_cols4; j2 < cols; ++j2)
1225 const SubMapper dm0 = rhs.getLinearMapper(0, j2);
1226 for (Index k = 0; k < depth; k++)
1235 // Template specialization for packet_size = 2. We must special-case packet
1236 // blocks with nr > packet_size, e.g. PacketBlock<Packet2d, 4>.
1237 template <typename NewDimension, Index Rows, Index Cols, typename ArgType, typename Device,
1238 typename Scalar, typename Index, typename nocontract_t, typename contract_t,
1239 bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment, int nr>
1240 struct gemm_pack_rhs<
1242 TensorContractionSubMapper<
1245 const TensorReshapingOp<NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType>>, Device>,
1246 nocontract_t, contract_t, 2, inner_dim_contiguous, inner_dim_reordered, Alignment>,
1247 nr, ColMajor, false, false>
1249 typedef TensorContractionSubMapper<
1252 const TensorReshapingOp<NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType>>, Device>,
1253 nocontract_t, contract_t, 2, inner_dim_contiguous, inner_dim_reordered, Alignment>
1255 typedef SubMapper DataMapper;
1256 typedef typename packet_traits<Scalar>::type Packet;
1258 EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE)
1261 EIGEN_DONT_INLINE void operator()(Scalar *block, const DataMapper &rhs, Index depth, Index cols,
1262 Index stride = 0, Index offset = 0) const
1264 eigen_assert(stride == 0);
1265 eigen_assert(offset == 0);
1270 const int packet_size = 2;
1271 const Index packet_cols4 = (cols / 4) * 4;
1272 const Index peeled_k = (depth / packet_size) * packet_size;
1273 const bool non_standard_patches = rhs.nonStandardPatches();
1275 for (Index j2 = 0; j2 < packet_cols4; j2 += 4)
1277 const SubMapper dm0 = rhs.getLinearMapper(0, j2 + 0);
1278 const SubMapper dm1 = rhs.getLinearMapper(0, j2 + 1);
1279 const SubMapper dm2 = rhs.getLinearMapper(0, j2 + 2);
1280 const SubMapper dm3 = rhs.getLinearMapper(0, j2 + 3);
1283 if (!non_standard_patches)
1286 // Iterate over patch columns and rows if we know that a single
1287 // packet do not span across multiple rows or columns.
1288 if ((rhs.patchDepth() % packet_size) == 0)
1290 const Index start_col = rhs.colOffset();
1291 const Index max_col = rhs.maxCol(peeled_k);
1293 for (Index c = start_col; c < max_col; ++c)
1295 eigen_assert(k <= peeled_k);
1297 const Index start_row = (c == start_col) ? rhs.rowOffset() : 0;
1298 const Index max_row = rhs.maxRow(peeled_k, c);
1300 const bool pad_col0 = dm0.padCol(c);
1301 const bool pad_col1 = dm1.padCol(c);
1302 const bool pad_col2 = dm2.padCol(c);
1303 const bool pad_col3 = dm3.padCol(c);
1305 // We can squeeze reads along the `row` and `depth` dimensions if
1306 // the row stride is `1`, which means that `row` and `depth`
1307 // dimensions are contiguous (two innermost dimensions).
1308 if (rhs.rowStride() == 1 && //
1309 !pad_col0 && !pad_col1 && !pad_col2 && !pad_col3 && //
1310 !dm0.padRow(start_row) && !dm0.padRow(max_row - 1) && //
1311 !dm1.padRow(start_row) && !dm1.padRow(max_row - 1) && //
1312 !dm2.padRow(start_row) && !dm2.padRow(max_row - 1) && //
1313 !dm3.padRow(start_row) && !dm3.padRow(max_row - 1))
1315 // Compute how many elements we can squeeze read.
1316 const Index start_depth = (c == start_col) ? rhs.depthOffset() : 0;
1318 // Upper bound for the number of elements in the depth dimension
1319 // that we can squeeze read.
1320 const Index squeeze_length = (max_row - start_row) * rhs.patchDepth() - start_depth;
1322 // Do not overshoot beyond the block size.
1323 const Index max_depth = start_depth + std::min<Index>(peeled_k - k, squeeze_length);
1324 eigen_assert((max_depth - start_depth) % packet_size == 0);
1326 const Index idx0 = dm0.baseIndex(start_row, c);
1327 const Index idx1 = dm1.baseIndex(start_row, c);
1328 const Index idx2 = dm2.baseIndex(start_row, c);
1329 const Index idx3 = dm3.baseIndex(start_row, c);
1331 for (Index d = start_depth; d < max_depth; d += packet_size)
1333 PacketBlock<Packet, 2> kernel0;
1334 PacketBlock<Packet, 2> kernel1;
1335 kernel0.packet[0] = rhs.packetNoPadding(d, idx0);
1336 kernel0.packet[1] = rhs.packetNoPadding(d, idx1);
1337 kernel1.packet[0] = rhs.packetNoPadding(d, idx2);
1338 kernel1.packet[1] = rhs.packetNoPadding(d, idx3);
1339 ptranspose(kernel0);
1340 ptranspose(kernel1);
1341 pstoreu(block + 0 * packet_size, kernel0.packet[0]);
1342 pstoreu(block + 1 * packet_size, kernel1.packet[0]);
1343 pstoreu(block + 2 * packet_size, kernel0.packet[1]);
1344 pstoreu(block + 3 * packet_size, kernel1.packet[1]);
1345 block += 4 * packet_size;
1349 // Go to the next column.
1353 // If we can't squeeze reads, process rows one by one.
1354 for (Index r = start_row; r < max_row; ++r)
1356 eigen_assert(k <= peeled_k);
1358 const bool pad0 = pad_col0 || dm0.padRow(r);
1359 const bool pad1 = pad_col1 || dm1.padRow(r);
1360 const bool pad2 = pad_col2 || dm2.padRow(r);
1361 const bool pad3 = pad_col3 || dm3.padRow(r);
1363 const Index idx0 = dm0.baseIndex(r, c);
1364 const Index idx1 = dm1.baseIndex(r, c);
1365 const Index idx2 = dm2.baseIndex(r, c);
1366 const Index idx3 = dm3.baseIndex(r, c);
1368 const Index start_depth =
1369 ((c == start_col) && (r == start_row)) ? rhs.depthOffset() : 0;
1370 const Index max_depth = rhs.maxDepth(peeled_k - k, start_depth);
1371 eigen_assert((max_depth - start_depth) % packet_size == 0);
1373 for (Index d = start_depth; d < max_depth; d += packet_size)
1375 eigen_assert(k < peeled_k);
1376 PacketBlock<Packet, 2> kernel0;
1377 PacketBlock<Packet, 2> kernel1;
1378 kernel0.packet[0] = pad0 ? pset1<Packet>(Scalar(0)) : rhs.packetNoPadding(d, idx0);
1379 kernel0.packet[1] = pad1 ? pset1<Packet>(Scalar(0)) : rhs.packetNoPadding(d, idx1);
1380 kernel1.packet[0] = pad2 ? pset1<Packet>(Scalar(0)) : rhs.packetNoPadding(d, idx2);
1381 kernel1.packet[1] = pad3 ? pset1<Packet>(Scalar(0)) : rhs.packetNoPadding(d, idx3);
1382 ptranspose(kernel0);
1383 ptranspose(kernel1);
1384 pstoreu(block + 0 * packet_size, kernel0.packet[0]);
1385 pstoreu(block + 1 * packet_size, kernel1.packet[0]);
1386 pstoreu(block + 2 * packet_size, kernel0.packet[1]);
1387 pstoreu(block + 3 * packet_size, kernel1.packet[1]);
1388 block += 4 * packet_size;
1394 // The loop above should fill peeled_k elements.
1395 eigen_assert(peeled_k == k);
1399 // Packet can span multiple rows or columns, so we have to go
1400 // though the slower "standard" path.
1401 for (; k < peeled_k; k += packet_size)
1403 PacketBlock<Packet, 2> kernel0;
1404 PacketBlock<Packet, 2> kernel1;
1405 kernel0.packet[0] = dm0.loadPacketStandard(k);
1406 kernel0.packet[1] = dm1.loadPacketStandard(k);
1407 kernel1.packet[0] = dm2.loadPacketStandard(k);
1408 kernel1.packet[1] = dm3.loadPacketStandard(k);
1409 ptranspose(kernel0);
1410 ptranspose(kernel1);
1411 pstoreu(block + 0 * packet_size, kernel0.packet[0]);
1412 pstoreu(block + 1 * packet_size, kernel1.packet[0]);
1413 pstoreu(block + 2 * packet_size, kernel0.packet[1]);
1414 pstoreu(block + 3 * packet_size, kernel1.packet[1]);
1415 block += 4 * packet_size;
1420 // Copy the remaining coefficients of the column block after the peeled_k.
1421 if (!non_standard_patches)
1423 for (; k < depth; k++)
1425 block[0] = dm0.loadCoeffStandard(k);
1426 block[1] = dm1.loadCoeffStandard(k);
1427 block[2] = dm2.loadCoeffStandard(k);
1428 block[3] = dm3.loadCoeffStandard(k);
1434 for (; k < depth; k++)
1445 // Copy the remaining columns one at a time (nr==1).
1446 for (Index j2 = packet_cols4; j2 < cols; ++j2)
1448 const SubMapper dm0 = rhs.getLinearMapper(0, j2);
1449 for (Index k = 0; k < depth; k++)
1458 // Special case for non-vectorized types such as float16.
1459 template <typename NewDimension, Index Rows, Index Cols, typename ArgType, typename Device,
1460 typename Scalar, typename Index, typename nocontract_t, typename contract_t,
1461 bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment, int nr>
1462 struct gemm_pack_rhs<
1464 TensorContractionSubMapper<
1467 const TensorReshapingOp<NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType>>, Device>,
1468 nocontract_t, contract_t, 1, inner_dim_contiguous, inner_dim_reordered, Alignment>,
1469 nr, ColMajor, false, false>
1471 typedef TensorContractionSubMapper<
1474 const TensorReshapingOp<NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType>>, Device>,
1475 nocontract_t, contract_t, 1, inner_dim_contiguous, inner_dim_reordered, Alignment>
1477 typedef SubMapper DataMapper;
1479 EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE)
1482 EIGEN_DONT_INLINE void operator()(Scalar *block, const DataMapper &rhs, Index depth, Index cols,
1483 Index stride = 0, Index offset = 0) const
1485 eigen_assert(stride == 0);
1486 eigen_assert(offset == 0);
1491 const Index packet_cols4 = (cols / 4) * 4;
1493 for (Index j2 = 0; j2 < packet_cols4; j2 += 4)
1495 const SubMapper dm0 = rhs.getLinearMapper(0, j2 + 0);
1496 const SubMapper dm1 = rhs.getLinearMapper(0, j2 + 1);
1497 const SubMapper dm2 = rhs.getLinearMapper(0, j2 + 2);
1498 const SubMapper dm3 = rhs.getLinearMapper(0, j2 + 3);
1500 if (!rhs.nonStandardPatches())
1502 for (Index k = 0; k < depth; k++)
1504 block[0] = dm0.loadCoeffStandard(k);
1505 block[1] = dm1.loadCoeffStandard(k);
1506 block[2] = dm2.loadCoeffStandard(k);
1507 block[3] = dm3.loadCoeffStandard(k);
1513 for (Index k = 0; k < depth; k++)
1524 // Copy the remaining columns one at a time (nr==1).
1525 for (Index j2 = packet_cols4; j2 < cols; ++j2)
1527 const SubMapper dm0 = rhs.getLinearMapper(0, j2);
1528 for (Index k = 0; k < depth; k++)
1536 } // end namespace internal
1538 /** SpatialConvolution
1539 * \ingroup CXX11_NeuralNetworks_Module
1541 * \brief Applies a 2D convolution over a multichannel input image.
1543 * The input parameter is expected to be a tensor with a rank of 3 or more
1544 * (channels, height, width, and optionally others)
1545 * The kernel parameter is expected to be a 4D tensor (filters, channels,
1546 * kernel_height, kernel_width)
1547 * The input and the kernel must both be in col-major layout. The result will
1548 * also be in col-major layout.
1550 * If col_in_stride, row_in_stride > 1, then applies convolution with holes
1551 * (aka atrous convolution), sampling every col_in_stride, row_in_stride input
1554 * If padding_top, padding_bottom, padding_left, or padding_right is specified,
1555 * then those paddings will be used to pad the input, and padding_type must be
1558 * The result can be assigned to a tensor of rank equal to the rank of the
1559 * input. The dimensions of the result will be filters, height, width (and
1560 * others if applicable).
1562 * It is possible to swap the order of the width and height dimensions provided
1563 * that the same order is used in the input, the kernel, and the output.
1565 * It is also possible to add an output kernel to the contraction, output
1566 * kernel is called by Eigen when it "finalizes" the block of an output tensor.
1569 template <typename Input, typename Kernel, typename OutputKernel = const NoOpOutputKernel>
1570 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE static const typename internal::conditional<
1571 internal::traits<Input>::Layout == ColMajor,
1573 const DSizes<typename internal::traits<Input>::Index, internal::traits<Input>::NumDimensions>,
1574 const TensorContractionOp<
1575 const array<IndexPair<typename internal::traits<Input>::Index>, 1>,
1576 const TensorReshapingOp<const DSizes<typename internal::traits<Input>::Index, 2>,
1578 const TensorReshapingOp<const DSizes<typename internal::traits<Input>::Index, 2>,
1579 const TensorImagePatchOp<Dynamic, Dynamic, const Input>>,
1580 const OutputKernel>>,
1582 const DSizes<typename internal::traits<Input>::Index, internal::traits<Input>::NumDimensions>,
1583 const TensorContractionOp<
1584 const array<IndexPair<typename internal::traits<Input>::Index>, 1>,
1585 const TensorReshapingOp<const DSizes<typename internal::traits<Input>::Index, 2>,
1586 const TensorImagePatchOp<Dynamic, Dynamic, const Input>>,
1587 const TensorReshapingOp<const DSizes<typename internal::traits<Input>::Index, 2>,
1589 const OutputKernel>>>::type
1590 SpatialConvolution(const Input &input, const Kernel &kernel, const Index row_stride = 1,
1591 const Index col_stride = 1, const PaddingType padding_type = PADDING_SAME,
1592 const Index row_in_stride = 1, const Index col_in_stride = 1,
1593 const OutputKernel &output_kernel = OutputKernel(), Index padding_top = 0,
1594 Index padding_bottom = 0, Index padding_left = 0, Index padding_right = 0)
1596 typedef typename internal::traits<Input>::Index TensorIndex;
1597 TensorRef<Tensor<typename internal::traits<Input>::Scalar, internal::traits<Input>::NumDimensions,
1598 internal::traits<Input>::Layout, TensorIndex>>
1601 Tensor<typename internal::traits<Kernel>::Scalar, internal::traits<Kernel>::NumDimensions,
1602 internal::traits<Kernel>::Layout, TensorIndex>>
1605 EIGEN_STATIC_ASSERT(internal::traits<Input>::Layout == internal::traits<Kernel>::Layout,
1606 YOU_MADE_A_PROGRAMMING_MISTAKE)
1607 const bool isColMajor = (internal::traits<Input>::Layout == ColMajor);
1609 const int NumDims = internal::traits<Input>::NumDimensions;
1611 // Number of filters to apply. This is the same as the output depth of the
1613 const TensorIndex kernelFilters = isColMajor ? kern.dimensions()[0] : kern.dimensions()[3];
1614 // Number of channels. This is the same as the input depth.
1615 const TensorIndex kernelChannels = isColMajor ? kern.dimensions()[1] : kern.dimensions()[2];
1616 const TensorIndex kernelRows = isColMajor ? kern.dimensions()[2] : kern.dimensions()[1];
1617 const TensorIndex kernelCols = isColMajor ? kern.dimensions()[3] : kern.dimensions()[0];
1619 const Index kernelRowsEff = kernelRows + (kernelRows - 1) * (row_in_stride - 1);
1620 const Index kernelColsEff = kernelCols + (kernelCols - 1) * (col_in_stride - 1);
1622 array<IndexPair<TensorIndex>, 1> contract_dims;
1623 contract_dims[0] = IndexPair<TensorIndex>(1, 0);
1625 const TensorIndex InputRows = isColMajor ? in.dimension(1) : in.dimension(NumDims - 2);
1626 const TensorIndex InputCols = isColMajor ? in.dimension(2) : in.dimension(NumDims - 3);
1627 const bool padding_explicit = (padding_top || padding_bottom || padding_left || padding_right);
1629 TensorIndex out_height;
1630 TensorIndex out_width;
1631 switch (padding_type)
1635 const TensorIndex InputRowsEff = InputRows + padding_top + padding_bottom;
1636 const TensorIndex InputColsEff = InputCols + padding_left + padding_right;
1637 out_height = divup(InputRowsEff - kernelRowsEff + 1, row_stride);
1638 out_width = divup(InputColsEff - kernelColsEff + 1, col_stride);
1643 eigen_assert(!padding_explicit);
1644 out_height = divup(InputRows, row_stride);
1645 out_width = divup(InputCols, col_stride);
1650 // Initialize unused variables to avoid a compiler warning
1653 eigen_assert(false && "unexpected padding");
1657 // Molds the output of the patch extraction code into a 2d tensor:
1658 // - the first dimension (dims[0]): the patch values to be multiplied with the
1660 // - the second dimension (dims[1]): everything else
1661 DSizes<TensorIndex, 2> pre_contract_dims;
1664 pre_contract_dims[0] = kernelChannels * kernelRows * kernelCols;
1665 pre_contract_dims[1] = out_height * out_width;
1666 for (int i = 3; i < NumDims; ++i)
1668 pre_contract_dims[1] *= in.dimension(i);
1673 pre_contract_dims[1] = kernelChannels * kernelRows * kernelCols;
1674 pre_contract_dims[0] = out_height * out_width;
1675 for (int i = 0; i < NumDims - 3; ++i)
1677 pre_contract_dims[0] *= in.dimension(i);
1681 // Molds the output of the contraction into the shape expected by the used
1682 // (assuming this is ColMajor):
1683 // - 1st dim: kernel filters
1684 // - 2nd dim: output height
1685 // - 3rd dim: output width
1686 // - 4th dim and beyond: everything else including batch size
1687 DSizes<TensorIndex, NumDims> post_contract_dims;
1690 post_contract_dims[0] = kernelFilters;
1691 post_contract_dims[1] = out_height;
1692 post_contract_dims[2] = out_width;
1693 for (int i = 3; i < NumDims; ++i)
1695 post_contract_dims[i] = in.dimension(i);
1700 post_contract_dims[NumDims - 1] = kernelFilters;
1701 post_contract_dims[NumDims - 2] = out_height;
1702 post_contract_dims[NumDims - 3] = out_width;
1703 for (int i = 0; i < NumDims - 3; ++i)
1705 post_contract_dims[i] = in.dimension(i);
1709 DSizes<TensorIndex, 2> kernel_dims;
1712 kernel_dims[0] = kernelFilters;
1713 kernel_dims[1] = kernelChannels * kernelRows * kernelCols;
1717 kernel_dims[0] = kernelChannels * kernelRows * kernelCols;
1718 kernel_dims[1] = kernelFilters;
1720 if (padding_explicit)
1722 return choose(Cond<internal::traits<Input>::Layout == ColMajor>(),
1723 kernel.reshape(kernel_dims)
1725 .extract_image_patches(kernelRows, kernelCols, row_stride,
1726 col_stride, row_in_stride, col_in_stride,
1727 /*row_inflate_stride=*/1,
1728 /*col_inflate_stride=*/1, padding_top,
1729 padding_bottom, padding_left, padding_right,
1730 /*padding_value=*/0)
1731 .reshape(pre_contract_dims),
1732 contract_dims, output_kernel)
1733 .reshape(post_contract_dims),
1735 .extract_image_patches(kernelRows, kernelCols, row_stride, col_stride,
1736 row_in_stride, col_in_stride,
1737 /*row_inflate_stride=*/1,
1738 /*col_inflate_stride=*/1, padding_top, padding_bottom,
1739 padding_left, padding_right,
1740 /*padding_value=*/0)
1741 .reshape(pre_contract_dims)
1742 .contract(kernel.reshape(kernel_dims), contract_dims, output_kernel)
1743 .reshape(post_contract_dims));
1748 Cond<internal::traits<Input>::Layout == ColMajor>(),
1749 kernel.reshape(kernel_dims)
1751 .extract_image_patches(kernelRows, kernelCols, row_stride, col_stride,
1752 row_in_stride, col_in_stride, padding_type)
1753 .reshape(pre_contract_dims),
1754 contract_dims, output_kernel)
1755 .reshape(post_contract_dims),
1757 .extract_image_patches(kernelRows, kernelCols, row_stride, col_stride, row_in_stride,
1758 col_in_stride, padding_type)
1759 .reshape(pre_contract_dims)
1760 .contract(kernel.reshape(kernel_dims), contract_dims, output_kernel)
1761 .reshape(post_contract_dims));
1765 } // end namespace Eigen
1767 #endif // __NNFW_CKER_EIGEN_EIGEN_SPATIAL_CONVOLUTIONS_INL_H__